-
-
Save luis-mueller/f23f483c405b0a169bf279f7b02209bc to your computer and use it in GitHub Desktop.
From MAML - Actual MAML: Given a tensorflow model and a batch of samples from different tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fastWeights(model, weights, input): | |
output = input | |
for layerIndex in range(len(model.layers)): | |
kernel = weights[layerIndex * 2] | |
bias = weights[layerIndex * 2 + 1] | |
output = model.layers[layerIndex].activation(output @ kernel + bias) | |
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def updateMAML(model, optimizer, inner_lr, batch): | |
def taskLoss(batch): | |
y_train, x_train, y_test, x_test = batch | |
with tf.GradientTape() as taskTape: | |
loss = mse(y_train, model(x_train)) # or cross-entropy, or ... | |
grads = taskTape.gradient(loss, model.trainable_weights) | |
weights = [w - inner_lr * g for g, w in zip(grads, model.trainable_weights)] | |
return mse(y_test, fastWeights(model, weights, x_test)) # or cross-entropy, or ... | |
with tf.GradientTape() as tape: | |
batchLoss = tf.map_fn(taskLoss, elems=batch, | |
fn_output_signature=tf.float32) | |
loss = tf.reduce_sum(batchLoss) | |
optimizer.minimize(loss, model.trainable_weights, tape=tape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Regarding the design choice of implementing
fastWeights
: We are obviously working with nested instances oftf.GradientTape
. What we don't want to do is to apply the gradients intaskLoss
to themodel
itself, because we want to keep the current parameters. Further, copying the model and optimizing on that appears to disconnect the computation graph, resulting in empty gradients (also see this issue). So thefastWeights
implementation works directly on the weights ofmodel
but applies the input tensor immediately, thereby avoiding both of the above problems. It comes at the cost of requiring this explicit implementation of the optimization algorithm. If we e.g. wanted to use Adam as the inner optimizer, we would have to implement that update scheme ourselves.