Created
October 21, 2019 17:35
-
-
Save unixpickle/87cc708fc09eae287a790975a09f3b0d to your computer and use it in GitHub Desktop.
MAML v2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn.functional as F | |
def maml_grad(model, inputs, outputs, lr, batch=1, checkpoint=False): | |
""" | |
Update the model gradient using MAML. | |
""" | |
params = list(model.parameters()) | |
device = params[0].device | |
batches = list(_split_batches(inputs.to(device), outputs.to(device), batch)) | |
if outputs.dtype.is_floating_point: | |
loss_fn = F.binary_cross_entropy_with_logits | |
else: | |
loss_fn = F.cross_entropy | |
if not checkpoint: | |
gradient, losses = _maml_grad(model, batches, lr, loss_fn, | |
[torch.zeros_like(p) for p in params]) | |
else: | |
gradient, losses = _checkpointed_maml_grad(model, batches, lr, loss_fn) | |
for p, g in zip(params, gradient): | |
if p.grad is None: | |
p.grad = g | |
else: | |
p.grad.add_(g) | |
return losses | |
def _split_batches(inputs, outputs, batch): | |
for i in range(0, inputs.shape[0], batch): | |
yield (inputs[i:i+batch], outputs[i:i+batch]) | |
def _checkpointed_maml_grad(model, batches, lr, loss_fn): | |
params = list(model.parameters()) | |
interval = int(math.sqrt(len(batches))) | |
checkpoints = [] | |
scalar_losses = [] | |
for i, (x, y) in enumerate(batches): | |
if i % interval == 0: | |
checkpoints.append([p.clone().detach() for p in params]) | |
out = model(x) | |
loss = loss_fn(out, y) | |
scalar_losses.append(loss.item()) | |
grads = torch.autograd.grad(loss, params) | |
for p, g in zip(params, grads): | |
p.data.add_(-lr * g) | |
gradient = [torch.zeros_like(p) for p in params] | |
for i in list(range(0, len(batches), interval))[::-1]: | |
checkpoint = checkpoints[i // interval] | |
for p, v in zip(params, checkpoint): | |
p.data.copy_(v) | |
gradient, _ = _maml_grad(model, batches[i:i+interval], lr, loss_fn, gradient) | |
return gradient, scalar_losses | |
def _maml_grad(model, batches, lr, loss_fn, grad_outputs): | |
params = list(model.parameters()) | |
initial_values = [] | |
final_values = [] | |
loss_grads = [] | |
scalar_losses = [] | |
for x, y in batches: | |
out = model(x) | |
loss = loss_fn(out, y) | |
scalar_losses.append(loss.item()) | |
initial_values.append([p.clone().detach() for p in params]) | |
grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True) | |
loss_grads.append([g.detach() for g in grads]) | |
updated = [] | |
for grad, param in zip(grads, params): | |
x = param - lr * grad | |
updated.append(x) | |
param.data.copy_(x) | |
final_values.append(updated) | |
gradient = grad_outputs | |
for loss_grad, initial, final in list(zip(loss_grads, initial_values, final_values))[::-1]: | |
for p, x in zip(params, initial): | |
p.data.copy_(x) | |
future_grad = torch.autograd.grad(final, params, grad_outputs=gradient, retain_graph=True) | |
gradient = [v1 + v2 for v1, v2 in zip(loss_grad, future_grad)] | |
return gradient, scalar_losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pytest | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from maml import maml_grad | |
@pytest.mark.parametrize('checkpoint', [False, True]) | |
def test_maml_grad(checkpoint): | |
model = nn.Sequential( | |
nn.Linear(3, 4), | |
nn.Tanh(), | |
nn.Linear(4, 3), | |
nn.Tanh(), | |
nn.Linear(3, 1), | |
) | |
# More precision for gradient checking. | |
model.to(torch.double) | |
inputs = torch.from_numpy(np.array([[1.0, 2.0, -0.5], [0.5, 1, -1], [0, 1, 2]])) | |
outputs = torch.from_numpy(np.array([[0.9], [0.2], [0.4]])) | |
# Make sure single-step gradients are correct | |
# without any numerical approximation. | |
exact_grads = _exact_maml_grad(model, inputs[:1], outputs[:1], 0.01, checkpoint) | |
step_grads = _one_step_maml_grad(model, inputs[:1], outputs[:1], 0.01) | |
for i, (ex, ap) in enumerate(zip(exact_grads, step_grads)): | |
assert np.allclose(ex, ap, rtol=1e-4, atol=1e-4) | |
exact_grads = _exact_maml_grad(model, inputs, outputs, 0.01, checkpoint) | |
approx_grads = _numerical_maml_grad(model, inputs, outputs, 0.01) | |
for ex, ap in zip(exact_grads, approx_grads): | |
assert np.allclose(ex, ap, rtol=1e-4, atol=1e-4) | |
def _exact_maml_grad(model, inputs, outputs, lr, checkpoint): | |
for p in model.parameters(): | |
p.grad = None | |
maml_grad(model, inputs, outputs, lr, checkpoint=checkpoint) | |
return [p.grad.numpy().copy() for p in model.parameters()] | |
def _one_step_maml_grad(model, inputs, outputs, lr): | |
for p in model.parameters(): | |
p.grad = None | |
F.binary_cross_entropy_with_logits(model(inputs), outputs).backward() | |
return [p.grad.numpy() for p in model.parameters()] | |
def _numerical_maml_grad(model, inputs, outputs, lr, delta=1e-4): | |
grad = [] | |
for p in model.parameters(): | |
param_grad = [] | |
np_value = p.detach().numpy() | |
flat_np = np_value.reshape([-1]) | |
for i, x in enumerate(flat_np): | |
flat_np[i] = x - delta | |
p.data.copy_(torch.from_numpy(np_value).to(p.device)) | |
loss1 = _numerical_maml_loss(model, inputs, outputs, lr) | |
flat_np[i] = x + delta | |
p.data.copy_(torch.from_numpy(np_value).to(p.device)) | |
loss2 = _numerical_maml_loss(model, inputs, outputs, lr) | |
flat_np[i] = x | |
p.data.copy_(torch.from_numpy(np_value).to(p.device)) | |
param_grad.append((loss2 - loss1) / (2 * delta)) | |
grad.append(np.array(param_grad, dtype=np.float64).reshape(p.shape)) | |
return grad | |
def _numerical_maml_loss(model, inputs, outputs, lr): | |
backup = [p.data.clone() for p in model.parameters()] | |
opt = optim.SGD(model.parameters(), lr) | |
losses = [] | |
for i in range(inputs.shape[0]): | |
x, y = inputs[i:i+1], outputs[i:i+1] | |
loss = F.binary_cross_entropy_with_logits(model(x), y) | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
losses.append(loss.item()) | |
for p, b in zip(model.parameters(), backup): | |
p.data.copy_(b) | |
return np.sum(losses) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment