Created
November 25, 2018 18:24
-
-
Save oscarknagg/b99cbbb1361c36f0bf1ff54ba92e238e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NShotTaskSampler(Sampler): | |
def __init__(self, | |
dataset: torch.utils.data.Dataset, | |
episodes_per_epoch: int = None, | |
n: int = None, | |
k: int = None, | |
q: int = None, | |
num_tasks: int = 1, | |
fixed_tasks: List[Iterable[int]] = None): | |
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. | |
Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets | |
of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k | |
samples are from the support set while the remaining q * k samples are from the query set. | |
The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples. | |
# Arguments | |
dataset: Instance of torch.utils.data.Dataset from which to draw samples | |
episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch | |
n_shot: int. Number of samples for each class in the n-shot classification tasks. | |
k_way: int. Number of classes in the n-shot classification tasks. | |
q_queries: int. Number query samples for each class in the n-shot classification tasks. | |
num_tasks: Number of n-shot tasks to group into a single batch | |
fixed_tasks: If this argument is specified this Sampler will always generate tasks from | |
the specified classes | |
""" | |
super(NShotTaskSampler, self).__init__(dataset) | |
self.episodes_per_epoch = episodes_per_epoch | |
self.dataset = dataset | |
if num_tasks < 1: | |
raise ValueError('num_tasks must be > 1.') | |
self.num_tasks = num_tasks | |
self.k = k | |
self.n = n | |
self.q = q | |
self.fixed_tasks = fixed_tasks | |
self.i_task = 0 | |
def __len__(self): | |
return self.episodes_per_epoch | |
def __iter__(self): | |
for _ in range(self.episodes_per_epoch): | |
batch = [] | |
for task in range(self.num_tasks): | |
if self.fixed_tasks is None: | |
# Get random classes | |
episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False) | |
else: | |
# Loop through classes in fixed_tasks | |
episode_classes = self.fixed_tasks[self.i_task % len(self.fixed_tasks)] | |
self.i_task += 1 | |
df = self.dataset.df[self.dataset.df['class_id'].isin(episode_classes)] | |
support_k = {k: None for k in episode_classes} | |
for k in episode_classes: | |
# Select support examples | |
support = df[df['class_id'] == k].sample(self.n) | |
support_k[k] = support | |
for i, s in support.iterrows(): | |
batch.append(s['id']) | |
for k in episode_classes: | |
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q) | |
for i, q in query.iterrows(): | |
batch.append(q['id']) | |
yield np.stack(batch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment