Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created October 21, 2019 17:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/87cc708fc09eae287a790975a09f3b0d to your computer and use it in GitHub Desktop.
Save unixpickle/87cc708fc09eae287a790975a09f3b0d to your computer and use it in GitHub Desktop.
MAML v2
import math
import torch
import torch.nn.functional as F
def maml_grad(model, inputs, outputs, lr, batch=1, checkpoint=False):
"""
Update the model gradient using MAML.
"""
params = list(model.parameters())
device = params[0].device
batches = list(_split_batches(inputs.to(device), outputs.to(device), batch))
if outputs.dtype.is_floating_point:
loss_fn = F.binary_cross_entropy_with_logits
else:
loss_fn = F.cross_entropy
if not checkpoint:
gradient, losses = _maml_grad(model, batches, lr, loss_fn,
[torch.zeros_like(p) for p in params])
else:
gradient, losses = _checkpointed_maml_grad(model, batches, lr, loss_fn)
for p, g in zip(params, gradient):
if p.grad is None:
p.grad = g
else:
p.grad.add_(g)
return losses
def _split_batches(inputs, outputs, batch):
for i in range(0, inputs.shape[0], batch):
yield (inputs[i:i+batch], outputs[i:i+batch])
def _checkpointed_maml_grad(model, batches, lr, loss_fn):
params = list(model.parameters())
interval = int(math.sqrt(len(batches)))
checkpoints = []
scalar_losses = []
for i, (x, y) in enumerate(batches):
if i % interval == 0:
checkpoints.append([p.clone().detach() for p in params])
out = model(x)
loss = loss_fn(out, y)
scalar_losses.append(loss.item())
grads = torch.autograd.grad(loss, params)
for p, g in zip(params, grads):
p.data.add_(-lr * g)
gradient = [torch.zeros_like(p) for p in params]
for i in list(range(0, len(batches), interval))[::-1]:
checkpoint = checkpoints[i // interval]
for p, v in zip(params, checkpoint):
p.data.copy_(v)
gradient, _ = _maml_grad(model, batches[i:i+interval], lr, loss_fn, gradient)
return gradient, scalar_losses
def _maml_grad(model, batches, lr, loss_fn, grad_outputs):
params = list(model.parameters())
initial_values = []
final_values = []
loss_grads = []
scalar_losses = []
for x, y in batches:
out = model(x)
loss = loss_fn(out, y)
scalar_losses.append(loss.item())
initial_values.append([p.clone().detach() for p in params])
grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True)
loss_grads.append([g.detach() for g in grads])
updated = []
for grad, param in zip(grads, params):
x = param - lr * grad
updated.append(x)
param.data.copy_(x)
final_values.append(updated)
gradient = grad_outputs
for loss_grad, initial, final in list(zip(loss_grads, initial_values, final_values))[::-1]:
for p, x in zip(params, initial):
p.data.copy_(x)
future_grad = torch.autograd.grad(final, params, grad_outputs=gradient, retain_graph=True)
gradient = [v1 + v2 for v1, v2 in zip(loss_grad, future_grad)]
return gradient, scalar_losses
import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from maml import maml_grad
@pytest.mark.parametrize('checkpoint', [False, True])
def test_maml_grad(checkpoint):
model = nn.Sequential(
nn.Linear(3, 4),
nn.Tanh(),
nn.Linear(4, 3),
nn.Tanh(),
nn.Linear(3, 1),
)
# More precision for gradient checking.
model.to(torch.double)
inputs = torch.from_numpy(np.array([[1.0, 2.0, -0.5], [0.5, 1, -1], [0, 1, 2]]))
outputs = torch.from_numpy(np.array([[0.9], [0.2], [0.4]]))
# Make sure single-step gradients are correct
# without any numerical approximation.
exact_grads = _exact_maml_grad(model, inputs[:1], outputs[:1], 0.01, checkpoint)
step_grads = _one_step_maml_grad(model, inputs[:1], outputs[:1], 0.01)
for i, (ex, ap) in enumerate(zip(exact_grads, step_grads)):
assert np.allclose(ex, ap, rtol=1e-4, atol=1e-4)
exact_grads = _exact_maml_grad(model, inputs, outputs, 0.01, checkpoint)
approx_grads = _numerical_maml_grad(model, inputs, outputs, 0.01)
for ex, ap in zip(exact_grads, approx_grads):
assert np.allclose(ex, ap, rtol=1e-4, atol=1e-4)
def _exact_maml_grad(model, inputs, outputs, lr, checkpoint):
for p in model.parameters():
p.grad = None
maml_grad(model, inputs, outputs, lr, checkpoint=checkpoint)
return [p.grad.numpy().copy() for p in model.parameters()]
def _one_step_maml_grad(model, inputs, outputs, lr):
for p in model.parameters():
p.grad = None
F.binary_cross_entropy_with_logits(model(inputs), outputs).backward()
return [p.grad.numpy() for p in model.parameters()]
def _numerical_maml_grad(model, inputs, outputs, lr, delta=1e-4):
grad = []
for p in model.parameters():
param_grad = []
np_value = p.detach().numpy()
flat_np = np_value.reshape([-1])
for i, x in enumerate(flat_np):
flat_np[i] = x - delta
p.data.copy_(torch.from_numpy(np_value).to(p.device))
loss1 = _numerical_maml_loss(model, inputs, outputs, lr)
flat_np[i] = x + delta
p.data.copy_(torch.from_numpy(np_value).to(p.device))
loss2 = _numerical_maml_loss(model, inputs, outputs, lr)
flat_np[i] = x
p.data.copy_(torch.from_numpy(np_value).to(p.device))
param_grad.append((loss2 - loss1) / (2 * delta))
grad.append(np.array(param_grad, dtype=np.float64).reshape(p.shape))
return grad
def _numerical_maml_loss(model, inputs, outputs, lr):
backup = [p.data.clone() for p in model.parameters()]
opt = optim.SGD(model.parameters(), lr)
losses = []
for i in range(inputs.shape[0]):
x, y = inputs[i:i+1], outputs[i:i+1]
loss = F.binary_cross_entropy_with_logits(model(x), y)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
for p, b in zip(model.parameters(), backup):
p.data.copy_(b)
return np.sum(losses)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment