Skip to content

Instantly share code, notes, and snippets.

@oscarknagg
Last active September 23, 2021 09:52
Show Gist options
  • Save oscarknagg/3ee481d979c302b9a8a792d86719c734 to your computer and use it in GitHub Desktop.
Save oscarknagg/3ee481d979c302b9a8a792d86719c734 to your computer and use it in GitHub Desktop.
Key functionality for Matching Networks (Vinyals et al 2016)
import torch
from torch.nn.utils import clip_grad_norm_
def matching_net_episode(model: Module,
optimiser: Optimizer,
loss_fn: Loss,
x: torch.Tensor,
y: torch.Tensor,
n_shot: int,
k_way: int,
q_queries: int,
distance: str,
fce: bool,
train: bool):
"""Performs a single training episode for a Matching Network.
# Arguments
model: Matching Network to be trained.
optimiser: Optimiser to calculate gradient step from loss
loss_fn: Loss function to calculate between predictions and outputs
x: Input samples of few shot classification task
y: Input labels of few shot classification task
n_shot: Number of examples per class in the support set
k_way: Number of classes in the few shot classification task
q_queries: Number of examples per class in the query set
distance: Distance metric to use when calculating distance between support and query set samples
fce: Whether or not to us fully conditional embeddings
train: Whether (True) or not (False) to perform a parameter update
# Returns
loss: Loss of the Matching Network on this task
y_pred: Predicted class probabilities for the query set on this task
"""
if train:
model.train()
optimiser.zero_grad()
else:
model.eval()
# Embed all samples
embeddings = model.encoder(x)
# Samples are ordered by the NShotWrapper class as follows:
# k lots of n support samples from a particular class
# k lots of q query samples from those classes
support = embeddings[:n_shot * k_way]
queries = embeddings[n_shot * q_queries:]
y_support = y[:n_shot * k_way]
y_queries = y[n_shot * q_queries:]
# Optionally apply full context embeddings
if fce:
# LSTM requires input of shape (seq_len, batch, input_size). `support` is of
# shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the
# support set as a sequence so add a single dimension to transform support set
# to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch
# dimension afterwards
# Calculate the fully conditional embedding, g, for support set samples as
# described in appendix A.2 of the paper. g takes the form of a
# bidirectional LSTM with a skip connection from inputs to outputs
support, _, _ = model.g(support.unsqueeze(1))
support = support.squeeze(1)
# Calculate the fully conditional embedding, f, for the query set samples
# as described in appendix A.1 of the paper.
queries = model.f(support, queries)
# Calculate distance between all queries and all prototypes
# Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
distances = (
queries.unsqueeze(1).expand(queries.shape[0], support.shape[0], -1) -
support.unsqueeze(0).expand(queries.shape[0], support.shape[0], -1)
).pow(2).sum(dim=2)
# Calculate "attention" as softmax over support-query distances
attention = (-distances).softmax(dim=1)
# Calculate predictions as in equation (1) from Matching Networks
# y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i
# Create one-hot encoded label vector for the support set, the
# default PyTorch format is for labels to be integers
y_onehot = torch.zeros(k * n, k)
# Unsqueeze to force y to be 2D as this
# is needed for .scatter()
y_onehot = y_onehot.scatter(1, y_support, 1)
y_pred = torch.mm(attention, y_onehot.cuda().double())
# Calculated loss with negative log likelihood
# Clip predictions for numerical stability
clipped_y_pred = y_pred.clamp(1e-8, 1 - 1e-8)
loss = loss_fn(clipped_y_pred.log(), y_queries)
if train:
# Backpropagate gradients
loss.backward()
# I found training to be quite unstable so I clip the norm
# of the gradient to be at most 1
clip_grad_norm_(model.parameters(), 1)
# Take gradient step
optimiser.step()
return loss, y_pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment