Chainer MNIST | python3Xのブログ

python3Xのブログ

ここでは40代、50代の方が日々の生活で役に立つ情報や私の趣味であるプログラム、Excelや科学に関する内容で投稿する予定です。

(Node.jsも必要だとは思うのですが、軸はPythonです)

横道ですが、Chainer と Keras の表記の違いを

MNISTを使って表現したいと思います

今回はChainer です

コード

#!/usr/bin/env python
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
import chainerx
import matplotlib
matplotlib.use('Agg')

# ネットワークの定義
class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):    # Chainerでお決まりの文句:中間層と出力層のノード数を受け取る(None)
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)  # 入力層
            self.l2 = L.Linear(None, n_units)  # 中間層(隠れ層)
            self.l3 = L.Linear(None, n_out)    # 出力層、n_outは10
    def forward(self, x):                        # (順伝播)ここは『foward()』の代わりに『__call__()』を使うこともある
        h1 = F.relu(self.l1(x))                  # u_i = Σx_ij wji ⇒ z_i = f(u_i) このf が活性化関数ReLU
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)
# 引数の定義(--initmodel:保存したモデルを使って追加学習をする際、保存したモデルファイルのパスを指定)
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,            # ミニバッチサイズ
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,                  # 学習するエポック数
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,              # スナップショットの頻度
                        help='Frequency of taking a snapshot')
    parser.add_argument('--device', '-d', type=str, default='-1',
                        help='Device specifier. Either ChainerX device '          # ChainerXはCupyをC++で高速に改良したもの
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', type=str,  #保存した最適化の状態を復元する際、保存したモデルファイル(snapshot)を指定
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,                # 中間層のノード数
                        help='Number of units')
    group = parser.add_argument_group('deprecated arguments')
    group.add_argument('--gpu', '-g', dest='device',                     # GPUを使用する場合、GPU の ID を指定。1枚の場合、0 を指定
                       type=int, nargs='?', const=0,
                       help='GPU ID (negative value indicates CPU)')
    args = parser.parse_args()
    device = chainer.get_device(args.device)
# 引数で指定した値を標準出力に出力する
    print('Device: {}'.format(device))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')
    # モデルの作成
    model = L.Classifier(MLP(args.unit, 10))
    model.to_device(device)
    device.use()
    # 最適化法の設定
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    # MNIST データセットの読み込み
    train, test = chainer.datasets.get_mnist()
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # 学習の設定
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
    # エポックごとにテストデータセットでモデル評価を行う
    trainer.extend(extensions.Evaluator(test_iter, model, device=device))
    # chainerxに対応してないならエクステンションの損失のダンプグラフ作成
    if device.xp is not chainerx:
        trainer.extend(extensions.DumpGraph('main/loss'))
    # 指定されたエポックごとにスナップショットを取る(frequency:頻度、trigger:引き金を引く=スナップショットのタイミング)
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))  # frequencyで指定されたエポックごとにスナップショットを取る
    # エポックごとにログを残す
    trainer.extend(extensions.LogReport())
    # 以下のメインロス、検証ロス、エポックをchainer/result/loss.pngとしてプロット
    trainer.extend(
        extensions.PlotReport(['main/loss', 'validation/main/loss'],
                              'epoch', file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(
            ['main/accuracy', 'validation/main/accuracy'],
            'epoch', file_name='accuracy.png'))
    # 下記の内容で出力する
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
    # 処理の進捗度合いの横棒
    trainer.extend(extensions.ProgressBar())
    if args.resume is not None:
        # スナップショットが存在したらそこから再開する
        chainer.serializers.load_npz(args.resume, trainer)
    # 学習開始
    trainer.run()

if __name__ == '__main__':
    main()