本文主要借鑒并綜合了以下兩個博客的內容(樣本生成和流圖構建訓練),并在其基礎上繪制了擬合后的直線和“訓練次數-代價函數值”曲線,可更直觀的觀察訓練效果:
https://www.cnblogs.com/xianhan/p/9090426.html
https://www.cnblogs.com/selenaf/p/9102398.html
具體步驟如下:
步驟1.在很多情況下,初學者都沒有樣本庫,一般可自建樣本庫。使用random函數隨機初始化樣本庫:
num_points=1000? # 生成的樣本數 vectors_set=[] # 初始化樣本集,為空 for i in range(num_points): ??? x1=np.random.normal(0.0,0.55)?? #橫坐標,進行隨機高斯處理化,以0為均值,以0.55為標準差 ??? y1=x1*0.1+0.3+np.random.normal(-0.03,0.03)?? #縱坐標,數據點在y1=x1*0.1+0.3上小范圍浮動 vectors_set.append([x1,y1]) ? # 將樣本集分為輸入集x_data和輸出集y_data x_data=[v[0] for v in vectors_set] y_data=[v[1] for v in vectors_set] ? # 繪制散點圖,查看生成樣本的分布情況 plt.scatter(x_data,y_data,c='r') plt.show() |
?
步驟2.建立計算流圖,包含“假設函數”“代價函數”和“訓練函數”。如下,訓練函數為梯度下降:
x = tf.placeholder(tf.float32) W = tf.Variable(tf.zeros([1])) b = tf.Variable(tf.zeros([1])) y_ = tf.placeholder(tf.float32) ? y = W * x + b ? lost = tf.reduce_mean(tf.square(y_-y)) # 設置代價函數 optimizer = tf.train.GradientDescentOptimizer(0.01)? # 設置梯度下降及其步長0.01 train_step = optimizer.minimize(lost) |
?
步驟3.初始化流圖,主要是初始化運行環境。如:
sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) |
?
步驟4.開始訓練:向占位符中輸入數據,使用sess.run()方法進行訓練。如下:
x_plot=[] y_plot=[] steps = 2000? # 訓練的步數 for i in range(steps): ??? xs = x_data ??? ys = y_data ??? feed = { x: xs, y_: ys }? # 向占位符中輸入數據 ??? sess.run(train_step, feed_dict=feed) ??? if i % 50 == 0 : ??????? print("After %d iteration:" % i) ??????? # print(xs,ys) ??????? print("W: %f" % sess.run(W))? # 查看當前訓練的W值 ??????? print("b: %f" % sess.run(b)) ??????? print("lost: %f" % sess.run(lost, feed_dict=feed)) ??????? if i>200: ?????? ?????x_plot.append(i) ??????????? y_plot.append(sess.run(lost, feed_dict=feed)) |
?
步驟5.繪制“訓練次數-代價函數值”曲線,觀察訓練的收斂趨勢。如下:
x_result=[-2,2] W_result=sess.run(W) b_result=sess.run(b) y_result=[] for x_temp in x_result:? # 如果直接賦值y_result=x_result*W+b,則y_result類型為Tensor,不能直接打印 ??? y_result.append(x_temp*W_result+b_result) plt.subplot(1,2,1) plt.scatter(x_data,y_data,c='r') plt.plot(x_result,y_result,'-y') plt.subplot(1,2,2) plt.plot(x_plot,y_plot,'-') plt.show() |
?
輸出結果大致如下:
打印結果(最后一次): After 1950 iteration: W: 0.102901 b: 0.270892 lost: 0.000836 |
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
