-
-
Save joschu/f503500cda64f2ce87c8288906b09e2d to your computer and use it in GitHub Desktop.
Reptile algorithm on sine waves
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch import nn, autograd as ag | |
import matplotlib.pyplot as plt | |
from copy import deepcopy | |
seed = 0 | |
plot = True | |
innerstepsize = 0.02 # stepsize in inner SGD | |
innerepochs = 1 # number of epochs of each inner SGD | |
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization | |
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it | |
rng = np.random.RandomState(seed) | |
torch.manual_seed(seed) | |
# Define task distribution | |
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points | |
ntrain = 10 # Size of training minibatches | |
def gen_task(): | |
"Generate classification problem" | |
phase = rng.uniform(low=0, high=2*np.pi) | |
ampl = rng.uniform(0.1, 5) | |
f_randomsine = lambda x : np.sin(x + phase) * ampl | |
return f_randomsine | |
# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results | |
model = nn.Sequential( | |
nn.Linear(1, 64), | |
nn.Tanh(), | |
nn.Linear(64, 64), | |
nn.Tanh(), | |
nn.Linear(64, 1), | |
) | |
def totorch(x): | |
return ag.Variable(torch.Tensor(x)) | |
def train_on_batch(x, y): | |
x = totorch(x) | |
y = totorch(y) | |
model.zero_grad() | |
ypred = model(x) | |
loss = (ypred - y).pow(2).mean() | |
loss.backward() | |
for param in model.parameters(): | |
param.data -= innerstepsize * param.grad.data | |
def predict(x): | |
x = totorch(x) | |
return model(x).data.numpy() | |
# Choose a fixed task and minibatch for visualization | |
f_plot = gen_task() | |
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)] | |
# Reptile training loop | |
for iteration in range(niterations): | |
weights_before = deepcopy(model.state_dict()) | |
# Generate task | |
f = gen_task() | |
y_all = f(x_all) | |
# Do SGD on this task | |
inds = rng.permutation(len(x_all)) | |
for _ in range(innerepochs): | |
for start in range(0, len(x_all), ntrain): | |
mbinds = inds[start:start+ntrain] | |
train_on_batch(x_all[mbinds], y_all[mbinds]) | |
# Interpolate between current weights and trained weights from this task | |
# I.e. (weights_before - weights_after) is the meta-gradient | |
weights_after = model.state_dict() | |
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule | |
model.load_state_dict({name : | |
weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize | |
for name in weights_before}) | |
# Periodically plot the results on a particular task and minibatch | |
if plot and iteration==0 or (iteration+1) % 1000 == 0: | |
plt.cla() | |
f = f_plot | |
weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation | |
plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1)) | |
for inneriter in range(32): | |
train_on_batch(xtrain_plot, f(xtrain_plot)) | |
if (inneriter+1) % 8 == 0: | |
frac = (inneriter+1) / 32 | |
plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac)) | |
plt.plot(x_all, f(x_all), label="true", color=(0,1,0)) | |
lossval = np.square(predict(x_all) - f(x_all)).mean() | |
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k") | |
plt.ylim(-4,4) | |
plt.legend(loc="lower right") | |
plt.pause(0.01) | |
model.load_state_dict(weights_before) # restore from snapshot | |
print(f"-----------------------------") | |
print(f"iteration {iteration+1}") | |
print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment