Skip to content

Instantly share code, notes, and snippets.

@poppingtonic
Last active January 24, 2018 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save poppingtonic/8cbc1bd8be6a77c6efde53abc62dfeb8 to your computer and use it in GitHub Desktop.
Save poppingtonic/8cbc1bd8be6a77c6efde53abc62dfeb8 to your computer and use it in GitHub Desktop.
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
sz=224 # image size
# architecture, from https://github.com/facebookresearch/ResNeXt
arch=resnext50
# batch size
bs=64
PATH = 'data/spiderscorpions/'
# Enable data augmentation, and precompute=True
# transforms_side_on flips the image along the vertical axis
# max_zoom: 1.1 makes images up to 10% larger
tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
data=ImageClassifierData.from_paths(PATH,tfms=tfms)
learn = ConvLearner.pretrained(arch, data, precompute=True)
# Use lr_find() to find highest learning rate where loss is still clearly improving
learn.lr_find()
# check the plot to find the learning rate where the losss is still improving
learn.sched.plot()
# assuming the optimal learning rate is 0.01, train for 3 epochs
learn.fit(0.01, 3)
# train last layer with data augmentation (i.e. precompute=False) for 2-3 epochs with cycle_len=1
learn.precompute=False
learn.fit(1e-2, 3, cycle_len=1)
# unfreeze all layers, thus opening up resnext50's original ImageNet weights for the
# features in the two spider and scorpion classes
learn.unfreeze()
lr = 0.01
# fastai groups the layers in all of the pre-packaged pretrained convolutional networks into three groups
# retrain the three layer groups in resnext50 using these learning rates for each group
# We set earlier layers to 3x-10x lower learning rate than next higher layer
lrs = np.array([lr/9, lr/3, lr])
learn.fit(lrs, 3)
# Use lr_find() again
learn.lr_find()
learn.sched.plot()
learn.fit(1e-2, 3, cycle_len=1, cycle_mult=2)
log_preds,y = learn.TTA()
preds = np.mean(np.exp(log_preds),0)
accuracy(log_preds, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment