Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Created June 20, 2019 19:29
Show Gist options
  • Save radekosmulski/adc82cdcbd486b027e389d8b97e7e12f to your computer and use it in GitHub Desktop.
Save radekosmulski/adc82cdcbd486b027e389d8b97e7e12f to your computer and use it in GitHub Desktop.
training on CIFAR10 using fastai from the command line
import fire
import fastai
from fastai.vision import *
from torch import nn
from fastai.metrics import top_k_accuracy
path = untar_data(URLs.CIFAR)
data = ImageDataBunch.from_folder(path, valid='test')
class block(nn.Module):
def __init__(self, n_in, n_out, two_d=True):
super().__init__()
self.op = nn.Conv2d(n_in, n_out, 3) if two_d else nn.Linear(n_in, n_out)
self.bn = nn.BatchNorm2d(n_out) if two_d else nn.BatchNorm1d(n_out)
def forward(self, x):
x = self.op(x)
x = F.relu(x)
x = self.bn(x)
return x
arch = SequentialEx(
block(3,32),
block(32,32),
nn.MaxPool2d(2),
block(32,32),
block(32,32),
nn.MaxPool2d(2),
Flatten(),
block(800, 800, False),
block(800, 800, False),
nn.Linear(800, 10)
)
def top_3_accuracy(preds, targs): return top_k_accuracy(preds, targs, 3)
learn = Learner(data, arch, metrics=[accuracy, top_3_accuracy])
def train(epochs=3, max_lr=1e-3, find_lr=False, plot_losses=False, save_model=False):
if find_lr:
learn.lr_find()
fig = learn.recorder.plot(return_fig=True)
fig.savefig('lr_find.png')
learn.fit_one_cycle(epochs, max_lr)
if plot_losses:
fig = learn.recorder.plot_losses(return_fig=True)
fig.savefig('losses.png')
if save_model:
loss, top_1, top_3 = learn.validate()
learn.save(f'{epochs}_{max_lr}_{loss:.2f}_{top_1:.2f}_{top_3:.2f}')
if __name__ == '__main__':
fire.Fire(train)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment