Skip to content

Instantly share code, notes, and snippets.

@alwc
Forked from Dref360/semisupervised.py
Created April 8, 2020 06:57
Show Gist options
  • Save alwc/62155ea2f6f026f0cc939b5c6f5dd2c1 to your computer and use it in GitHub Desktop.
Save alwc/62155ea2f6f026f0cc939b5c6f5dd2c1 to your computer and use it in GitHub Desktop.
Semi-supervised iterator
from typing import Optional
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from baal.active import ActiveLearningDataset
class AlternateIterator:
def __init__(self, dl_1: DataLoader, dl_2: DataLoader, num_steps: Optional[int] = None,
p: Optional[float] = None):
"""
Create an iterator that will alternate between two dataloaders.
Args:
dl_1 (DataLoader): first DataLoader
dl_2 (DataLoader): second DataLoader
num_steps (Optional[int]): Number of steps, if None will be the sum of both length.
p (Optional[float]): Probability of choosing dl_1 over dl_2.
If None, will be alternate between the two.
"""
self.dl_1 = iter(dl_1)
self.len_dl1 = len(dl_1)
self.dl_2 = iter(dl_2)
self.len_dl2 = len(dl_2)
self.num_steps = num_steps or (self.len_dl1 + self.len_dl2)
self.p = None if p is None else [p, 1 - p]
self._pool = None
self._iter_idx = None
def _make_index(self):
if self.p is None:
# If p is None, we just alternate.
arr = np.array([i % 2 for i in range(self.num_steps)])
else:
arr = np.random.choice([0, 1], self.num_steps, p=self.p)
yield from arr
def __len__(self):
return self.num_steps
def __iter__(self):
self._iter_idx = self._make_index()
return self
def __next__(self):
for idx in self._iter_idx:
if idx == 0:
return next(self.dl_1), idx
else:
return next(self.dl_2), idx
raise StopIteration
class SemiSupervisedIterator(AlternateIterator):
def __init__(self, al_dataset: ActiveLearningDataset, num_steps=None, p=0.5):
self.al_dataset = al_dataset
active_dl = DataLoader(al_dataset, batch_size=3, shuffle=True, num_workers=4)
pool_dl = DataLoader(al_dataset.pool, batch_size=3, shuffle=True, num_workers=4)
super().__init__(dl_1=active_dl, dl_2=pool_dl, num_steps=num_steps, p=p)
def main():
d1 = TensorDataset(torch.randn(3000, 100))
d2 = TensorDataset(torch.randn(5000, 100))
dl_1 = DataLoader(d1, batch_size=3)
dl_2 = DataLoader(d2, batch_size=3)
it = AlternateIterator(dl_1, dl_2, 100, p=0.2)
for i_, ((batch,), idx) in enumerate(it):
print(i_, ":", batch.shape, idx)
# Can be reused!
for i_, ((batch,), idx) in enumerate(it):
print(i_, ":", batch.shape, idx)
# Can concatenate both
it2 = AlternateIterator(dl_1, dl_2)
for i_, ((batch,), idx) in enumerate(it2):
print(i_, ":", batch.shape, idx)
print("Active learning!")
al_dataset = ActiveLearningDataset(d1)
al_dataset.label_randomly(100)
it2 = SemiSupervisedIterator(al_dataset, 100, p=0.2)
for i_, ((batch,), idx) in enumerate(it2):
print(i_, ":", batch.shape, idx)
# Can be reused
for i_, ((batch,), idx) in enumerate(it2):
print(i_, ":", batch.shape, idx)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment