Skip to content

Instantly share code, notes, and snippets.

@aakashns
Last active July 24, 2018 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aakashns/e242d25c82f77692ed7c439212bc3264 to your computer and use it in GitHub Desktop.
Save aakashns/e242d25c82f77692ed7c439212bc3264 to your computer and use it in GitHub Desktop.
from fastai.conv_learner import ConvLearner, num_cpus, accuracy
def get_learner(arch, bs):
"""Create a FastAI learner using the given model"""
data = get_data(bs, num_cpus())
learn = ConvLearner.from_model_data(arch.cuda(), data)
learn.crit = nn.CrossEntropyLoss()
learn.metrics = [accuracy]
return learn
def get_TTA_accuracy(learn):
"""Calculate accuracy with Test Time Agumentation(TTA)"""
preds, targs = learn.TTA()
preds = 0.6 * preds[0] + 0.4 * preds[1:].sum(0)
return accuracy_np(preds, targs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment