Skip to content

Instantly share code, notes, and snippets.

@gbiz123
Created July 28, 2024 18:39
Show Gist options
  • Save gbiz123/27f77712a7c4d26b000c41ef795aa2c2 to your computer and use it in GitHub Desktop.
Save gbiz123/27f77712a7c4d26b000c41ef795aa2c2 to your computer and use it in GitHub Desktop.
Downsample Binary Pytorch Dataset Down To Size Of Smallest Class
import torch
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
import random
def downsample_balance_binary_dataset(dataset: Dataset) -> Dataset:
class_0_count = len([d for d in dataset if d[1] == 0])
class_1_count = len([d for d in dataset if d[1] == 1])
if class_0_count > class_1_count:
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0]
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1]
downsampled_class_0_indeces = class_0_indeces[:class_1_count]
if len(class_1_indeces) != len(downsampled_class_0_indeces):
raise ValueError("Error during downsampling, class_1_indices was not the same as downsampled_class_0_indeces")
all_indices = downsampled_class_0_indeces + class_1_indeces
return Subset(dataset, all_indices)
elif class_1_count > class_0_count:
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0]
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1]
downsampled_class_1_indeces = class_1_indeces[:class_0_count]
all_indices = downsampled_class_1_indeces + class_1_indeces
if len(class_1_indeces) != len(downsampled_class_1_indeces):
raise ValueError("Error during downsampling, class_0_indices was not the same as downsampled_class_0_indeces")
return Subset(dataset, all_indices)
else:
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment