When working with imbalanced data for machine learning tasks in PyTorch, and simple random split might not be able to partly divide classes that are not well represented. Resulting sample splits might not portray the real-world population, leading to poor predictive peformance in the resulting model.
Therefore, I have created a simple function for conducting a stratified split with random shuffling, similar to that of StratifiedShuffleSplit from scikit-learn (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)
import random
import math
import torch.utils.data
from collections import defaultdict
def stratified_split(dataset : torch.utils.data.Dataset, labels, fraction, random_state=None):
if random_state: random.seed(random_state)
indices_per_label = defaultdict(list)
for index, label in enumerate(labels):
indices_per_label[label].append(index)
first_set_indices, second_set_indices = list(), list()
for label, indices in indices_per_label.items():
n_samples_for_label = round(len(indices) * fraction)
random_indices_sample = random.sample(indices, n_samples_for_label)
first_set_indices.extend(random_indices_sample)
second_set_indices.extend(set(indices) - set(random_indices_sample))
first_set_inputs = torch.utils.data.Subset(dataset, first_set_indices)
first_set_labels = list(map(labels.__getitem__, first_set_indices))
second_set_inputs = torch.utils.data.Subset(dataset, second_set_indices)
second_set_labels = list(map(labels.__getitem__, second_set_indices))
return first_set_inputs, first_set_labels, second_set_inputs, second_set_labels
The function splits a provided PyTorch Dataset object into two PyTorch Subset objects using stratified random sampling. The fraction-parameter must be a float value (0.0 < fraction < 1.0) that is the decimal percentage of the first resulting subset.
For example, given a set of 100 samples, a fraction of 0.75 will return two stratified subsets of length 75 and 25 samples respectively.
For a different example with a fraction of 2/3, a dataset of 100 samples will be split into two stratified subsets of length 67 and 33 samples respectively.
The code below demonstrates how to split the Credit Card Fraud dataset (https://www.kaggle.com/mlg-ulb/creditcardfraud) into a training set, validation set and testing set in PyTorch. This dataset consists of 284,807 samples.
df = pandas.read_csv('./data/CreditCardFraud/creditcard.csv')
inputs = df.iloc[:, :-1].values
labels = df.iloc[:, -1].values
sc = sklearn.preprocessing.StandardScaler()
torch_inputs = sc.fit_transform(inputs)
torch_inputs = torch.from_numpy(torch_inputs).float()
torch_labels = torch.from_numpy(labels).float()
dataset = torch.utils.data.TensorDataset(torch_inputs, torch_labels)
# split dataset into training-, testing- and validation set
train_data, train_labels, test_data, _ = stratified_split(dataset, labels, fraction=0.9, random_state=1)
train_data, train_labels, eval_data, _ = stratified_split(train_data, train_labels, fraction=0.9, random_state=1)
This yields the following datasets:
Train data length: 230,695
Eval data length: 25,632
Test data length: 28,480
Total data lenth: 284,807