Skip to content

Instantly share code, notes, and snippets.

@FreeFly19
Created July 21, 2023 16:16
Show Gist options
  • Save FreeFly19/b4f076e2f55ea65b1eb54867f3eae57c to your computer and use it in GitHub Desktop.
Save FreeFly19/b4f076e2f55ea65b1eb54867f3eae57c to your computer and use it in GitHub Desktop.
Square root finder with SGD(momentum + decay) on pytorch
import torch
w = torch.tensor([5.], requires_grad=True)
lr = 0.6
w_mom_grad = 0
momentum_coef = 0.9
weight_decay = 0.001
for i in range(10):
loss = (torch.sqrt(w) - 5)**2
loss.backward()
grad = w.grad + w_mom_grad
w = torch.tensor([w - lr * grad], requires_grad=True)
print(f'Loss: {loss.item():.4f}, w: {w.item():.4f}, Learning Rate: {lr:.4f}')
w_mom_grad = grad * momentum_coef
lr = lr * (1 - weight_decay)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment