Skip to content

Instantly share code, notes, and snippets.

View oscarknagg's full-sized avatar

Oscar Knagg oscarknagg

View GitHub Profile
class NShotTaskSampler(Sampler):
def __init__(self,
dataset: torch.utils.data.Dataset,
episodes_per_epoch: int = None,
n: int = None,
k: int = None,
q: int = None,
num_tasks: int = 1,
fixed_tasks: List[Iterable[int]] = None):
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.
@oscarknagg
oscarknagg / siamese.py
Last active October 2, 2018 18:43
Demonstration code for a siamese network
from keras.layers import Input, Subtract, Dense, Lambda
from keras.models import Model
import keras.backend as K
def build_siamese_network(encoder, input_shape):
input_1 = Input(input_shape)
input_2 = Input(input_shape)
# `encoder` is any predefined network that maps a single sample
# into an embedding space.