先の例で定義した learn() 関数の中身について説明を少し記述する。

for i in range(num) は学習回数を指定している。
x=Variable( … ) は、学習に用いる入力値を準備している。
[[-4.0333333 ]
[ 6.9666667 ]
[ 8.96666622]
[-7.0333333 ]

[-2.0333333 ]
[-0.03333334]
[ 5.9666667 ]]

といようなベクトルになっている。
つまり、このベクトルに含まれる値全てを使って一回のパラメータ更新を行う。
それを for ループで何回も繰り返す。

そして肝心なのは

mopt.update(model, x, x*x)

という1行で、入力 x に対して、正しい値(教師信号)として x*x を与えたときの、model の誤差を評価し、model に設定されている mnn のパラメータを更新する、ということをしている。



learn() 関数について、もう少し説明を追加しておく。

Classifier 関数

model=L.Classifier(mnn, lossfun=F.mean_squared_error)

を使わずに、学習を進めることもできる。
mopt.update(model, x, x*x) を分解することに相当するが、その場合以下の様になる。

mnn = MyNN(120)
mopt = optimizers.SGD()
mopt.setup(model)


を準備し、

def learn2(num, mnn, mopt, jitt=0, batch=20):
  for i in range(num):
   print "turn {0}".format(i)
   x=Variable( np.array([np.random.permutation(batch)-batch/2.0+1.0*i/num+jitt], dtype=np.float32).T )
   mopt.zero_grads()
   err=F.mean_squared_error(mnn(x), x*x)
   err.backward()
   mopt.update()
   print err.data/batch


という関数で学習をすすめる。
この関数では引数に model は使用しない。その代わりネットワークモデル自体を引数として受け取っておく。
mopt で勾配を初期化(zero_grads())、そして、mnn と MSE 関数を使ってエラーを直接評価して、そのエラーから誤差逆伝播に使う勾配を計算する(backward())。
そして最後に mopt で update() する、という手順だ。

learn() の mopt.update(model, x, x*x) はこれらの手順を1つの関数で済ますことができるように設計されている。


y=x^2

緑線が真の解(y=x^2)で、青線が mnn により学習された二次曲線。




やじるし 機械学習Chainer関連メモの目次