-
-
Save alwc/62155ea2f6f026f0cc939b5c6f5dd2c1 to your computer and use it in GitHub Desktop.
Semi-supervised iterator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, TensorDataset | |
from baal.active import ActiveLearningDataset | |
class AlternateIterator: | |
def __init__(self, dl_1: DataLoader, dl_2: DataLoader, num_steps: Optional[int] = None, | |
p: Optional[float] = None): | |
""" | |
Create an iterator that will alternate between two dataloaders. | |
Args: | |
dl_1 (DataLoader): first DataLoader | |
dl_2 (DataLoader): second DataLoader | |
num_steps (Optional[int]): Number of steps, if None will be the sum of both length. | |
p (Optional[float]): Probability of choosing dl_1 over dl_2. | |
If None, will be alternate between the two. | |
""" | |
self.dl_1 = iter(dl_1) | |
self.len_dl1 = len(dl_1) | |
self.dl_2 = iter(dl_2) | |
self.len_dl2 = len(dl_2) | |
self.num_steps = num_steps or (self.len_dl1 + self.len_dl2) | |
self.p = None if p is None else [p, 1 - p] | |
self._pool = None | |
self._iter_idx = None | |
def _make_index(self): | |
if self.p is None: | |
# If p is None, we just alternate. | |
arr = np.array([i % 2 for i in range(self.num_steps)]) | |
else: | |
arr = np.random.choice([0, 1], self.num_steps, p=self.p) | |
yield from arr | |
def __len__(self): | |
return self.num_steps | |
def __iter__(self): | |
self._iter_idx = self._make_index() | |
return self | |
def __next__(self): | |
for idx in self._iter_idx: | |
if idx == 0: | |
return next(self.dl_1), idx | |
else: | |
return next(self.dl_2), idx | |
raise StopIteration | |
class SemiSupervisedIterator(AlternateIterator): | |
def __init__(self, al_dataset: ActiveLearningDataset, num_steps=None, p=0.5): | |
self.al_dataset = al_dataset | |
active_dl = DataLoader(al_dataset, batch_size=3, shuffle=True, num_workers=4) | |
pool_dl = DataLoader(al_dataset.pool, batch_size=3, shuffle=True, num_workers=4) | |
super().__init__(dl_1=active_dl, dl_2=pool_dl, num_steps=num_steps, p=p) | |
def main(): | |
d1 = TensorDataset(torch.randn(3000, 100)) | |
d2 = TensorDataset(torch.randn(5000, 100)) | |
dl_1 = DataLoader(d1, batch_size=3) | |
dl_2 = DataLoader(d2, batch_size=3) | |
it = AlternateIterator(dl_1, dl_2, 100, p=0.2) | |
for i_, ((batch,), idx) in enumerate(it): | |
print(i_, ":", batch.shape, idx) | |
# Can be reused! | |
for i_, ((batch,), idx) in enumerate(it): | |
print(i_, ":", batch.shape, idx) | |
# Can concatenate both | |
it2 = AlternateIterator(dl_1, dl_2) | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
print("Active learning!") | |
al_dataset = ActiveLearningDataset(d1) | |
al_dataset.label_randomly(100) | |
it2 = SemiSupervisedIterator(al_dataset, 100, p=0.2) | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
# Can be reused | |
for i_, ((batch,), idx) in enumerate(it2): | |
print(i_, ":", batch.shape, idx) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment