Last active
April 8, 2020 06:57
Semi-supervised iterator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, TensorDataset | |
from baal.active import ActiveLearningDataset | |
class AlternateIterator: | |
def __init__(self, dl_1: DataLoader, dl_2: DataLoader, num_steps: Optional[int] = None, | |
p: Optional[float] = None): | |
""" | |
Create an iterator that will alternate between two dataloaders. | |
Args: | |
dl_1 (DataLoader): first DataLoader | |
dl_2 (DataLoader): second DataLoader | |
num_steps (Optional[int]): Number of steps, if None will be the sum of both length. | |
p (Optional[float]): Probability of choosing dl_1 over dl_2. | |
If None, will be alternate between the two. | |
""" | |
self.dl_1 = iter(dl_1) | |
self.len_dl1 = len(dl_1) | |
self.dl_2 = iter(dl_2) | |
self.len_dl2 = len(dl_2) | |
self.num_steps = num_steps or (self.len_dl1 + self.len_dl2) | |
self.p = None if p is None else [p, 1 - p] | |
self._pool = None | |
self._iter_idx = None | |
def _make_index(self): | |
if self.p is None: | |
# If p is None, we just alternate. | |
arr = np.array([i % 2 for i in range(self.num_steps)]) | |
else: | |
arr = np.random.choice([0, 1], self.num_steps, p=self.p) | |
yield from arr | |
def __len__(self): | |
return self.num_steps | |
def __iter__(self): | |
self._iter_idx = self._make_index() | |
return self | |
def __next__(self): | |
for idx in self._iter_idx: | |
if idx == 0: | |
return next(self.dl_1), idx | |
else: | |
return next(self.dl_2), idx | |
raise StopIteration | |
class SemiSupervisedIterator(AlternateIterator): | |
def __init__(self, al_dataset: ActiveLearningDataset, num_steps=None, p=0.5): | |
self.al_dataset = al_dataset | |
active_dl = DataLoader(al_dataset, batch_size=3, shuffle=True, num_workers=4) | |
pool_dl = DataLoader(al_dataset.pool, batch_size=3, shuffle=True, num_workers=4) | |
super().__init__(dl_1=active_dl, dl_2=pool_dl, num_steps=num_steps, p=p) | |
def main(): | |
d1 = TensorDataset(torch.randn(3000, 100)) | |
d2 = TensorDataset(torch.randn(5000, 100)) | |
dl_1 = DataLoader(d1, batch_size=3) | |
dl_2 = DataLoader(d2, batch_size=3) | |
it = AlternateIterator(dl_1, dl_2, 100, p=0.2) | |
for i_, ((batch,), idx) in enumerate(it): | |
print(i_, ":", batch.shape, idx) | |
# Can be reused! | |
for i_, ((batch,), idx) in enumerate(it): | |
print(i_, ":", batch.shape, idx) | |
# Can concatenate both | |
it2 = AlternateIterator(dl_1, dl_2) | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
print("Active learning!") | |
al_dataset = ActiveLearningDataset(d1) | |
al_dataset.label_randomly(100) | |
it2 = SemiSupervisedIterator(al_dataset, 100, p=0.2) | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
# Can be reused | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I like this approach; it's straightforward to balance two unequal datasets. By saving the shuffle interleaving, it's repeatable.
Maybe default num_steps should be the sum of entries in the two datasets, and default p should be such that the two datasets can be interleaved perfectly. For that to work something other than np.random.choice might be needed, maybe shuffle 0's and 1's with the respective counts?
In the case where num_steps doesn't iterate the whole of one of the datasets (it effectively truncates it), should it be assumed that the datasets are shuffled and chosen from uniformly? or should that be done here?