-
-
Save asi1024/9ed419f652c1c81301ec70b895610a54 to your computer and use it in GitHub Desktop.
Migration Guide from Chainer to PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import os.path | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch import optim | |
from torchvision import datasets, transforms | |
import pytorch_pfn_extras as ppe | |
from pytorch_pfn_extras.training import extensions | |
import matplotlib | |
matplotlib.use('Agg') | |
# Network definition | |
class MLP(nn.Module): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
# the size of the inputs to each layer will be inferred | |
self.l1 = ppe.nn.LazyLinear(None, n_units) # n_in -> n_units | |
self.l2 = ppe.nn.LazyLinear(None, n_units) # n_units -> n_units | |
self.l3 = ppe.nn.LazyLinear(None, n_out) # n_units -> n_out | |
def forward(self, x): | |
x = x.reshape(-1, 784) | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
class Classifier(nn.Module): | |
def __init__(self, predictor): | |
super(Classifier, self).__init__() | |
self.predictor = predictor | |
self.lossfun = nn.CrossEntropyLoss() | |
def forward(self, *args): | |
t = args[-1] | |
args = args[:-1] | |
self.y = self.predictor(*args) | |
self.loss = self.lossfun(self.y, t) | |
pred = self.y.argmax(dim=1).reshape(t.shape) | |
self.accuracy = (pred == t).sum().item() / len(t) | |
return self.loss | |
def train_func(manager, model, optimizer, device, train_loader): | |
while not manager.stop_trigger: | |
model.train() | |
for x, t in train_loader: | |
with manager.run_iteration(): | |
x, t = x.to(device), t.to(device) | |
optimizer.zero_grad() | |
loss = model(x, t) | |
ppe.reporting.report({'main/loss': loss.item()}) | |
ppe.reporting.report({'main/accuracy': model.accuracy}) | |
loss.backward() | |
optimizer.step() | |
def test_func(model, device, x, t): | |
model.eval() | |
x, t = x.to(device), t.to(device) | |
loss = model(x, t).item() | |
accuracy = model.accuracy | |
ppe.reporting.report({'validation/main/loss': loss}) | |
ppe.reporting.report({'validation/main/accuracy': accuracy}) | |
def main(): | |
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, | |
help='Number of images in each mini-batch') | |
parser.add_argument('--epoch', '-e', type=int, default=20, | |
help='Number of sweeps over the dataset to train') | |
parser.add_argument('--frequency', '-f', type=int, default=-1, | |
help='Frequency of taking a snapshot') | |
parser.add_argument('--device', '-d', type=str, default='cpu') | |
parser.add_argument('--out', '-o', default='result', | |
help='Directory to output the result') | |
parser.add_argument('--resume', '-r', type=str, | |
help='Resume the training from snapshot') | |
parser.add_argument('--unit', '-u', type=int, default=1000, | |
help='Number of units') | |
args = parser.parse_args() | |
device = torch.device(args.device) | |
use_cuda = device.type == 'cuda' and torch.cuda.is_available() | |
print('Device: {}'.format(device)) | |
print('# unit: {}'.format(args.unit)) | |
print('# Minibatch-size: {}'.format(args.batchsize)) | |
print('# epoch: {}'.format(args.epoch)) | |
print('') | |
# Load the MNIST dataset | |
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | |
train_loader = torch.utils.data.DataLoader( | |
datasets.MNIST('../data', train=True, download=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])), | |
batch_size=args.batchsize, shuffle=True, **kwargs) | |
test_loader = torch.utils.data.DataLoader( | |
datasets.MNIST('../data', train=False, transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])), | |
batch_size=args.batchsize, shuffle=True, **kwargs) | |
iters_per_epoch = len(train_loader) | |
# Set up a neural network to train | |
# Classifier reports softmax cross entropy loss and accuracy at every | |
# iteration, which will be used by the PrintReport extension below. | |
model = Classifier(MLP(args.unit, 10)).to(device=device) | |
model(torch.zeros(100, 784).to(device).to(torch.float32), | |
torch.zeros(100).to(device).to(torch.int64)) | |
# Setup an optimizer | |
optimizer = optim.Adam(model.parameters()) | |
manager = ppe.training.ExtensionsManager( | |
{'main': model}, {'main': optimizer}, args.epoch, | |
iters_per_epoch=iters_per_epoch) | |
# Evaluator | |
manager.extend( | |
extensions.Evaluator( | |
{'main': test_loader}, model, | |
eval_func=lambda x, t: test_func( | |
model, device, x, t), | |
progress_bar=True)) | |
# Take a snapshot for each specified epoch | |
# Take a snapshot each ``frequency`` epoch, delete old stale | |
# snapshots and automatically load from snapshot files if any | |
# files are already resident at result directory. | |
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) | |
manager.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) | |
# Write a log of evaluation statistics for each epoch | |
manager.extend(extensions.LogReport(), call_before_training=True) | |
# Save two plot images to the result dir | |
manager.extend( | |
extensions.PlotReport(['main/loss', 'validation/main/loss'], | |
'epoch', file_name='loss.png'), | |
call_before_training=True) | |
manager.extend( | |
extensions.PlotReport( | |
['main/accuracy', 'validation/main/accuracy'], | |
'epoch', file_name='accuracy.png'), | |
call_before_training=True) | |
# Print selected entries of the log to stdout | |
# Here "main" refers to the target link of the "main" optimizer again, and | |
# "validation" refers to the default name of the Evaluator extension. | |
# Entries other than 'epoch' are reported by the Classifier link, called by | |
# either the updater or the evaluator. | |
manager.extend(extensions.PrintReport( | |
['epoch', 'main/loss', 'validation/main/loss', | |
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']), | |
call_before_training=True) | |
# Print a progress bar to stdout | |
manager.extend(extensions.ProgressBar()) | |
if args.resume is not None: | |
model.load_state_dict( | |
torch.load(os.path.join(args.resume, 'mlp.model'))) | |
optimizer.load_state_dict( | |
torch.load(os.path.join(args.resume, 'mlp.state'))) | |
# Run the training | |
train_func(manager, model, optimizer, device, train_loader) | |
if args.out is not None: | |
torch.save(model.state_dict(), os.path.join(args.out, 'mlp.model')) | |
torch.save(optimizer.state_dict(), os.path.join(args.out, 'mlp.state')) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment