Last active
November 2, 2017 13:18
-
-
Save adash333/87c4329891613d0b39563a63fe711c8b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# original code from https://github.com/mitmul/chainer-handson/blob/master/2-Try-Trainer-class_ja.ipynb | |
# 解説ページ http://twosquirrel.mints.ne.jp/?p=20823 | |
#1 datasetの準備 | |
# 本当は画像からdatasetを準備するべきだが、簡便のため、今回は | |
# chainer.datasetsが提供するNumpy配列のmnistデータをimportする。 | |
# chainer.datasetsを使用しないで自分で画像からchainer用のデータセットを | |
# 作成する方法については、以下を参照 | |
# http://twosquirrel.mints.ne.jp/?p=20366 | |
from chainer.datasets import mnist | |
train, test = mnist.get_mnist() | |
# --------------------------------------------- # | |
#2 Iteratorの準備 | |
# ChainerにはTrainerというモジュールがあり、基本的にこれを用いて学習を行う。 | |
# Trainerの構造は、以下の図の通りだが、Trainerが使用できるように、順番に、 | |
# Updater, Iterator, Dataset, Optimizer, Model, Extensions | |
# を記載していく必要がある。 | |
# 上記を記載した後に、trainer.run() で学習を実行する。 | |
# Trainerの構造の図のリンク(絶対に閲覧お勧めの図) | |
# https://camo.qiitausercontent.com/d3af5d369c038fed989ea7839b1566ef49a6c331/68747470733a2f2f71696974612d696d6167652d73746f72652e73332e616d617a6f6e6177732e636f6d2f302f31373933342f61373531646633312d623939392d663639322d643833392d3438386332366231633438612e706e67 | |
from chainer import iterators | |
batchsize = 128 | |
train_iter = iterators.SerialIterator(train, batchsize) | |
test_iter = iterators.SerialIterator(test, batchsize, False, False) | |
# --------------------------------------------- # | |
#3 Modelの準備 | |
# ChainerのTrainerを使用するために、Modelを準備する | |
''' | |
# LinkとFunction | |
Chainerでは、ニューラルネットワークの各層を、LinkとFunctionに区別します。 | |
Linkは、パラメータを持つ関数です。 | |
Functionは、パラメータを持たない関数です。 | |
これらを組み合わせてモデルを記述します。 | |
''' | |
import chainer | |
import chainer.links as L | |
import chainer.functions as F | |
''' | |
今回は、手書き数字MNIST画像を、multiple layer perceptron(多層パーセプトロン)という | |
ニューラルネットワークモデルを用いて機械学習で分類します。 | |
層構造のイメージは、以下のリンクが参考になります。 | |
https://qiita.com/kenmatsu4/items/7b8d24d4c5144a686412 | |
ネットワークは3層で、入力層、隠れ層、出力層の3層とします。 | |
28x28のグレースケール画像を、0から255までの値をとる各ピクセルの値を、 | |
784個、横に並んだ数字の配列に変換して(、さらに255で割って)、 | |
入力層に入れます。入力層のunit数は784個となります。 | |
中間層のunit数(n_mid_units)は、今回は、100個に設定しています。 | |
手書き数字の0から9まで10種類の画像を分類するため、 | |
出力層のunit数(n_out)は、10個となります。 | |
''' | |
class MLP(chainer.Chain): | |
def __init__(self, n_mid_units=100, n_out=10): | |
super(MLP, self).__init__( | |
l1=L.Linear(None, n_mid_units), | |
l2=L.Linear(n_mid_units, n_mid_units), | |
l3=L.Linear(n_mid_units, n_out), | |
) | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
# --------------------------------------------- # | |
#4 モデルと最適化アルゴリズムの設定(Updaterの準備) | |
''' | |
trainer.run()で学習をしているとき、Updaterが内部で、以下を行っている。 | |
1.データセットからデータを取り出し(Iterator) | |
2.モデルに渡してロスを計算し(Model = Optimizer.target) | |
3.Optimizerを使ってモデルのパラメータを更新する(Optimizer) | |
''' | |
from chainer import optimizers | |
from chainer import training | |
gpu_id = -1 # Set to -1 if you don't have a GPU | |
model = MLP() | |
if gpu_id >= 0: | |
model.to_gpu(gpu_id) | |
max_epoch = 3 | |
# モデルをClassifierで包んで、ロスの計算などをモデルに含める | |
model = L.Classifier(model) | |
if gpu_id >= 0: | |
model.to_gpu(gpu_id) | |
# 最適化手法の選択 | |
optimizer = optimizers.SGD() | |
optimizer.setup(model) | |
# UpdaterにIteratorとOptimizerを渡す | |
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id) | |
# --------------------------------------------- # | |
#5 学習と結果の出力 | |
# TrainerにUpdaterを渡す | |
trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result') | |
# TrainerにExtensionを追加 | |
from chainer.training import extensions | |
# trainer.extend()で、学習の進行状況を表すプログレスバーや、lossのグラフ化と画像の保存などを行う | |
# 以下は、意味不明なコードが並んでいるが、分からなくてもとりあえずコピペしておく。 | |
trainer.extend(extensions.LogReport()) | |
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time'])) | |
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) | |
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png')) | |
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) | |
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}')) | |
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) | |
trainer.extend(extensions.dump_graph('main/loss')) | |
# 学習の実行(Extensionsによって、結果も出力される) | |
trainer.run() | |
# --------------------------------------------- # | |
#6 学習結果のパラメータの保存 | |
# Save paramaters | |
chainer.serializers.save_npz('my_mnist.model', model) | |
# --------------------------------------------- # |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment