Skip to content

Instantly share code, notes, and snippets.

@dboyliao
Last active January 29, 2023 15:48
Show Gist options
  • Save dboyliao/1a66ee4cf4b825c16a75c439ded7beac to your computer and use it in GitHub Desktop.
Save dboyliao/1a66ee4cf4b825c16a75c439ded7beac to your computer and use it in GitHub Desktop.
Simple Example: Solving Lagrange Multiplier with PyTorch
import torch
x = torch.tensor(0, requires_grad=True, dtype=torch.float64)
y = torch.tensor(0, requires_grad=True, dtype=torch.float64)
l = torch.tensor(0, requires_grad=True, dtype=torch.float64)
lr = 0.1
# min x^2+y^2 s.t x+y = 1
for i in range(100):
L = x**2 + y**2 + l*(1-x-y)
L.backward()
x.data.add_(-lr*x.grad.data)
y.data.add_(-lr*y.grad.data)
x.grad.detach_()
x.grad.zero_()
y.grad.detach_()
y.grad.zero_()
l.grad.detach_()
l.grad.zero_()
L = x**2 + y**2 + l*(1-x-y)
L.backward()
l.data.add_(lr*l.grad.data)
x.grad.detach_()
x.grad.zero_()
y.grad.detach_()
y.grad.zero_()
l.grad.detach_()
l.grad.zero_()
print(x, y, x+y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment