Skip to content

Instantly share code, notes, and snippets.

@oscarknagg
Last active November 26, 2018 22:04
Show Gist options
  • Save oscarknagg/a42c34ba23078dd9b124765395b496c7 to your computer and use it in GitHub Desktop.
Save oscarknagg/a42c34ba23078dd9b124765395b496c7 to your computer and use it in GitHub Desktop.
Key functionality for Prototypical Networks (Snell et al 2017)
def proto_net_episode(model: Module,
optimiser: Optimizer,
loss_fn: Callable,
x: torch.Tensor,
y: torch.Tensor,
n_shot: int,
k_way: int,
q_queries: int,
distance: str,
train: bool):
"""Performs a single training episode for a Prototypical Network.
# Arguments
model: Prototypical Network to be trained.
optimiser: Optimiser to calculate gradient step
loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy
x: Input samples of few shot classification task
y: Input labels of few shot classification task
n_shot: Number of examples per class in the support set
k_way: Number of classes in the few shot classification task
q_queries: Number of examples per class in the query set
distance: Distance metric to use when calculating distance between class prototypes and queries
train: Whether (True) or not (False) to perform a parameter update
# Returns
loss: Loss of the Prototypical Network on this task
y_pred: Predicted class probabilities for the query set on this task
"""
if train:
model.train()
optimiser.zero_grad()
else:
model.eval()
# Embed all samples
embeddings = model(x)
# Samples are ordered by the NShotWrapper class as follows:
# k lots of n support samples from a particular class
# k lots of q query samples from those classes
support = embeddings[:n_shot*k_way]
queries = embeddings[n_shot*k_way:]
y_support = y[:n_shot * k_way]
y_queries = y[n_shot * q_queries:]
# Reshape so the first dimension indexes by class then take the mean
# along that dimension to generate the "prototypes" for each class
prototypes = support.reshape(k, n, -1).mean(dim=1)
# Calculate squared distances between all queries and all prototypes
# Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
distances = (
queries.unsqueeze(1).expand(queries.shape[0], support.shape[0], -1) -
support.unsqueeze(0).expand(queries.shape[0], support.shape[0], -1)
).pow(2).sum(dim=2)
# Calculate log p_{phi} (y = k | x)
log_p_y = (-distances).log_softmax(dim=1)
loss = loss_fn(log_p_y, y_queries)
# Prediction probabilities are softmax over distances
y_pred = (-distances).softmax(dim=1)
if train:
# Take gradient step
loss.backward()
optimiser.step()
return loss, y_pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment