This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NShotTaskSampler(Sampler): | |
def __init__(self, | |
dataset: torch.utils.data.Dataset, | |
episodes_per_epoch: int = None, | |
n: int = None, | |
k: int = None, | |
q: int = None, | |
num_tasks: int = 1, | |
fixed_tasks: List[Iterable[int]] = None): | |
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input, Subtract, Dense, Lambda | |
from keras.models import Model | |
import keras.backend as K | |
def build_siamese_network(encoder, input_shape): | |
input_1 = Input(input_shape) | |
input_2 = Input(input_shape) | |
# `encoder` is any predefined network that maps a single sample | |
# into an embedding space. |
NewerOlder