Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save radekosmulski/54dc35136133cfceb67aded2004d18c2 to your computer and use it in GitHub Desktop.
Save radekosmulski/54dc35136133cfceb67aded2004d18c2 to your computer and use it in GitHub Desktop.
from fastai.vision import *
from fastai.script import *
from torch import nn
from fastai.metrics import top_k_accuracy
path = untar_data(URLs.CIFAR)
data = ImageDataBunch.from_folder(path, valid='test')
class block(nn.Module):
def __init__(self, n_in, n_out, two_d=True):
super().__init__()
self.op = nn.Conv2d(n_in, n_out, 3) if two_d else nn.Linear(n_in, n_out)
self.bn = nn.BatchNorm2d(n_out) if two_d else nn.BatchNorm1d(n_out)
def forward(self, x):
x = self.op(x)
x = F.relu(x)
x = self.bn(x)
return x
arch = SequentialEx(
block(3,32),
block(32,32),
nn.MaxPool2d(2),
block(32,32),
block(32,32),
nn.MaxPool2d(2),
Flatten(),
block(800, 800, False),
block(800, 800, False),
nn.Linear(800, 10)
)
def top_3_accuracy(preds, targs): return top_k_accuracy(preds, targs, 3)
learn = Learner(data, arch, metrics=[accuracy, top_3_accuracy])
@call_parse
def train(
epochs: Param("Number of epochs to train", int)=1,
max_lr: Param("Maximum lr for one cycle", float)=1e-3
):
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(epochs, max_lr)
learn.recorder.plot_losses()
loss, top_1, top_3 = learn.validate()
learn.save(f'{epochs}_{max_lr}_{loss:.2f}_{top_1:.2f}_{top_3:.2f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment