Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Last active August 17, 2019 09:31
Show Gist options
  • Save radekosmulski/e31e8d18bafcec06c6d6eb8dd17180c2 to your computer and use it in GitHub Desktop.
Save radekosmulski/e31e8d18bafcec06c6d6eb8dd17180c2 to your computer and use it in GitHub Desktop.
train on CIFAR10 in console using fastai
from fastai.vision import *
from fastai.script import *
from torch import nn
from fastai.metrics import top_k_accuracy
path = untar_data(URLs.CIFAR)
data = ImageDataBunch.from_folder(path, valid='test')
class block(nn.Module):
def __init__(self, n_in, n_out, two_d=True):
super().__init__()
self.op = nn.Conv2d(n_in, n_out, 3) if two_d else nn.Linear(n_in, n_out)
self.bn = nn.BatchNorm2d(n_out) if two_d else nn.BatchNorm1d(n_out)
def forward(self, x):
x = self.op(x)
x = F.relu(x)
x = self.bn(x)
return x
arch = SequentialEx(
block(3,32),
block(32,32),
nn.MaxPool2d(2),
block(32,32),
block(32,32),
nn.MaxPool2d(2),
Flatten(),
block(800, 800, False),
block(800, 800, False),
nn.Linear(800, 10)
)
def top_3_accuracy(preds, targs): return top_k_accuracy(preds, targs, 3)
learn = Learner(data, arch, metrics=[accuracy, top_3_accuracy])
@call_parse
def train(
epochs: Param("Number of epochs to train", int)=1,
max_lr: Param("Maximum lr for one cycle", float)=1e-3,
find_lr: Param("Run lr finder and save figure to lr_find.png", bool)=False,
plot_losses: Param("Plot losses after training and save figure to losses.png", bool)=False,
save_model: Param("Save model after training (name will consists of hyperarams)", bool)=False,
):
if find_lr:
learn.lr_find()
fig = learn.recorder.plot(return_fig=True)
fig.savefig('lr_find.png')
learn.fit_one_cycle(epochs, max_lr)
if plot_losses:
fig = learn.recorder.plot_losses(return_fig=True)
fig.savefig('losses.png')
if save_model:
loss, top_1, top_3 = learn.validate()
model_name = f'{epochs}_{max_lr}_{loss:.2f}_{top_1:.2f}_{top_3:.2f}'
print(f'Saving model with name: {model_name}')
learn.save(model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment