Skip to content

Instantly share code, notes, and snippets.

@Mirodil
Created December 10, 2018 15:31
Show Gist options
  • Save Mirodil/a4ecb1f59e553e8bfa8e379932ecf28d to your computer and use it in GitHub Desktop.
Save Mirodil/a4ecb1f59e553e8bfa8e379932ecf28d to your computer and use it in GitHub Desktop.
PyTorch Learning Rate Finder
def find_learning_rate(model, data_loader, criterion, lr:tuple=(1e-7, 1), epochs:int=1):
history = []
min_lr, max_lr = lr
num_batches = epochs * len(data_loader)
last_avg_loss, i, beta = 0, 0, 0.98
# preserve initial state
initial_weights = './temp.model'
torch.save(model.state_dict(), initial_weights)
optimizer = Adam(model.parameters(), lr=min_lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda n: (max_lr/min_lr) ** (n/num_batches))
model.train()
for epoch_idx in range(epochs):
progress_bar = tqdm_notebook(data_loader, leave=False)
for batch_idx, (x, y) in enumerate(progress_bar):
i += 1
scheduler.step()
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
lr, *_ = scheduler.get_lr()
# smooth loss
last_avg_loss = beta * last_avg_loss + (1 - beta) * loss.item()
smooth_loss = last_avg_loss / (1 - beta ** i)
history.append((lr, smooth_loss))
if lr >= max_lr:
break
progress_bar.close()
# restore initial state
model.load_state_dict(torch.load(initial_weights))
return np.array(history, dtype=np.float16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment