Skip to content

Instantly share code, notes, and snippets.

@Daiver
Created January 16, 2019 17:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Daiver/b4f9115a9e33a1ca233d0defbabee6d9 to your computer and use it in GitHub Desktop.
Save Daiver/b4f9115a9e33a1ca233d0defbabee6d9 to your computer and use it in GitHub Desktop.
import sys
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import catalyst
from catalyst.dl.callbacks import (
ClassificationLossCallback,
Logger, TensorboardLogger,
OptimizerCallback, SchedulerCallback, CheckpointCallback,
PrecisionCallback, OneCycleLR)
from catalyst.dl.runner import ClassificationRunner
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def main():
print('Python version', sys.version)
print('Catalyst version:', catalyst.__version__)
bs = 32
n_workers = 0
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
loaders = collections.OrderedDict()
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=data_transform)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=bs,
shuffle=True, num_workers=n_workers)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False,
download=True, transform=data_transform)
testloader = torch.utils.data.DataLoader(
testset, batch_size=bs,
shuffle=False, num_workers=n_workers)
loaders["train"] = trainloader
loaders["valid"] = testloader
model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# scheduler = None # for OneCycle usage
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 8], gamma=0.3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)
# the only tricky part
n_epochs = 10
logdir = "./logs/cifar_simple_notebook"
callbacks = collections.OrderedDict()
callbacks["loss"] = ClassificationLossCallback()
callbacks["optimizer"] = OptimizerCallback()
callbacks["precision"] = PrecisionCallback(
precision_args=[1, 3, 5])
# OneCylce custom scheduler callback
# callbacks["scheduler"] = OneCycleLR(
# cycle_len=n_epochs,
# div=3, cut_div=4, momentum_range=(0.95, 0.85))
# Pytorch scheduler callback
callbacks["scheduler"] = SchedulerCallback(
reduce_metric="precision01")
callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()
runner = ClassificationRunner(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler)
runner.train(
loaders=loaders,
callbacks=callbacks,
logdir=logdir,
epochs=n_epochs, verbose=True)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment