Skip to content

Instantly share code, notes, and snippets.

@machinaut
Last active March 13, 2018 17:50
Show Gist options
  • Save machinaut/2547b88bdbfc89f2f2e1782df491979b to your computer and use it in GitHub Desktop.
Save machinaut/2547b88bdbfc89f2f2e1782df491979b to your computer and use it in GitHub Desktop.
reptile in tensorflow
#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
def get_model(hidden_units=[64, 64],
activation=tf.tanh,
inner_optimizer=tf.train.GradientDescentOptimizer,
inner_learning_rate=0.01,
outer_optimizer=tf.train.AdamOptimizer,
outer_learning_rate=0.001):
# Input Layer
x = tf.placeholder(tf.float32, shape=[None, 1], name='x')
# Hidden Layers
net = tf.identity(x)
for i, units in enumerate(hidden_units):
net = tf.layers.dense(net, units=units, activation=activation,
name='hidden%d' % i)
# Output Layer
y = tf.layers.dense(net, units=1, name='y')
# Labels
y_ = tf.placeholder(y.dtype, shape=y.shape, name='y_')
# Loss
loss = tf.losses.mean_squared_error(labels=y_, predictions=y)
# Inner Loop
in_opt = inner_optimizer(learning_rate=inner_learning_rate, name='in_opt')
in_train = in_opt.minimize(loss, name='in_train')
# Variable assignment operation
assign_dict = {} # map: variable -> placeholder for assignment
assign_ops = [] # list of assignments
for v in tf.trainable_variables():
assign_dict[v] = tf.placeholder(v.dtype, shape=v.shape)
assign_ops.append(v.assign(assign_dict[v]))
def assign(phi, sess):
sess.run(assign_ops, feed_dict={
assign_dict[k]: v for k, v in phi.items()})
# Outer Loop Gradient Update
out_opt = outer_optimizer(learning_rate=outer_learning_rate)
update_dict = {}
grads_vars = []
for v in tf.trainable_variables():
update_dict[v] = tf.placeholder(v.dtype, shape=v.shape)
grads_vars.append((update_dict[v], v))
update_op = out_opt.apply_gradients(grads_vars)
def update(phi, W, sess):
assign(phi, sess)
sess.run(update_op, feed_dict={
update_dict[v]: phi[v] - W[v] for v in phi.keys()})
return (x, y, y_, loss, in_train, assign, update)
def get_task():
a = np.random.uniform(.1, 5.)
b = np.random.uniform(0, np.pi * 2)
return lambda x: a * np.sin(x + b)
def get_params(sess=None):
return {v: sess.run(v) for v in tf.trainable_variables()}
def SGD(tau, x, y_, train,
epochs=10,
batch_size=10,
sess=None):
for _ in range(epochs):
data_x = np.random.uniform(-5, 5, size=(batch_size, 1))
sess.run(train, feed_dict={x: data_x, y_: tau(data_x)})
return get_params(sess=sess)
def render(step, x, y, y_, train, assign, tests, sess=None):
if getattr(render, 'fig', None) is None:
plt.ion()
f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row')
render.fig = f
render.axes = (a1, a2, a3, a4)
full_x = np.linspace(-10, 10, 100).reshape(-1, 1)
phi = get_params(sess=sess)
plt.suptitle('iteration %d' % step)
for i, (test, ax) in enumerate(zip(tests, render.axes)):
assign(phi, sess=sess)
test_preds = {}
test_preds[0] = sess.run(y, feed_dict={x: full_x})
for j in range(32):
SGD(test, x, y_, train, epochs=1, sess=sess)
test_preds[j + 1] = sess.run(y, feed_dict={x: full_x})
ax.cla()
ax.set_xlim([-10, 10])
ax.set_ylim([-5, 5])
ax.axvline(x=-5, color='k')
ax.axvline(x=5, color='k')
ax.set_title('test %d' % i)
ax.plot(full_x, test(full_x), label='true', color=(0, 1, 0))
for j, test_pred in test_preds.items():
color = (j / 32, 0, 1 - (j / 32), 0.5)
label = 'after %d' % j if j % 4 == 0 else None
ax.plot(full_x, test_pred, label=label, color=color)
plt.pause(0.01)
def main(outer_epochs=30000,
step_test=500,
render_test=True):
x, y, y_, loss, in_train, assign, update = get_model()
tests = [get_task() for _ in range(4)]
test_x = np.linspace(-5, 5, 50).reshape(-1, 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(outer_epochs):
phi = get_params(sess=sess)
if i % step_test == 0 or i == outer_epochs - 1:
if render_test:
render(i, x, y, y_, in_train, assign, tests, sess=sess)
print('outer', i)
for j, test in enumerate(tests):
assign(phi, sess=sess)
score = sess.run(loss, feed_dict={x: test_x, y_: test(test_x)})
print('test%d %0.3f' % (j, score), end=' ')
print()
assign(phi, sess=sess)
W = SGD(get_task(), x, y_, in_train, sess=sess)
update(phi, W, sess=sess)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment