Last active
September 19, 2021 18:10
-
-
Save ryanzidago/8963ad4d07b4b6568f16aa675048bb01 to your computer and use it in GitHub Desktop.
Fastbook Chapter 4 - Custom Learner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mnist_loss(predictions, targets): | |
predictions = predictions.sigmoid() | |
return torch.where(targets == 1, 1 - predictions, predictions).mean() | |
class BasicOptim: | |
def __init__(self, params, learning_rate): | |
self.params = params | |
self.learning_rate = learning_rate | |
def step(self, *args, **kwargs): | |
for param in self.params(): | |
param.data -= param.grad.data * self.learning_rate | |
def zero_grad(self, *args, **kwargs): | |
for param in self.params(): | |
param.grad = None | |
# custom learner | |
class CustomLearner: | |
def __init__(self, dls, model, opt_func, loss_func, metrics=batch_accuracy): | |
self.dls = dls | |
self.model = model | |
self.opt_func = opt_func | |
self.loss_func = loss_func | |
self.metrics = metrics | |
def fit(self, epochs, learning_rate=1.): | |
self.train_model(epochs, learning_rate) | |
def train_model(self, epochs, learning_rate): | |
# opt = opt_func(self.model.parameters(), learning_rate) | |
opt = BasicOptim(self.model.parameters, learning_rate) | |
for _ in range(epochs): | |
self.train_epoch(self.dls[0], opt) | |
print(self.validate_epoch(self.dls[1]), end="\n") | |
def train_epoch(self, train_dls, opt): | |
for xb, yb in train_dls: | |
# step 2: make predictions | |
predictions = self.model(xb) | |
# step 3: calculate the loss, based on the previously calculated predictions | |
loss = self.calc_loss(predictions, yb) | |
# step 4: calculate the gradient, based on the previously calculate loss | |
self.calc_grad(loss) | |
# step 5: step the weights | |
opt.step() | |
opt.zero_grad() | |
def calc_loss(self, predictions, yb): | |
return self.loss_func(predictions, yb) | |
def calc_grad(self, loss): | |
loss.backward() | |
def validate_epoch(self, valid_dl): | |
accuracies = [self.batch_accuracy(self.model(xb), yb) for xb, yb in valid_dl] | |
return round(torch.stack(accuracies).mean().item(), 4) | |
def batch_accuracy(self, xb, yb): | |
predictions = xb.sigmoid() | |
correct = (predictions > 0.5) == yb | |
return correct.float().mean() | |
path = untar_data(URLs.MNIST_SAMPLE) | |
Path.BASE_PATH = path | |
threes = (path/"train"/"3").ls().sorted() | |
sevens = (path/"train"/"7").ls().sorted() | |
three_tensors = [tensor(Image.open(image)) for image in threes] | |
seven_tensors = [tensor(Image.open(image)) for image in sevens] | |
stacked_threes = torch.stack(three_tensors).float() / 255 | |
stacked_sevens = torch.stack(seven_tensors).float() / 255 | |
valid_threes = (path/"valid"/"3").ls() | |
valid_sevens = (path/"valid"/"7").ls() | |
valid_three_tensors = [tensor(Image.open(image)) for image in valid_threes] | |
valid_seven_tensors = [tensor(Image.open(image)) for image in valid_sevens] | |
valid_stacked_threes = torch.stack(valid_three_tensors).float() / 255 | |
valid_stacked_sevens = torch.stack(valid_seven_tensors).float() / 255 | |
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28) | |
train_y = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1) | |
dset = list(zip(train_x, train_y)) | |
valid_x = torch.cat([valid_stacked_threes, valid_stacked_sevens]).view(-1, 28 * 28) | |
valid_y = tensor([1] * len(valid_stacked_threes) + [0] * len(valid_stacked_sevens)).unsqueeze(1) | |
valid_dset = list(zip(valid_x, valid_y)) | |
# step 1. initalize the weights and bias | |
weights = (torch.randn((28 * 28, 1)) * 1.0).requires_grad_() | |
bias = (torch.randn(1) * 1.0).requires_grad_() | |
dl = DataLoader(dset, batch_size=256) | |
valid_dl = DataLoader(valid_dset, batch_size=256) | |
dls = DataLoder(dl, valid_dl) | |
custom_learner = CustomLearner(dls, nn.Linear(28 * 28, 1), opt_func=SGD, loss_func=mnist_loss) | |
custom_learner.fit(5, learning_rate=0.1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment