Skip to content

Instantly share code, notes, and snippets.

@kojix2
Created March 27, 2018 23:20
Show Gist options
  • Save kojix2/42a08884e568d0ba9e2da951c086e8b9 to your computer and use it in GitHub Desktop.
Save kojix2/42a08884e568d0ba9e2da951c086e8b9 to your computer and use it in GitHub Desktop.
red-chainer mnist-cnn sample
require 'chainer'
require 'fileutils'
require 'tmpdir'
C = Chainer
CT = C::Training
CI = C::Iterators
CTE = CT::Extensions
class CNN < C::Chain
include C::Links::Connection
F = C::Functions
R = F::Activation::Relu
P = F::Pooling::MaxPooling2D
def initialize
super()
init_scope do
@conv1 = Convolution2D.new(1, 10, 5)
@conv2 = Convolution2D.new(10, 20, 5)
@l1 = Linear.new(nil, out_size:500)
@l2 = Linear.new(nil, out_size:10)
end
end
def call(x)
h = P.max_pooling_2d(R.relu(@conv1.(x)), 2)
h = P.max_pooling_2d(R.relu(@conv2.(h)), 2)
h = @l1.(h)
@l2.(h)
end
end
model = C::Links::Model::Classifier.new(CNN.new())
optimizer = C::Optimizers::Adam.new
optimizer.setup(model)
train, test = C::Datasets::Mnist.get_mnist(ndim:3)
train_iter = CI::SerialIterator.new(train, 100)
test_iter = CI::SerialIterator.new(test, 100, repeat: false, shuffle: false)
updater = CT::StandardUpdater.new(train_iter, optimizer, device: -1)
trainer = CT::Trainer.new(updater, stop_trigger: [20, 'epoch'], out: 'result')
trainer.extend(CTE::Evaluator.new(test_iter, model, device: -1))
# Take a snapshot for each specified epoch
trainer.extend(CTE::Snapshot.new, trigger: [20, 'epoch'], priority: -100)
trainer.extend(CTE::LogReport.new)
trainer.extend(CTE::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(CTE::ProgressBar.new)
trainer.run
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment