使用python/numpy實(shí)現(xiàn)im2col的學(xué)習(xí)心得
- 背景
- 書(shū)上的程序
- 分析
- 首先是:
- 其次:
- 寫(xiě)在最后
背景
最近在看深度學(xué)習(xí)的東西。使用的參考書(shū)是《深度學(xué)習(xí)入門(mén)——基于python的理論與實(shí)現(xiàn)》。在看到7.4時(shí),里面引入了一個(gè)im2col的函數(shù),從而方便講不斷循環(huán)進(jìn)行地相乘相加操作變成矩陣的運(yùn)算,通過(guò)空間資源換取時(shí)間效率。
為什么要這么操作和操作以后col矩陣的樣子比較好理解。由于對(duì)python和numpy不太熟悉,理解書(shū)上給出的程序?qū)崿F(xiàn)想了很久。終于有點(diǎn)感覺(jué)了,記錄下來(lái)。
書(shū)上的程序
def
im2col
(
input_data
,
filter_h
,
filter_w
,
stride
=
1
,
pad
=
0
)
:
"""
Parameters
----------
input_data : 由(數(shù)據(jù)量, 通道, 高, 長(zhǎng))的4維數(shù)組構(gòu)成的輸入數(shù)據(jù)
filter_h : 濾波器的高
filter_w : 濾波器的長(zhǎng)
stride : 步幅
pad : 填充
Returns
-------
col : 2維數(shù)組
"""
N
,
C
,
H
,
W
=
input_data
.
shape
out_h
=
(
H
+
2
*
pad
-
filter_h
)
//
stride
+
1
out_w
=
(
W
+
2
*
pad
-
filter_w
)
//
stride
+
1
img
=
np
.
pad
(
input_data
,
[
(
0
,
0
)
,
(
0
,
0
)
,
(
pad
,
pad
)
,
(
pad
,
pad
)
]
,
'constant'
)
col
=
np
.
zeros
(
(
N
,
C
,
filter_h
,
filter_w
,
out_h
,
out_w
)
)
for
y
in
range
(
filter_h
)
:
y_max
=
y
+
stride
*
out_h
for
x
in
range
(
filter_w
)
:
x_max
=
x
+
stride
*
out_w
col
[
:
,
:
,
y
,
x
,
:
,
:
]
=
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
col
=
col
.
transpose
(
0
,
4
,
5
,
1
,
2
,
3
)
.
reshape
(
N
*
out_h
*
out_w
,
-
1
)
return
col
分析
首先只考慮一個(gè)數(shù)據(jù),即此時(shí) N = 1 N=1 N = 1 。并且假設(shè)數(shù)據(jù)只有一層,比如灰度圖,即 C = 1 C=1 C = 1 。假設(shè)數(shù)據(jù)的高和長(zhǎng)分別為4,4。即 H = 4 H=4 H = 4 , W = 4 W=4 W = 4 。濾波器的長(zhǎng)和高分別為2,2。即 f i l t e r _ h = 2 filter\_h=2 f i l t e r _ h = 2 , f i l t e r _ w = 2 filter\_w=2 f i l t e r _ w = 2 。更進(jìn)一步地,將Pad簡(jiǎn)化為0。
此時(shí),img就是一個(gè)
4 ? 4 4*4
4
?
4
的矩陣,假設(shè)如下:
濾波器是
2 ? 2 2*2
2
?
2
的矩陣,假設(shè)為
因此,卷積層輸出是
3 ? 3 3*3
3
?
3
的矩陣。
有了這些預(yù)備設(shè)定,就可以開(kāi)始理解程序了。我們重點(diǎn)關(guān)注兩句話(huà)。
首先是:
col
[
:
,
:
,
y
,
x
,
:
,
:
]
=
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
y和x分別代表濾波器的尺寸,由于設(shè)定
N = 1 N=1
N
=
1
、
C = 1 C=1
C
=
1
。因此可以先只看后面四個(gè)維度。那么
( y , x , : , : ) (y,x,:,:)
(
y
,
x
,
:
,
:
)
意味著矩陣前兩維和濾波器尺寸一致,即
2 ? 2 2*2
2
?
2
,后面的兩個(gè)冒號(hào),就代表了在卷積運(yùn)算(濾波)時(shí),第y行第x列的濾波器參數(shù),需要和img中運(yùn)算的數(shù)的矩陣。
解釋一下:當(dāng)y=0,x=0時(shí),對(duì)應(yīng)的a的位置,如圖
此時(shí),完成整個(gè)卷積運(yùn)算的時(shí)候,a分別需要做(3*3)=9次的乘法,每次做乘法是對(duì)應(yīng)img中的數(shù)如下:
第一次:(綠色表示當(dāng)前卷積時(shí),每個(gè)濾波器參數(shù)對(duì)應(yīng)的位置,紅色表示a對(duì)應(yīng)的位置)
第二次:
以此類(lèi)推……
所以
( 0 , 0 , : , : ) (0,0,:,:)
(
0
,
0
,
:
,
:
)
中存的數(shù)如下:(黃色標(biāo)注的位置),對(duì)應(yīng)濾波器參數(shù)a所以進(jìn)行運(yùn)算的范圍。
所以:
img
[
:
,
:
,
y
:
y_max
:
stride
,
x
:
x_max
:
stride
]
中,y和x分別表示filter中第幾行,第幾列。然后每次移動(dòng)stride,直到走完img中所有的位置,抵達(dá)y_max和x_max。
這句話(huà),以及這個(gè)for循環(huán)的作用就解釋完了。
其次:
col
=
col
.
transpose
(
0
,
4
,
5
,
1
,
2
,
3
)
.
reshape
(
N
*
out_h
*
out_w
,
-
1
)
這句話(huà)目的是把矩陣重新排列,最后呈現(xiàn)出適合進(jìn)行矩陣運(yùn)算來(lái)代替循環(huán)的形式。
所以,這個(gè)矩陣一定是
N ? o u t _ h ? o u t _ w N*out\_h*out\_w
N
?
o
u
t
_
h
?
o
u
t
_
w
行。這里就是
3 ? 3 = 9 3*3=9
3
?
3
=
9
行。有多少列呢,肯定是濾波器系數(shù)的個(gè)數(shù),即
2 ? 2 = 4 2*2=4
2
?
2
=
4
列。
至于transpose函數(shù)中的設(shè)置,主要是為了配合后面的reshape函數(shù)的參數(shù)。
多說(shuō)一句,我覺(jué)得這里transpose不要老是想著轉(zhuǎn)置,我開(kāi)始也這么想,這么多維度,就轉(zhuǎn)不過(guò)來(lái)彎了。
我覺(jué)得,其實(shí)transpose就是決定一個(gè)新的取數(shù)順序,依次取出來(lái)就可以,然后能夠和原來(lái)對(duì)應(yīng)上,就沒(méi)問(wèn)題了。比如 a是一個(gè)三維的東西。然后b = a.transpose(1,2,0)。也就是說(shuō)
a [ y ] [ z ] [ x ] = b [ x ] [ y ] [ z ] a[y][z][x] = b[x][y][z]
a
[
y
]
[
z
]
[
x
]
=
b
[
x
]
[
y
]
[
z
]
transpose第一個(gè)參數(shù),0,表示第0維,也就是transpose以后,第0維不變,說(shuō)明即便展開(kāi),輸入的img也是按順序一個(gè)一個(gè)處理完的。
第2和3的參數(shù),之所以放 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h 和 o u t _ w 的大小,得明白reshape的操作方法。如果沒(méi)有指定order參數(shù),并且是默認(rèn)按照C的存儲(chǔ)格式(這里不理解可以看看reshape的參數(shù)有哪些),它是把矩陣按照從第0維開(kāi)始,依次全部排列開(kāi),然后在按需求重組。所以這里,要按照 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h 和 o u t _ w 優(yōu)先順序排列開(kāi),然后再使col總共就 N ? o u t _ h ? o u t _ w N*out\_h*out\_w N ? o u t _ h ? o u t _ w 行,那么reshpe函數(shù)會(huì)使每行中,就存儲(chǔ)一次卷積所需要所有值,即 C ? f i l t e r h ? f i l t e r w C*filter_h*filter_w C ? f i l t e r h ? ? f i l t e r w ? 列。
后面三個(gè)參數(shù)保證順序不變就行,方便和濾波器參數(shù)位置一一對(duì)應(yīng)。
以上,總結(jié)成一句話(huà):其實(shí)就是準(zhǔn)確找到濾波器每個(gè)參數(shù)對(duì)應(yīng)需要相乘的所有值,然后再變換一下矩陣的行狀,就可以了。
寫(xiě)在最后
由于本人水平有限,這一點(diǎn)代碼都想了一下午加一晚上才明白。還得繼續(xù)努力了。加油!雖然整理出來(lái)了,感覺(jué)有些東西不太好表述清楚,大家有什么問(wèn)題可以留言,多多交流,互相學(xué)習(xí)。
更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號(hào)聯(lián)系: 360901061
您的支持是博主寫(xiě)作最大的動(dòng)力,如果您喜歡我的文章,感覺(jué)我的文章對(duì)您有幫助,請(qǐng)用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長(zhǎng)非常感激您!手機(jī)微信長(zhǎng)按不能支付解決辦法:請(qǐng)將微信支付二維碼保存到相冊(cè),切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對(duì)您有幫助就好】元
