Skip to content

Instantly share code, notes, and snippets.

@asi1024
Last active May 7, 2020
Embed
What would you like to do?
Migration Guide from Chainer to PyTorch
#!/usr/bin/env python
import argparse
import os.path
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torchvision import datasets, transforms
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions
import matplotlib
matplotlib.use('Agg')
# Network definition
class MLP(nn.Module):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
# the size of the inputs to each layer will be inferred
self.l1 = ppe.nn.LazyLinear(None, n_units) # n_in -> n_units
self.l2 = ppe.nn.LazyLinear(None, n_units) # n_units -> n_units
self.l3 = ppe.nn.LazyLinear(None, n_out) # n_units -> n_out
def forward(self, x):
x = x.reshape(-1, 784)
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
class Classifier(nn.Module):
def __init__(self, predictor):
super(Classifier, self).__init__()
self.predictor = predictor
self.lossfun = nn.CrossEntropyLoss()
def forward(self, *args):
t = args[-1]
args = args[:-1]
self.y = self.predictor(*args)
self.loss = self.lossfun(self.y, t)
pred = self.y.argmax(dim=1).reshape(t.shape)
self.accuracy = (pred == t).sum().item() / len(t)
return self.loss
def train_func(manager, model, optimizer, device, train_loader):
while not manager.stop_trigger:
model.train()
for x, t in train_loader:
with manager.run_iteration():
x, t = x.to(device), t.to(device)
optimizer.zero_grad()
loss = model(x, t)
ppe.reporting.report({'main/loss': loss.item()})
ppe.reporting.report({'main/accuracy': model.accuracy})
loss.backward()
optimizer.step()
def test_func(model, device, x, t):
model.eval()
x, t = x.to(device), t.to(device)
loss = model(x, t).item()
accuracy = model.accuracy
ppe.reporting.report({'validation/main/loss': loss})
ppe.reporting.report({'validation/main/accuracy': accuracy})
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--device', '-d', type=str, default='cpu')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', type=str,
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
device = torch.device(args.device)
use_cuda = device.type == 'cuda' and torch.cuda.is_available()
print('Device: {}'.format(device))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Load the MNIST dataset
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batchsize, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batchsize, shuffle=True, **kwargs)
iters_per_epoch = len(train_loader)
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = Classifier(MLP(args.unit, 10)).to(device=device)
model(torch.zeros(100, 784).to(device).to(torch.float32),
torch.zeros(100).to(device).to(torch.int64))
# Setup an optimizer
optimizer = optim.Adam(model.parameters())
manager = ppe.training.ExtensionsManager(
{'main': model}, {'main': optimizer}, args.epoch,
iters_per_epoch=iters_per_epoch)
# Evaluator
manager.extend(
extensions.Evaluator(
{'main': test_loader}, model,
eval_func=lambda x, t: test_func(
model, device, x, t),
progress_bar=True))
# Take a snapshot for each specified epoch
# Take a snapshot each ``frequency`` epoch, delete old stale
# snapshots and automatically load from snapshot files if any
# files are already resident at result directory.
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
manager.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
manager.extend(extensions.LogReport(), call_before_training=True)
# Save two plot images to the result dir
manager.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'),
call_before_training=True)
manager.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'),
call_before_training=True)
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
manager.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']),
call_before_training=True)
# Print a progress bar to stdout
manager.extend(extensions.ProgressBar())
if args.resume is not None:
model.load_state_dict(
torch.load(os.path.join(args.resume, 'mlp.model')))
optimizer.load_state_dict(
torch.load(os.path.join(args.resume, 'mlp.state')))
# Run the training
train_func(manager, model, optimizer, device, train_loader)
if args.out is not None:
torch.save(model.state_dict(), os.path.join(args.out, 'mlp.model'))
torch.save(optimizer.state_dict(), os.path.join(args.out, 'mlp.state'))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment