Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created March 12, 2024 07:03
Show Gist options
  • Save pszemraj/8f10d362bdb56329532bf31c4df821a5 to your computer and use it in GitHub Desktop.
Save pszemraj/8f10d362bdb56329532bf31c4df821a5 to your computer and use it in GitHub Desktop.
hf datasets train_test_split with stratify_by_column for any type (by tricking it)
import os
import numpy as np
from datasets import ClassLabel, Dataset, DatasetDict
def split_dataset(
dataset: Dataset,
test_size=0.025,
validation_size=0.025,
stratify_by_column: str = None,
):
"""
Splits a dataset into training, testing, and validation sets with optional stratification.
Parameters:
- dataset: The dataset to split, assumed to be a Hugging Face dataset object.
- test_size: The proportion of the dataset to allocate to the test set.
- validation_size: The proportion of the dataset to allocate to the validation set.
- stratify_by_column: The column name to stratify by.
Returns:
- A DatasetDict with keys 'train', 'test', and 'validation', each corresponding to the respective dataset split.
"""
if (
stratify_by_column
and dataset.features[stratify_by_column].dtype != "ClassLabel"
):
# Convert the stratify column to integer labels if not already ClassLabel
unique_values = sorted(set(dataset[stratify_by_column]))
value_to_int = {v: i for i, v in enumerate(unique_values)}
tmp_stratify_col = f"{stratify_by_column}-ClassLabel"
dataset = dataset.map(
lambda examples: {tmp_stratify_col: value_to_int[examples[stratify_by_column]]},
load_from_cache_file=False,
num_proc=os.cpu_count(),
)
dataset = dataset.cast_column(
tmp_stratify_col, ClassLabel(num_classes=len(unique_values), names=unique_values)
)
else:
tmp_stratify_col = None
nontrain_size = test_size + validation_size
if nontrain_size >= 1:
raise ValueError(
"The combined size of test and validation sets must be less than 1."
)
train_test_split = dataset.train_test_split(
test_size=nontrain_size,
stratify_by_column=tmp_stratify_col,
)
train_set = train_test_split["train"]
non_train_set = train_test_split["test"]
temp_test_proportion = test_size / nontrain_size
test_validation_split = non_train_set.train_test_split(
test_size=temp_test_proportion,
stratify_by_column=tmp_stratify_col,
)
test_set = test_validation_split["test"]
validation_set = test_validation_split["train"]
split_ds = DatasetDict(
{
"train": train_set,
"test": test_set,
"validation": validation_set,
}
)
if tmp_stratify_col:
split_ds = split_ds.remove_columns(tmp_stratify_col)
return split_ds
ds_u_split = split_dataset(
ds_unique, test_size=0.025, validation_size=0.025, stratify_by_column="year"
)
ds_u_split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment