Skip to content

Instantly share code, notes, and snippets.

@Alvtron
Last active November 22, 2023 12:37
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Alvtron/9b9c2f870df6a54fda24dbd1affdc254 to your computer and use it in GitHub Desktop.
Save Alvtron/9b9c2f870df6a54fda24dbd1affdc254 to your computer and use it in GitHub Desktop.
Split a PyTorch Dataset into two subsets using stratified random sampling.

Stratified dataset split in PyTorch

When working with imbalanced data for machine learning tasks in PyTorch, and simple random split might not be able to partly divide classes that are not well represented. Resulting sample splits might not portray the real-world population, leading to poor predictive peformance in the resulting model.

Therefore, I have created a simple function for conducting a stratified split with random shuffling, similar to that of StratifiedShuffleSplit from scikit-learn (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)

import random
import math
import torch.utils.data
from collections import defaultdict

def stratified_split(dataset : torch.utils.data.Dataset, labels, fraction, random_state=None):
    if random_state: random.seed(random_state)
    indices_per_label = defaultdict(list)
    for index, label in enumerate(labels):
        indices_per_label[label].append(index)
    first_set_indices, second_set_indices = list(), list()
    for label, indices in indices_per_label.items():
        n_samples_for_label = round(len(indices) * fraction)
        random_indices_sample = random.sample(indices, n_samples_for_label)
        first_set_indices.extend(random_indices_sample)
        second_set_indices.extend(set(indices) - set(random_indices_sample))
    first_set_inputs = torch.utils.data.Subset(dataset, first_set_indices)
    first_set_labels = list(map(labels.__getitem__, first_set_indices))
    second_set_inputs = torch.utils.data.Subset(dataset, second_set_indices)
    second_set_labels = list(map(labels.__getitem__, second_set_indices))
    return first_set_inputs, first_set_labels, second_set_inputs, second_set_labels

The function splits a provided PyTorch Dataset object into two PyTorch Subset objects using stratified random sampling. The fraction-parameter must be a float value (0.0 < fraction < 1.0) that is the decimal percentage of the first resulting subset.

For example, given a set of 100 samples, a fraction of 0.75 will return two stratified subsets of length 75 and 25 samples respectively.

For a different example with a fraction of 2/3, a dataset of 100 samples will be split into two stratified subsets of length 67 and 33 samples respectively.

Example - CreditCardFraud dataset

The code below demonstrates how to split the Credit Card Fraud dataset (https://www.kaggle.com/mlg-ulb/creditcardfraud) into a training set, validation set and testing set in PyTorch. This dataset consists of 284,807 samples.

df = pandas.read_csv('./data/CreditCardFraud/creditcard.csv')
inputs = df.iloc[:, :-1].values
labels = df.iloc[:, -1].values
sc = sklearn.preprocessing.StandardScaler()
torch_inputs = sc.fit_transform(inputs)
torch_inputs = torch.from_numpy(torch_inputs).float()
torch_labels = torch.from_numpy(labels).float()
dataset = torch.utils.data.TensorDataset(torch_inputs, torch_labels)
# split dataset into training-, testing- and validation set
train_data, train_labels, test_data, _ = stratified_split(dataset, labels, fraction=0.9, random_state=1)
train_data, train_labels, eval_data, _ = stratified_split(train_data, train_labels, fraction=0.9, random_state=1)

This yields the following datasets:

Train data length: 230,695
Eval data length: 25,632
Test data length: 28,480
Total data lenth: 284,807
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment