Skip to content

Instantly share code, notes, and snippets.

@luis-mueller
Last active October 5, 2022 07:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luis-mueller/f23f483c405b0a169bf279f7b02209bc to your computer and use it in GitHub Desktop.
Save luis-mueller/f23f483c405b0a169bf279f7b02209bc to your computer and use it in GitHub Desktop.
From MAML - Actual MAML: Given a tensorflow model and a batch of samples from different tasks
def fastWeights(model, weights, input):
output = input
for layerIndex in range(len(model.layers)):
kernel = weights[layerIndex * 2]
bias = weights[layerIndex * 2 + 1]
output = model.layers[layerIndex].activation(output @ kernel + bias)
return output
import tensorflow as tf
def updateMAML(model, optimizer, inner_lr, batch):
def taskLoss(batch):
y_train, x_train, y_test, x_test = batch
with tf.GradientTape() as taskTape:
loss = mse(y_train, model(x_train)) # or cross-entropy, or ...
grads = taskTape.gradient(loss, model.trainable_weights)
weights = [w - inner_lr * g for g, w in zip(grads, model.trainable_weights)]
return mse(y_test, fastWeights(model, weights, x_test)) # or cross-entropy, or ...
with tf.GradientTape() as tape:
batchLoss = tf.map_fn(taskLoss, elems=batch,
fn_output_signature=tf.float32)
loss = tf.reduce_sum(batchLoss)
optimizer.minimize(loss, model.trainable_weights, tape=tape)
@luis-mueller
Copy link
Author

Regarding the design choice of implementing fastWeights: We are obviously working with nested instances of tf.GradientTape. What we don't want to do is to apply the gradients in taskLoss to the model itself, because we want to keep the current parameters. Further, copying the model and optimizing on that appears to disconnect the computation graph, resulting in empty gradients (also see this issue). So the fastWeights implementation works directly on the weights of model but applies the input tensor immediately, thereby avoiding both of the above problems. It comes at the cost of requiring this explicit implementation of the optimization algorithm. If we e.g. wanted to use Adam as the inner optimizer, we would have to implement that update scheme ourselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment