Skip to content

Instantly share code, notes, and snippets.

@LennartPurucker
Created February 5, 2025 16:39
from __future__ import annotations
import logging
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, RepeatedStratifiedKFold, StratifiedKFold
logger = logging.getLogger(__name__)
class UnableToFixTooManySplitsError(Exception):
pass
def _save_stratified_splits(
_splitter: StratifiedKFold | RepeatedStratifiedKFold,
x: np.ndarray | pd.DataFrame,
y: np.ndarray | pd.Series,
n_splits: int,
auto_fix_stratified_splits: bool = False,
) -> list[list[list[int]]]:
"""Fix from AutoGluon to avoid unsafe splits for classification if less than n_splits instances exist for all classes.
https://github.com/autogluon/autogluon/blob/0ab001a1193869a88f7af846723d23245781a1ac/core/src/autogluon/core/utils/utils.py#L70.
"""
try:
splits = [[train_index, test_index] for train_index, test_index in _splitter.split(x, y)]
except ValueError as e:
x = pd.DataFrame(x)
y = pd.Series(y)
y_dummy = pd.concat([y, pd.Series([-1] * n_splits)], ignore_index=True)
X_dummy = pd.concat([x, x.head(n_splits)], ignore_index=True)
invalid_index = set(y_dummy.tail(n_splits).index)
splits = [[train_index, test_index] for train_index, test_index in _splitter.split(X_dummy, y_dummy)]
len_out = len(splits)
for i in range(len_out):
train_index, test_index = splits[i]
splits[i][0] = [index for index in train_index if index not in invalid_index]
splits[i][1] = [index for index in test_index if index not in invalid_index]
# only rais afterward because only now we know that we cannot fix it
if not auto_fix_stratified_splits:
raise UnableToFixTooManySplitsError(
"Cannot split data in a stratifed way with each class in each subset of the data.",
) from e
return [
[
[int(i) for i in (train_index if isinstance(train_index, list) else train_index.tolist())],
[int(i) for i in (test_index if isinstance(test_index, list) else test_index.tolist())],
]
for train_index, test_index in splits
]
def fix_split_by_dropping_classes(
x: np.ndarray,
y: np.ndarray,
n_splits: int,
spliter_kwargs: dict,
) -> list[list[list[int], list[int]]]:
"""Fixes stratifed splits for edge case.
For each class that has fewer instances than number of splits, we oversample before split to n_splits and then remove all oversamples and
original samples from the splits; effectively removing the class from the data without touching the indices.
"""
val, counts = np.unique(y, return_counts=True)
too_low = val[counts < n_splits]
too_low_counts = counts[counts < n_splits]
y_dummy = pd.Series(y.copy())
X_dummy = pd.DataFrame(x.copy())
org_index_max = len(X_dummy)
invalid_index = []
for c_val, c_count in zip(too_low, too_low_counts, strict=True):
fill_missing = n_splits - c_count
invalid_index.extend(np.where(y == c_val)[0])
y_dummy = pd.concat(
[y_dummy, pd.Series([c_val] * fill_missing)],
ignore_index=True,
)
X_dummy = pd.concat(
[X_dummy, pd.DataFrame(x).head(fill_missing)],
ignore_index=True,
)
invalid_index.extend(list(range(org_index_max, len(y_dummy))))
splits = _save_stratified_splits(
_splitter=StratifiedKFold(**spliter_kwargs),
x=X_dummy,
y=y_dummy,
n_splits=n_splits,
)
len_out = len(splits)
for i in range(len_out):
train_index, test_index = splits[i]
splits[i][0] = [index for index in train_index if index not in invalid_index]
splits[i][1] = [index for index in test_index if index not in invalid_index]
return splits
def assert_valid_splits(
splits: list[list[list[int], list[int]]],
y: np.ndarray,
*,
non_empty: bool = True,
each_selected_class_in_each_split_subset: bool = True,
same_length_training_splits: bool = False,
):
"""Verify that the splits are valid."""
if non_empty:
assert len(splits) != 0, "No splits generated!"
for split in splits:
assert len(split) != 0, "Some split is empty!"
assert len(split[0]) != 0, "A train subset of a split is empty!"
assert len(split[1]) != 0, "A test subset of a split is empty!"
if each_selected_class_in_each_split_subset:
# As we might drop classes, we first need to build the set of classes that are in the splits.
# - 2nd unique is for speed up purposes only.
_real_y = set(
np.unique([c for split in splits for c in np.unique(y[split[1]])]),
)
# Now we need to check that each class that exists in all splits is in each split.
for split in splits:
assert _real_y == (set(np.unique(y[split[0]]))), "A class is missing in a train subset!"
assert _real_y == (set(np.unique(y[split[1]]))), "A class is missing in a test subset!"
if same_length_training_splits:
for split in splits:
assert len(split[0]) == len(
splits[0][0],
), "A train split has different amount of samples!"
def _equalize_training_splits(
input_splits: list[list[list[int], list[int]]],
rng: np.random.RandomState,
) -> list[list[list[int], list[int]]]:
"""Equalize training splits by duplicating samples in too small splits."""
splits = input_splits[:]
n_max_train_samples = max(len(split[0]) for split in splits)
for split in splits:
curr_split_len = len(split[0])
if curr_split_len < n_max_train_samples:
missing_samples = n_max_train_samples - curr_split_len
split[0].extend(
[int(dup_i) for dup_i in rng.choice(split[0], size=missing_samples)],
)
split[0] = sorted(split[0])
return splits
def get_cv_split_for_data(
x: np.ndarray,
y: np.ndarray,
splits_seed: int,
n_splits: int,
*,
stratified_split: bool,
safety_shuffle: bool = True,
auto_fix_stratified_splits: bool = False,
force_same_length_training_splits: bool = False,
) -> list[list[list[int], list[int]]] | str:
"""Safety shuffle and generate (safe) splits.
If it returns str at the first entry, no valid split could be generated and the str is the reason why.
Note: the function does not support repeated splits at this point.
Simply call this function multiple times with different seeds to get repeated splits.
Test with:
```python
if __name__ == "__main__":
print(
get_cv_split_for_data(
x=np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).T,
y=np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4]),
splits_seed=42,
n_splits=3,
stratified_split=True,
auto_fix_stratified_splits=True,
)
)
```
Args:
x: The data to split.
y: The labels to split.
splits_seed: The seed to use for the splits. Or a RandomState object.
n_splits: The number of splits to generate.
stratified_split: Whether to use stratified splits.
safety_shuffle: Whether to shuffle the data before splitting.
auto_fix_stratified_splits: Whether to try to fix stratified splits automatically.
Fix by dropping classes with less than n_splits samples.
force_same_length_training_splits: Whether to force the training splits to have the same amount of samples.
Force by duplicating random instance in the training subset of a too small split until all training splits have the same amount of samples.
Out:
A list of pairs of indexes, where in each pair first come the train examples, then test. So we get something like
`[[TRAIN_INDICES_0, TEST_INDICES_0], [TRAIN_INDICES_1, TRAIN_INDICES_1]]` for 2 splits.
Or a string if no valid split could be generated whereby the string gives the reason.
"""
assert len(x) == len(y), "x and y must have the same length!"
rng = np.random.RandomState(splits_seed)
if safety_shuffle:
p = rng.permutation(len(x))
x, y = x[p], y[p]
spliter_kwargs = {"n_splits": n_splits, "shuffle": True, "random_state": rng}
if not stratified_split:
splits = [list(tpl) for tpl in KFold(**spliter_kwargs).split(x, y)]
else:
try:
splits = _save_stratified_splits(
_splitter=StratifiedKFold(**spliter_kwargs),
x=x,
y=y,
n_splits=n_splits,
auto_fix_stratified_splits=auto_fix_stratified_splits,
)
assert_valid_splits(
splits=splits,
y=y,
non_empty=True,
same_length_training_splits=force_same_length_training_splits,
each_selected_class_in_each_split_subset=stratified_split,
)
except (AssertionError, UnableToFixTooManySplitsError) as e:
logger.debug(e)
if auto_fix_stratified_splits:
logger.warning("Splits are not valid. Trying to fix stratified splits automatically...")
splits = fix_split_by_dropping_classes(
x=x,
y=y,
n_splits=n_splits,
spliter_kwargs=spliter_kwargs,
)
else:
logger.debug(e)
if isinstance(e, UnableToFixTooManySplitsError):
splits = "Cannot generate valid stratified splits for dataset due to not enough samples per class!"
else:
splits = (
"Cannot generate valid stratified splits for dataset without losing classes in some subsets!"
)
if isinstance(splits, str):
return splits
if force_same_length_training_splits:
splits = _equalize_training_splits(splits, rng)
assert_valid_splits(
splits=splits,
y=y,
non_empty=True,
same_length_training_splits=force_same_length_training_splits,
each_selected_class_in_each_split_subset=stratified_split,
)
if safety_shuffle:
# Revert to correct outer scope indices
for split in splits:
split[0] = sorted(p[split[0]])
split[1] = sorted(p[split[1]])
return splits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment