Skip to content

Instantly share code, notes, and snippets.

@takuseno
Last active November 15, 2020 11:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuseno/42b63b6418c26b42d0ddab5d1ac96af3 to your computer and use it in GitHub Desktop.
Save takuseno/42b63b6418c26b42d0ddab5d1ac96af3 to your computer and use it in GitHub Desktop.
TensorFlow version of reptile sample https://blog.openai.com/reptile/
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
mode = 'maml'
seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 if mode == 'reptile' else 0.001 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it
rng = np.random.RandomState(seed)
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
"Generate classification problem"
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(0.1, 5)
f_randomsine = lambda x : np.sin(x + phase) * ampl
return f_randomsine
class Network:
def __init__(self, innerstepsize, mode='reptile'):
self.innerstepsize = innerstepsize
self.mode = mode
self.build()
def network(self, x, scope):
with tf.variable_scope(scope):
output = tf.layers.dense(x, 64)
output = tf.nn.tanh(output)
output = tf.layers.dense(output, 64)
output = tf.nn.tanh(output)
y = tf.layers.dense(output, 1, name='y')
return y
def build(self):
# build network
self.x = tf.placeholder(tf.float32, [None, 1], name='x')
self.y = self.network(self.x, self.mode)
# backup network for outer update
backup_model = self.network(self.x, 'backup')
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, self.mode)
backup_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, 'backup')
self.label = tf.placeholder(tf.float32, [None, 1], name='label')
self.loss = tf.reduce_mean(tf.square(self.y - self.label))
self.gradients = tf.gradients(self.loss, variables)
inner_optimize_expr = []
for var, grad in zip(variables, self.gradients):
inner_optimize_expr.append(
var.assign(var - self.innerstepsize * grad))
self.inner_optimize = tf.group(*inner_optimize_expr)
backup_expr = []
for var, backup_var in zip(variables, backup_variables):
backup_expr.append(backup_var.assign(var))
self.backup_ops = tf.group(*backup_expr)
restore_expr = []
for var, backup_var in zip(variables, backup_variables):
restore_expr.append(var.assign(backup_var))
self.restore_ops = tf.group(*restore_expr)
self.outerstepsize = tf.placeholder(tf.float32, [], name='outerstepsize')
outer_optimize_expr = []
if self.mode == 'reptile':
for var, backup_var in zip(variables, backup_variables):
outer_optimize_expr.append(
var.assign(backup_var + self.outerstepsize * (var - backup_var)))
elif self.mode == 'maml':
for var, backup_var, grad in zip(variables, backup_variables, self.gradients):
outer_optimize_expr.append(
var.assign(backup_var - self.outerstepsize * grad))
self.outer_optimize = tf.group(*outer_optimize_expr)
if self.mode == 'reptile':
self.outer_update = self.outer_update_reptile
elif self.mode == 'maml':
self.outer_update = self.outer_update_maml
def get_session(self):
return tf.get_default_session()
def predict(self, x):
sess = self.get_session()
with sess.as_default():
return sess.run(self.y, feed_dict={self.x: x})
def inner_update(self, x, label):
sess = self.get_session()
with sess.as_default():
feed_dict = {
self.x: x,
self.label: label
}
loss, _ = sess.run(
[self.loss, self.inner_optimize], feed_dict=feed_dict)
return loss
def outer_update_reptile(self, outerstepsize):
sess = self.get_session()
with sess.as_default():
feed_dict = {
self.outerstepsize: outerstepsize
}
sess.run(self.outer_optimize, feed_dict=feed_dict)
def outer_update_maml(self, x, label, outerstepsize):
sess = self.get_session()
with sess.as_default():
feed_dict = {
self.x: x,
self.label: label,
self.outerstepsize: outerstepsize
}
grad, _ = sess.run([self.gradients, self.outer_optimize], feed_dict=feed_dict)
def backup(self):
sess = self.get_session()
with sess.as_default():
sess.run(self.backup_ops)
def restore(self):
sess = self.get_session()
with sess.as_default():
sess.run(self.restore_ops)
def train_on_batch(x, y):
model.inner_update(x, y)
def predict(x):
return model.predict(x)
model = Network(innerstepsize=innerstepsize, mode=mode)
sess = tf.Session()
sess.__enter__()
sess.run(tf.global_variables_initializer())
# Choose a fixed task and minibatch for visualization
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]
# Reptile training loop
for iteration in range(niterations):
model.backup()
# Generate task
f = gen_task()
y_all = f(x_all)
# Do SGD on this task
inds = rng.permutation(len(x_all))
for _ in range(innerepochs):
for start in range(0, len(x_all), ntrain):
mbinds = inds[start:start+ntrain]
train_on_batch(x_all[mbinds], y_all[mbinds])
# Interpolate between current weights and trained weights from this task
# I.e. (weights_before - weights_after) is the meta-gradient
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
if mode == 'reptile':
model.outer_update(outerstepsize)
elif mode == 'maml':
model.outer_update(x_all, y_all, outerstepsize)
model.backup()
# Periodically plot the results on a particular task and minibatch
if plot and iteration==0 or (iteration+1) % 1000 == 0:
plt.cla()
f = f_plot
plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
for inneriter in range(32):
train_on_batch(xtrain_plot, f(xtrain_plot))
if (inneriter+1) % 8 == 0:
frac = (inneriter+1) / 32
plt.plot(x_all, predict(x_all),
label="pred after %i"%(inneriter+1),
color=(frac, 0, 1-frac))
plt.plot(x_all, f(x_all), label="true", color=(0, 1, 0))
lossval = np.square(predict(x_all) - f(x_all)).mean()
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4, 4)
plt.legend(loc="lower right")
plt.pause(0.01)
model.restore() # restore from snapshot
print(f"-----------------------------")
print(f"iteration {iteration+1}")
@takuseno
Copy link
Author

takuseno commented Jun 6, 2018

mode variable switches the algorithm between MAML and Reptile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment