This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accuracy(out, y_true): | |
y_hat = out.argmax(dim=-1).view(y_true.size(0), -1) | |
y_true = y_true.view(y_true.size(0), -1) | |
match = y_hat == y_true | |
return match.float().mean() | |
class Accuracy(Callback): | |
def epoch_started(self, **kwargs): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RollingLoss(Callback): | |
def __init__(self, smooth=0.98): | |
self.smooth = smooth | |
def batch_ended(self, phase, **kwargs): | |
prev = phase.rolling_loss | |
a = self.smooth | |
avg_loss = a * prev + (1 - a) * phase.batch_loss | |
debias_loss = avg_loss / (1 - a ** phase.batch_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(model, opt, phases, callbacks=None, epochs=1, device=default_device, loss_fn=F.nll_loss): | |
model.to(device) | |
cb = callbacks | |
cb.training_started(phases=phases, optimizer=opt) | |
for epoch in range(1, epochs + 1): | |
cb.epoch_started(epoch=epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = create_model(params) | |
phases = create_train_valid_data() | |
opt = optim.SGD(model.params, lr=1e-3) | |
model.to(device) | |
for epoch in range(1, epochs + 1): | |
for phase in phases: | |
n = len(phase.loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Converting FER2013 dataset from CSV representation into folder with images. | |
The dataset is taken from: | |
https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge | |
Encoding: | |
(0=Angry, 1=Disgust, 2=Fear, 3=Happy, 4=Sad, 5=Surprise, 6=Neutral). | |
""" |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Training ResNet18 model on 50000 samples per category. | |
""" | |
import sys | |
from fastai import defaults | |
from fastai.vision import create_cnn, get_transforms | |
from fastai.metrics import accuracy | |
from fastai.callbacks import EarlyStoppingCallback, SaveModelCallback, CSVLogger | |
from fastai.vision.data import ImageItemList, imagenet_stats |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.