網(wǎng)絡(luò)結(jié)構(gòu)定義的差異:
在Python中(network3.py),網(wǎng)絡(luò)定義時(shí),不但定義了結(jié)構(gòu)參數(shù) layers,還定義了對(duì)應(yīng)的 mini_batch_size。也就是說(shuō)在 network3.py中定義的網(wǎng)絡(luò),是與mini_batch_size有相關(guān)性的。如果計(jì)算過(guò)程中要進(jìn)行 mini_batch_size的調(diào)整,直接更改 mini_batch_size然后重新計(jì)算是不可行的。
因此,需要對(duì)已有的網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行重新生成。
#網(wǎng)絡(luò)結(jié)構(gòu)定義
class Network(object):
def init (self, layers, mini_batch_size):
self.layers = layers
self.mini_batch_size = mini_batch_size
self.params = [param for layer in self.layers for param in layer.params]
self.x = T.matrix(“x”)
self.y = T.ivector(“y”)
init_layer = self.layers[0]
init_layer.set_inpt(self.x, self.x, self.mini_batch_size)
for j in range(1, len(self.layers)):
prev_layer, layer = self.layers[j-1], self.layers[j]
layer.set_inpt(
prev_layer.output, prev_layer.output_dropout, self.mini_batch_size)
self.output = self.layers[-1].output
self.output_dropout = self.layers[-1].output_dropout
根據(jù)上述的 網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)定義的代碼,可以推斷出,如果想更改 mini_batch_size的大小,應(yīng)該需要把原來(lái)(已有)的網(wǎng)絡(luò)參數(shù) net.layers,net.params重新賦值到一個(gè)網(wǎng)絡(luò)中去。
載入已有的數(shù)據(jù)
step 1:
查找網(wǎng)絡(luò)的層級(jí)結(jié)構(gòu),
net.layers
[, ]
step 2:
查找各層的輸入與輸出參數(shù) n_in, n_out的數(shù)值
net.layers[0].n_in
784
net.layers[1].n_in
100
net.layers[0].n_out
100
net.layers[1].n_out
10
step 3:
仔細(xì)的推敲了原來(lái)的結(jié)構(gòu),使用如下代碼,可以更新mini_batch_size后重新計(jì)算。
net = Network(net.layers, NEW_mini_batch_size)
net2.SGD(training_data, epochs, NEW_mini_batch_size, eta,
validation_data, test_data)
看起來(lái)還不錯(cuò)。
the old mini_batch_size is: 50
the new mini_batch_size is: 100
Training mini-batch number 0
Epoch 0: validation accuracy 90.53%
This is the best validation accuracy to date.
The corresponding test accuracy is 90.43%
Finished training network.
Best validation accuracy of 90.53% obtained at iteration 499
Corresponding test accuracy of 90.43%
step 4: 實(shí)現(xiàn)思路與代碼
每次計(jì)算完畢后,把網(wǎng)絡(luò)學(xué)習(xí)后的參數(shù)保存到磁盤(pán)。需要使用的時(shí)候,再讀取到內(nèi)存,調(diào)整mini_batch_size(批處理大小), eta(學(xué)習(xí)率), epochs(迭代次數(shù))后再進(jìn)行計(jì)算。
實(shí)現(xiàn)代碼
#讀取網(wǎng)絡(luò)參數(shù)
rst_path = “rst/conv_net3.json”
pfile = open(rst_path, ‘rb’) # read current contents
net = pickle.load(pfile)
pfile.close()
#設(shè)置 批處理、學(xué)習(xí)率、迭代數(shù)等參數(shù)
mini_batch_size = 60
epochs = 5
eta = 0.1
#更新mini_batch_size和網(wǎng)絡(luò)參數(shù)!
#使用已有的 layers 結(jié)構(gòu)和參數(shù)。
net = Network(net.layers, mini_batch_size);
#載入訓(xùn)練和測(cè)試數(shù)據(jù)
training_data, validation_data, test_data = network3.load_data_shared()
net.SGD(training_data, epochs, mini_batch_size, eta,
validation_data, test_data)
#保存數(shù)據(jù)
with open(rst_path, ‘wb’) as save_file:
pickle.dump(net, save_file)
save_file.close()
更多文章、技術(shù)交流、商務(wù)合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號(hào)聯(lián)系: 360901061
您的支持是博主寫(xiě)作最大的動(dòng)力,如果您喜歡我的文章,感覺(jué)我的文章對(duì)您有幫助,請(qǐng)用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點(diǎn)擊下面給點(diǎn)支持吧,站長(zhǎng)非常感激您!手機(jī)微信長(zhǎng)按不能支付解決辦法:請(qǐng)將微信支付二維碼保存到相冊(cè),切換到微信,然后點(diǎn)擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對(duì)您有幫助就好】元
