Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
chainer.trainingのextensions全部試す
import numpy as np
import chainer
from chainer import Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
class MyModel(Chain):
def __init__(self):
super(MyModel,self).__init__(
l1 = L.Linear(None,100),
l2 = L.Linear(None,100),
l3 = L.Linear(None,10))
def __call__(self,x):
h = F.relu(self.l1(x))
h = F.relu(self.l2(h))
return self.l3(h)
def train():
model = L.Classifier(MyModel())
dev = 0
if dev >= 0:
chainer.cuda.get_device(dev).use()
model.to_gpu()
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, 200)
test_iter = chainer.iterators.SerialIterator(test, 200,repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=dev)
trainer = training.Trainer(updater, (100, 'epoch'), out="result")
# load trainer snapshot
#serializers.load_npz('./result/snapshot_iter_40800', trainer)
# dump_graph
#trainer.extend(extensions.dump_graph("main/loss"))
# Evaluator
trainer.extend(extensions.Evaluator(test_iter, model, device=dev))
# ExponentialShift
#trainer.extend(extensions.ExponentialShift("alpha", 1.000001))
#trainer.extend(extensions.ExponentialShift("alpha", 1.0001))
# LinearShift
#trainer.extend(extensions.LinearShift("alpha", (0.001,0.0001), (20000,30000)))
#trainer.extend(extensions.LinearShift("alpha", (0.01,0.001), (20000,30000)))
# LogReport
trainer.extend(extensions.LogReport())
# snapshot
#trainer.extend(extensions.snapshot())
# snapshot_object
# trainer.extend(extensions.snapshot_object(optimizer, 'optimizer_snapshot_{.updater.epoch}', trigger=(10,'epoch')))
# PrintReport
trainer.extend(extensions.PrintReport( entries=['epoch', 'main/loss', 'main/accuracy', 'elapsed_time' ]))
# ProgressBar
trainer.extend(extensions.ProgressBar())
print("run")
trainer.run()
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment