Skip to content

Instantly share code, notes, and snippets.

@ryanzidago
Last active September 19, 2021 18:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanzidago/8963ad4d07b4b6568f16aa675048bb01 to your computer and use it in GitHub Desktop.
Save ryanzidago/8963ad4d07b4b6568f16aa675048bb01 to your computer and use it in GitHub Desktop.
Fastbook Chapter 4 - Custom Learner
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets == 1, 1 - predictions, predictions).mean()
class BasicOptim:
def __init__(self, params, learning_rate):
self.params = params
self.learning_rate = learning_rate
def step(self, *args, **kwargs):
for param in self.params():
param.data -= param.grad.data * self.learning_rate
def zero_grad(self, *args, **kwargs):
for param in self.params():
param.grad = None
# custom learner
class CustomLearner:
def __init__(self, dls, model, opt_func, loss_func, metrics=batch_accuracy):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metrics = metrics
def fit(self, epochs, learning_rate=1.):
self.train_model(epochs, learning_rate)
def train_model(self, epochs, learning_rate):
# opt = opt_func(self.model.parameters(), learning_rate)
opt = BasicOptim(self.model.parameters, learning_rate)
for _ in range(epochs):
self.train_epoch(self.dls[0], opt)
print(self.validate_epoch(self.dls[1]), end="\n")
def train_epoch(self, train_dls, opt):
for xb, yb in train_dls:
# step 2: make predictions
predictions = self.model(xb)
# step 3: calculate the loss, based on the previously calculated predictions
loss = self.calc_loss(predictions, yb)
# step 4: calculate the gradient, based on the previously calculate loss
self.calc_grad(loss)
# step 5: step the weights
opt.step()
opt.zero_grad()
def calc_loss(self, predictions, yb):
return self.loss_func(predictions, yb)
def calc_grad(self, loss):
loss.backward()
def validate_epoch(self, valid_dl):
accuracies = [self.batch_accuracy(self.model(xb), yb) for xb, yb in valid_dl]
return round(torch.stack(accuracies).mean().item(), 4)
def batch_accuracy(self, xb, yb):
predictions = xb.sigmoid()
correct = (predictions > 0.5) == yb
return correct.float().mean()
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
threes = (path/"train"/"3").ls().sorted()
sevens = (path/"train"/"7").ls().sorted()
three_tensors = [tensor(Image.open(image)) for image in threes]
seven_tensors = [tensor(Image.open(image)) for image in sevens]
stacked_threes = torch.stack(three_tensors).float() / 255
stacked_sevens = torch.stack(seven_tensors).float() / 255
valid_threes = (path/"valid"/"3").ls()
valid_sevens = (path/"valid"/"7").ls()
valid_three_tensors = [tensor(Image.open(image)) for image in valid_threes]
valid_seven_tensors = [tensor(Image.open(image)) for image in valid_sevens]
valid_stacked_threes = torch.stack(valid_three_tensors).float() / 255
valid_stacked_sevens = torch.stack(valid_seven_tensors).float() / 255
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28)
train_y = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1)
dset = list(zip(train_x, train_y))
valid_x = torch.cat([valid_stacked_threes, valid_stacked_sevens]).view(-1, 28 * 28)
valid_y = tensor([1] * len(valid_stacked_threes) + [0] * len(valid_stacked_sevens)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))
# step 1. initalize the weights and bias
weights = (torch.randn((28 * 28, 1)) * 1.0).requires_grad_()
bias = (torch.randn(1) * 1.0).requires_grad_()
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)
dls = DataLoder(dl, valid_dl)
custom_learner = CustomLearner(dls, nn.Linear(28 * 28, 1), opt_func=SGD, loss_func=mnist_loss)
custom_learner.fit(5, learning_rate=0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment