←前一篇 後一篇→
損失函式(
Loss function)是用來計算神經網路个輸出, 到底精差偌濟? 凡勢無仝个應用, 愛揀無仝个損失函式. 譬如講, 認聲音裡个字, 和認圖裡个數字, 損失函式無仝. 這嘛和輸出棧个樣有關係. 你欲認 0~9 十个數字, 和欲認人面, 嘛是無相siâng.
現此時佇這本冊的例是用認捌 0~9 做例, 伊的輸出棧是 10 粒神經元, 逐粒代表一个數字, 咱就先就這个例來紹介.
均方誤差 (Mean squared error)
這是上蓋出名的損失函式, 伊的數學算式是:
佇遮 y 是神經網路輸出, t 是訓練資料. k 是維度. 咱佇咧遮有 0~9 十个輸出, 伊的維度就是 10.
import sys, os
sys.path.append(os.pardir) # 設 module 搜揣路草
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
for i in range(100):
y = predict(network, x[i])
print(list(map(lambda x:round(x,1), y)))
伊會印出逐擺資料的輸出結果. 因為這是訓練好的資料, 所以大部份攏是 0, 正確的答案伊會真接近 1, 但是咱有時嘛會看著:
[0.0, 0.0, 0.0, 0.0, 0.40000001, 0.0, 0.0, 0.0, 0.0, 0.5]
[0.0, 0.0, 0.2, 0.69999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.89999998, 0.1, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.60000002, 0.0, 0.0, 0.0, 0.2, 0.1, 0.0, 0.0]
若是抑袂訓練好, index 0~9 就會攏有數字, 咱揀一兩个來算伊的均方誤差:
import numpy as np
def mean_squared_error(y, t):
return 0.5 * np.sum((y-t)**2)
y1 = [0.0, 0.0, 0.2, 0.69999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
t1 = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
y2 = [0.0, 0.0, 0.0, 0.0, 0.89999998, 0.1, 0.0, 0.0, 0.0, 0.0]
t2 = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
print(mean_squared_error(np.array(y1), np.array(t1)))
print(mean_squared_error(np.array(y2), np.array(t2)))
伊的結果是:
愈細表示愈接近正確的答案, 佇遮, y2 較倚!
Khu-ló-sù Én-tso-phì 精差 (Cross Entropy Error)
Cross Entropy Error 是另一个四常用个損失函式, 伊的公式是按呢:
咱先來看 y = log(x) 个圖生做啥乜款?
#!/usr/bin/python3
import numpy as np
import matplotlib.pylab as plt
x = np.arange(-0.0, 1.0, 0.001)
delta = 1e-7
y = np.log(x + 1e-7)
plt.plot(x, y)
plt.ylim(-5, 0) # siat-tīng y kuainn ê huān-uî
plt.show()
運行這个程式, 出現:
log(1) 是 0, x < 1 是負數, 愈接近 0, 伊就負愈大. log(0) 是毋成數, 窮實, 伊是負个無限大, 咱共伊標做 -inf. 毋過, 這電腦無法度算落去, 咱共伊添一个 1e-7 這个微微仔數來閃過這个問題.
若是看 croess entropy error 本身: 若是完全對同, 比如講 3, t3 會是 1, 賰的 t0~t2, t4~t9 攏是 0, 所致成做 E = -t3 * log(y3), log(1)=0, E 是 0 表示完全無精差.
咱來計算看覓頭前个例:
#!/usr/bin/python3
import numpy as np
def cross_entropy_error(y, t):
delta = 1e-7
return -np.sum(t * np.log(y+delta))
y1 = [0.0, 0.0, 0.2, 0.69999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
t1 = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
y2 = [0.0, 0.0, 0.0, 0.0, 0.89999998, 0.1, 0.0, 0.0, 0.0, 0.0]
t2 = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
print(cross_entropy_error(np.array(y1), np.array(t1)))
print(cross_entropy_error(np.array(y2), np.array(t2)))
伊的輸出:
0.356674815367
0.105360426769
伊的值, 和均方差無仝, 毋過猶原是 y2 較接近.
批次學習
到今, 咱討論ê是一擺訓練个誤差有偌大. 若是開始訓練, 咱會揀一批一批, 一批訓練了, 才規批來算伊有偌準, 所致:
就是共 N 个攏加--起來, 才閣分做 N 來平均. 紲落來 ê 問題是: 欲按怎 <凊彩> 揀 N 出來咧?
若以 MNIST 有 60000 筆資料, 咱一擺揀 100 筆來訓練, N 是 100. Numpy 有一个 random.choice() 函數, 會使共咱鬥做這个 '凊彩揀' 个工課:
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_maskmask]
t_batch = y_train[batch_mask]
train_size=600000, batch_size=100, batch_mask 是 10 个元素的 Numpy array, 逐个元素攏是 0~60000 个數字. 閣利用進前
Numpy array 个
蓋奅个索引 (Fancy indexing), 真輕可就共欲挃个物捎--出來.
伊的 cross_entropy_erorr(y, t) 就愛修改:
def cross_entropy_error(y, t):
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
batch_size = y.shape[0]
return -np.sum(t * np.log(y)) / batch_size
這 if y.ndim == 1 所做的 reshape, 是 Numpy array 个一个鋩角, 你若毋知, 會使去
遮, 搜揣 <鋩角> 兩字就知.
按呢修改, 伊就仝時支援 "干焦一筆資料" 和 "濟筆資料". 為啥物? 因為濟筆資料, 伊是成做 2 維陣列, 逐个列 row 是一筆資料. 干焦一筆 y.ndim==1, 100 筆 y.ndim == 100. Numpy array 个運算攏是逐个元素相對同 (element-wise), 所致, 一擺就攏算好阿. np.sum() 也是一改就攏總加--起來
方法確定了後, 就是訓練个循環:
- 揀一批資料, 共資料飼予咱的神經網路, 算看覓伊的平均誤差是偌濟?
- 閣修正 W 這个陣列
- 閣揀另外一批資料, 轉去 1, 看誤差有變較細--無?
紲--落來, 欲按怎予 W 佇咧調整个過程中方, 愈來愈倚正確个答案, 也就是損失函式愈來愈細? 這就愛uì微分个觀念開是講起!
沒有留言:
張貼留言