from torch.utils.data.sampler import Sampler
class BalanceClassSampler(Sampler):
"""Abstraction over data sampler.
Allows you to create stratified sample on unbalanced classes.
"""
def __init__(self, labels: List[int], mode: str = "downsampling"):
"""
Args:
labels (List[int]): list of class label
for each elem in the datasety
mode (str): Strategy to balance classes.
Must be one of [downsampling, upsampling]
"""
super().__init__(labels)
labels = np.array(labels)
samples_per_class = {
label: (labels == label).sum() for label in set(labels)
}
self.lbl2idx = {
label: np.arange(len(labels))[labels == label].tolist()
for label in set(labels)
}
if isinstance(mode, str):
assert mode in ["downsampling", "upsampling"]
if isinstance(mode, int) or mode == "upsampling":
samples_per_class = (
mode
if isinstance(mode, int)
else max(samples_per_class.values())
)
else:
samples_per_class = min(samples_per_class.values())
self.labels = labels
self.samples_per_class = samples_per_class
self.length = self.samples_per_class * len(set(labels))
def __iter__(self) -> Iterator[int]:
"""
Yields:
indices of stratified sample
"""
indices = []
for key in sorted(self.lbl2idx):
replace_ = self.samples_per_class > len(self.lbl2idx[key])
indices += np.random.choice(
self.lbl2idx[key], self.samples_per_class, replace=replace_
).tolist()
assert len(indices) == self.length
np.random.shuffle(indices)
return iter(indices)
def __len__(self) -> int:
"""
Returns:
length of result sample
"""
return self.length
Last active
June 29, 2020 09:45
-
-
Save nlpjoe/56d7828e3520572109d54c4f0d4aadb4 to your computer and use it in GitHub Desktop.
[class balance] #pytorch #ml
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment