Skip to content

Instantly share code, notes, and snippets.

Created November 25, 2018 18:24
Show Gist options
  • Save oscarknagg/b99cbbb1361c36f0bf1ff54ba92e238e to your computer and use it in GitHub Desktop.
Save oscarknagg/b99cbbb1361c36f0bf1ff54ba92e238e to your computer and use it in GitHub Desktop.
class NShotTaskSampler(Sampler):
def __init__(self,
episodes_per_epoch: int = None,
n: int = None,
k: int = None,
q: int = None,
num_tasks: int = 1,
fixed_tasks: List[Iterable[int]] = None):
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.
Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
samples are from the support set while the remaining q * k samples are from the query set.
The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.
# Arguments
dataset: Instance of from which to draw samples
episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
n_shot: int. Number of samples for each class in the n-shot classification tasks.
k_way: int. Number of classes in the n-shot classification tasks.
q_queries: int. Number query samples for each class in the n-shot classification tasks.
num_tasks: Number of n-shot tasks to group into a single batch
fixed_tasks: If this argument is specified this Sampler will always generate tasks from
the specified classes
super(NShotTaskSampler, self).__init__(dataset)
self.episodes_per_epoch = episodes_per_epoch
self.dataset = dataset
if num_tasks < 1:
raise ValueError('num_tasks must be > 1.')
self.num_tasks = num_tasks
self.k = k
self.n = n
self.q = q
self.fixed_tasks = fixed_tasks
self.i_task = 0
def __len__(self):
return self.episodes_per_epoch
def __iter__(self):
for _ in range(self.episodes_per_epoch):
batch = []
for task in range(self.num_tasks):
if self.fixed_tasks is None:
# Get random classes
episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False)
# Loop through classes in fixed_tasks
episode_classes = self.fixed_tasks[self.i_task % len(self.fixed_tasks)]
self.i_task += 1
df = self.dataset.df[self.dataset.df['class_id'].isin(episode_classes)]
support_k = {k: None for k in episode_classes}
for k in episode_classes:
# Select support examples
support = df[df['class_id'] == k].sample(self.n)
support_k[k] = support
for i, s in support.iterrows():
for k in episode_classes:
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
for i, q in query.iterrows():
yield np.stack(batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment