Created
July 3, 2020 11:16
-
-
Save pranshuj73/2ccb64032fd23926999922a72a99418e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() # this is for stopping the model from keeping track of old parameters | |
def evaluate(model, val_loader): | |
# This function will evaluate the model and give back the val acc and loss | |
model.eval() | |
outputs = [model.validation_step(batch) for batch in val_loader] | |
return model.validation_epoch_end(outputs) | |
# getting the current learning rate | |
def get_lr(optimizer): | |
for param_group in optimizer.param_groups: | |
return param_group['lr'] | |
# this fit function follows the intuition of 1cycle lr | |
def fit(epochs, max_lr, model, train_loader=train_dl, val_loader=val_dl, weight_decay=0, grad_clip=None, opt_func=torch.optim.Adam): | |
torch.cuda.empty_cache() | |
history = [] #keep track of the evaluation results | |
# setting upcustom optimizer including weight decay | |
optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay) | |
# setting up 1cycle lr scheduler | |
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader)) | |
for epoch in range(epochs): | |
# training | |
model.train() | |
train_losses = [] | |
lrs = [] | |
for batch in train_loader: | |
loss = model.training_step(batch) | |
train_losses.append(loss) | |
loss.backward() | |
# gradient clipping | |
if grad_clip: | |
nn.utils.clip_grad_value_(model.parameters(), grad_clip) | |
optimizer.step() | |
optimizer.zero_grad() | |
# record the lr | |
lrs.append(get_lr(optimizer)) | |
sched.step() | |
#validation | |
result = evaluate(model, val_loader) | |
result['train_loss'] = torch.stack(train_losses).mean().item() | |
result['lrs'] = lrs | |
model.epoch_end(epoch, result) | |
history.append(result) | |
return history |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment