Skip to content

Instantly share code, notes, and snippets.

@blu3r4y
Created May 18, 2021 09:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blu3r4y/8b9b262d413efbe15b3ae463aa9490ad to your computer and use it in GitHub Desktop.
Save blu3r4y/8b9b262d413efbe15b3ae463aa9490ad to your computer and use it in GitHub Desktop.
Grouped train test split used by Dynatrace - SAL - LIT.AI.JKU in the NAD 2021 challenge
# Copyright 2021
# Dynatrace Research
# SAL Silicon Austria Labs
# LIT Artificial Intelligence Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Union
from itertools import chain
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
log = logging.getLogger(__name__)
def grouped_train_test_split(*arrays,
groups: Union[np.ndarray, pd.Series],
labels: Union[np.ndarray, pd.Series],
test_size: float = 0.25, random_state: int = None,
max_reshuffle: int = 100, test_size_eps: float = 0.2):
"""
Very similar to `sklearn.model_selection.train_test_split` (even has the same syntax),
but will actually split the `groups`, and remap them, resulting in a split that will always keep
the same groups together in one fold.
Outputs are similar to sklearn, so use like this:
>>> X = np.arange(100) # some feature vector
>>> y = np.random.randint(2, size=100) # two randomly assigned classes [0, 1]
>>> g = np.random.randint(10, size=100) # ten randomly assigned groups
>>>
>>> X_train, X_test, y_train, y_test = grouped_train_test_split(X, y, groups=g, labels=y, test_size=0.25))
:param arrays: Allowed inputs are numpy arrays or pandas series.
:param groups: A one-dimensional array or series of the same length as the inputs that holds group labels,
which will be used to always keep the groups together within each split.
:param labels: A one-dimensional array or series of the same length as the inputs that holds class labels,
which will be used to preserve the class label distribution in the train and test splits.
:param test_size: The size of the test split as number between 0 and 1. (default: 0.25)
:param random_state: Some integer for deterministic sampling. (default: random integer)
:param max_reshuffle: In order to get the desired splitting ratio we will try a few different.
random splits until we get it right (we try multiple times, because only after the re-mapping
we know the true split ratio). This parameter indicates how often we will try that. (default: 100)
:param test_size_eps: The tolerance that we allow between `test_size` and the true test size. (default: 0.2)
"""
if random_state is None:
random_state = np.random.randint(1_000_000)
# shapes must match
assert all([len(arr) == len(groups) for arr in arrays])
assert all([len(arr) == len(labels) for arr in arrays])
groups = pd.Series(groups)
labels = pd.Series(labels)
nclasses = labels.nunique()
# compute the median label within each group (just as a quick'n'dirty majority vote)
groups_and_labels = pd.DataFrame({"group": groups, "stratum": labels})
majorities = groups_and_labels.groupby("group").median().astype(int)
# this number shall be close to `test_size` eventually
true_test_size = np.full(nclasses, -1)
train_mask, test_mask = None, None
# loop until we stay within the desired bounds
# or reach the reshuffle limit ...
nshuffle = 0
while np.any(np.abs(true_test_size - test_size) > test_size_eps) \
and nshuffle <= max_reshuffle:
nshuffle += 1
# first, split the group labels and try to stratify a little bit
train_groups, test_groups = train_test_split(majorities.index, test_size=test_size,
stratify=majorities["stratum"],
random_state=random_state + nshuffle)
# remap the masks
train_mask, test_mask = groups.isin(train_groups), groups.isin(test_groups)
# compute resulting label distribution
train_distr = labels[train_mask].value_counts()
test_distr = labels[test_mask].value_counts()
# preserve the number of classes in each split
splits_have_all_classes = train_distr.nunique() == test_distr.nunique() == nclasses
# check the difference in the distributions
# (as long as we at least got all classes in each split as well)
if splits_have_all_classes:
true_test_size = test_distr / (test_distr + train_distr)
else:
true_test_size = np.full(nclasses, -1)
# TODO: we could save the closest split we make instead of the last one here ...
# warn if we still couldn't achieve the desired split
if np.any(np.abs(true_test_size - test_size) > test_size_eps):
log.warning(f"the true test sizes of the grouped split are still {true_test_size} "
f"after {max_reshuffle} tries to re-shuffle the groups "
f"(you wanted {test_size:.2f} +/- {test_size_eps:.2f})")
else:
log.info(f"achieved true test sizes: {true_test_size}")
return list(chain.from_iterable(
(arr[train_mask], arr[test_mask])
for arr in arrays
))
numpy>=1.19
pandas>=1.2
scikit-learn>=0.23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment