Skip to content

Instantly share code, notes, and snippets.

@pupuis
Last active Jul 30, 2021
Embed
What would you like to do?
From MAML - Actual MAML: Given a tensorflow model and a batch of samples from different tasks
def fastWeights(model, grads, input):
output = input
for layerIndex in range(len(model.layers)):
kernel = model.trainable_weights[layerIndex * 2]
bias = model.trainable_weights[layerIndex * 2 + 1]
output = model.layers[layerIndex].activation(output @ kernel + bias)
return output
import tensorflow as tf
def updateMAML(model, optimizer, batch):
def taskLoss(batch):
y_train, x_train, y_test, x_test = batch
with tf.GradientTape() as taskTape:
loss = mse(y_train, model(x_train)) # or cross-entropy, or ...
grads = taskTape.gradient(loss, model.trainable_weights)
return mse(y_test, fastWeights(model, grads, x_test)) # or cross-entropy, or ...
with tf.GradientTape() as tape:
batchLoss = tf.map_fn(taskLoss, elems=batch,
fn_output_signature=tf.float32)
loss = tf.reduce_sum(batchLoss)
optimizer.minimize(loss, model.trainable_weights, tape=tape)
@pupuis
Copy link
Author

pupuis commented Jul 4, 2021

Regarding the design choice of implementing fastWeights: We are obviously working with nested instances of tf.GradientTape. What we don't want to do is to apply the gradients in taskLoss to the model itself, because we want to keep the current parameters. Further, copying the model and optimizing on that appears to disconnect the computation graph, resulting in empty gradients (also see this issue). So the fastWeights implementation works directly on the weights of model but applies the input tensor immediately, thereby avoiding both of the above problems. It comes at the cost of requiring this explicit implementation of the optimization algorithm. If we e.g. wanted to use Adam as the inner optimizer, we would have to implement that update scheme ourselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment