ツインテールdeエンジェルモード!! で、損失関数を記述してみたいと思います。
(その2)
前回は、2乗和誤差を記述しましたので、
今回は、交差エントロピー誤差を記述してみます。
英語で書くと cross entropy error となるらしいので、
頭文字をとって ceerr() ぐらいで作ってみます。
% cat lossfunc
### Deep Learning - 損失関数の定義 ###
(途中略)
def ceerr(y,t){ # 交差エントロピー誤差 [ Cross Entropy Error ]
if( type(y)=='A' ){
ans = 0
each( k=keys(y) )
ans += -t[k]*log(y[k]+1e-7)
}
else
ans = -t*log(y+1e-7)
retn(ans)
}
こちらも、場合分けしているのは、
ベクトルだけでなくスカラーでも処理できるようにするためです。
2乗和誤差で使ったサンプルデータで計算&表示させてみます。
スクリプトの全体像はこんな感じです。
(先の損失関数を、ファイルで読み込んでいます。)
配列 t が教師データ。(one-hot 表現)
配列 y が推論出力。といったイメージです。
% cat v
:r lossfunc
t = { 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 }
y = { 0.1 , 0.05 , 0.6 , 0 , 0.05 , 0.1 , 0 , 0.1 , 0 , 0 }
loss = ceerr(y,t)
p(loss)
y = { 0.1 , 0.05 , 0.1 , 0 , 0.05 , 0.1 , 0 , 0.6 , 0 , 0 }
loss = ceerr(y,t)
p(loss)
では、実行させてみます。
前回の2乗和誤差と値は違いますが、
推論が正しいとき(最初の結果)が、出力が低くなり、
推論が正しくない時(二番目の結果)が、出力が高くなるという点は同じでした。
% tt v
0.510825
2.302584
交差エントロピー誤差でも、トレースモードで実行してみます。
% tt -t v
<File=v,Line=001> :r lossfunc
*** EOF ***
<File=v,Line=003> t = { 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 }
t[0]=0,t[1]=0,t[2]=1,t[3]=0,t[4]=0,t[5]=0,t[6]=0,t[7]=0,t[8]=0,t[9]=0
<File=v,Line=005> y = { 0.1 , 0.05 , 0.6 , 0 , 0.05 , 0.1 , 0 , 0.1 , 0 , 0 }
y[0]=0.100000,y[1]=0.050000,y[2]=0.600000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.100000,y[8]=0,y[9]=0
<File=v,Line=006> loss = ceerr(y,t)
Call: ceerr(y[0]=0.100000,y[1]=0.050000,y[2]=0.600000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.100000,y[8]=0,y[9]=0,t[0]=0,t[1]=0,t[2]=1,t[3]=0,t[4]=0,t[5]=0,t[6]=0,t[7]=0,t[8]=0,t[9]=0)
<File=lossfunc,Line=015> if( type(y)=='A' ){
Call: type(y[0]=0.100000,y[1]=0.050000,y[2]=0.600000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.100000,y[8]=0,y[9]=0)
if(TRUE)
<File=lossfunc,Line=016> ans = 0
ans=0
<File=lossfunc,Line=017> each( k=keys(y) )
Call: keys(y[0]=0.100000,y[1]=0.050000,y[2]=0.600000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.100000,y[8]=0,y[9]=0)
<File=lossfunc,Line=018> ans += -t[k]*log(y[k]+1e-7)
Call: log(0.100000)
ans=0.000000
Call: log(0.050000)
ans=0.000000
Call: log(0.600000)
ans=0.510825
Call: log(0.000000)
ans=0.510825
Call: log(0.050000)
ans=0.510825
Call: log(0.100000)
ans=0.510825
Call: log(0.000000)
ans=0.510825
Call: log(0.100000)
ans=0.510825
Call: log(0.000000)
ans=0.510825
Call: log(0.000000)
ans=0.510825
<File=lossfunc,Line=022> retn(ans)
loss=0.510825
<File=v,Line=007> p(loss)
Call: p(0.510825)
0.510825
<File=v,Line=009> y = { 0.1 , 0.05 , 0.1 , 0 , 0.05 , 0.1 , 0 , 0.6 , 0 , 0 }
y[0]=0.100000,y[1]=0.050000,y[2]=0.100000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.600000,y[8]=0,y[9]=0
<File=v,Line=010> loss = ceerr(y,t)
Call: ceerr(y[0]=0.100000,y[1]=0.050000,y[2]=0.100000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.600000,y[8]=0,y[9]=0,t[0]=0,t[1]=0,t[2]=1,t[3]=0,t[4]=0,t[5]=0,t[6]=0,t[7]=0,t[8]=0,t[9]=0)
<File=lossfunc,Line=015> if( type(y)=='A' ){
Call: type(y[0]=0.100000,y[1]=0.050000,y[2]=0.100000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.600000,y[8]=0,y[9]=0)
if(TRUE)
<File=lossfunc,Line=016> ans = 0
ans=0
<File=lossfunc,Line=017> each( k=keys(y) )
Call: keys(y[0]=0.100000,y[1]=0.050000,y[2]=0.100000,y[3]=0,y[4]=0.050000,y[5]=0.100000,y[6]=0,y[7]=0.600000,y[8]=0,y[9]=0)
<File=lossfunc,Line=018> ans += -t[k]*log(y[k]+1e-7)
Call: log(0.100000)
ans=0.000000
Call: log(0.050000)
ans=0.000000
Call: log(0.100000)
ans=2.302584
Call: log(0.000000)
ans=2.302584
Call: log(0.050000)
ans=2.302584
Call: log(0.100000)
ans=2.302584
Call: log(0.000000)
ans=2.302584
Call: log(0.600000)
ans=2.302584
Call: log(0.000000)
ans=2.302584
Call: log(0.000000)
ans=2.302584
<File=lossfunc,Line=022> retn(ans)
loss=2.302584
<File=v,Line=011> p(loss)
Call: p(2.302584)
2.302584
変数の変化や、関数の呼び出しなどが、追跡できました。
うまく動作しているみたいです。