Skip to content

Instantly share code, notes, and snippets.

@cemoody
Created January 11, 2017 21:09
Show Gist options
  • Save cemoody/62bec60f7be6bc2be7a913f09293035e to your computer and use it in GitHub Desktop.
Save cemoody/62bec60f7be6bc2be7a913f09293035e to your computer and use it in GitHub Desktop.
Chainer Model Wrapper
import chainer
from chainer import cuda
from chainer import training
from chainer.training import extensions
from chainer.datasets import TupleDataset
from chainer.iterators import SerialIterator
class Wrapper(object):
def __init__(self, model, batchsize=512, n_epochs=100, device=None,
resume=True):
self.model = model
self.n_epochs = n_epochs
self.device = device
self.batchsize = batchsize
self.resume = resume
if device is not None:
self.model.to_gpu(device)
self.optimizer = chainer.optimizers.Adam()
self.optimizer.setup(self.model)
def fit(self, X, y, debug=False):
chainer.set_debug(debug)
train = TupleDataset(X)
train_iter = SerialIterator(train, self.batchsize)
updater = training.StandardUpdater(train_iter, self.optimizer,
device=self.device)
trainer = training.Trainer(updater, (self.n_epochs, 'epoch'),
out='out_' + str(self.device))
# Setup logging, printing & saving
keys = self.model.keys
reports = ['epoch']
reports += ['main/' + key for key in keys]
trainer.extend(extensions.snapshot(), trigger=(1000, 'epoch'))
trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
trainer.extend(extensions.PrintReport(reports))
trainer.extend(extensions.ProgressBar(update_interval=10))
# If previous model detected, resume
if self.resume:
print("Loading from {}".format(self.resume))
chainer.serializers.load_npz(self.resume, trainer)
# Run the model
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment