Created
December 4, 2020 14:58
-
-
Save johschmidt42/e13b448302dac72b8eb5e9eb36278d94 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import torch | |
from torch import nn | |
from matplotlib import pyplot as plt | |
from tqdm import tqdm, trange | |
import math | |
class LearningRateFinder: | |
""" | |
Train a model using different learning rates within a range to find the optimal learning rate. | |
""" | |
def __init__(self, | |
model: nn.Module, | |
criterion, | |
optimizer, | |
device | |
): | |
self.model = model | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.loss_history = {} | |
self._model_init = model.state_dict() | |
self._opt_init = optimizer.state_dict() | |
self.device = device | |
def fit(self, | |
data_loader: torch.utils.data.DataLoader, | |
steps=100, | |
min_lr=1e-7, | |
max_lr=1, | |
constant_increment=False | |
): | |
""" | |
Trains the model for number of steps using varied learning rate and store the statistics | |
""" | |
self.loss_history = {} | |
self.model.train() | |
current_lr = min_lr | |
steps_counter = 0 | |
epochs = math.ceil(steps / len(data_loader)) | |
progressbar = trange(epochs, desc='Progress') | |
for epoch in progressbar: | |
batch_iter = tqdm(enumerate(data_loader), 'Training', total=len(data_loader), | |
leave=False) | |
for i, (x, y) in batch_iter: | |
x, y = x.to(self.device), y.to(self.device) | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = current_lr | |
self.optimizer.zero_grad() | |
out = self.model(x) | |
loss = self.criterion(out, y) | |
loss.backward() | |
self.optimizer.step() | |
self.loss_history[current_lr] = loss.item() | |
steps_counter += 1 | |
if steps_counter > steps: | |
break | |
if constant_increment: | |
current_lr += (max_lr - min_lr) / steps | |
else: | |
current_lr = current_lr * (max_lr / min_lr) ** (1 / steps) | |
def plot(self, | |
smoothing=True, | |
clipping=True, | |
smoothing_factor=0.1 | |
): | |
""" | |
Shows loss vs learning rate(log scale) in a matplotlib plot | |
""" | |
loss_data = pd.Series(list(self.loss_history.values())) | |
lr_list = list(self.loss_history.keys()) | |
if smoothing: | |
loss_data = loss_data.ewm(alpha=smoothing_factor).mean() | |
loss_data = loss_data.divide(pd.Series( | |
[1 - (1.0 - smoothing_factor) ** i for i in range(1, loss_data.shape[0] + 1)])) # bias correction | |
if clipping: | |
loss_data = loss_data[10:-5] | |
lr_list = lr_list[10:-5] | |
plt.plot(lr_list, loss_data) | |
plt.xscale('log') | |
plt.title('Loss vs Learning rate') | |
plt.xlabel('Learning rate (log scale)') | |
plt.ylabel('Loss (exponential moving average)') | |
plt.show() | |
def reset(self): | |
""" | |
Resets the model and optimizer to its initial state | |
""" | |
self.model.load_state_dict(self._model_init) | |
self.optimizer.load_state_dict(self._opt_init) | |
print('Model and optimizer in initial state.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment