Skip to content

Instantly share code, notes, and snippets.

@ma7555
Last active June 19, 2020 14:02
Show Gist options
  • Save ma7555/121d82d14b785270854b31610faf88d3 to your computer and use it in GitHub Desktop.
Save ma7555/121d82d14b785270854b31610faf88d3 to your computer and use it in GitHub Desktop.
Create stratified train/test/validation splits for a pandas dataframe
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
def get_stf_ttv(data, targets, train_size=0.8, random_state=555):
'''
Used to get stratified train/test/validation splits
Test and validation splits are equal, if train_size is set to 0.8,
the remaining 0.2 will be split between test and validation
resulting in 80% train, 10% test, 10% validation
Parameters:
data (pd.DataFrame):
targets (pd.Series)
train_size (float)
random_state (int)
Returns:
train_index (np.array)
test_index (np.array)
val_index (np.array)
'''
sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=random_state)
train_index, test_valid_index = next(sss.split(data, targets))
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=random_state)
test_index, val_index = next(sss.split(data.iloc[test_valid_index], targets.iloc[test_valid_index]))
test_index = targets.iloc[test_valid_index].iloc[test_index].index
val_index = targets.iloc[test_valid_index].iloc[val_index].index
return train_index, test_index, val_index
train_index, test_index, val_index = get_stf_ttv(data, targets, random_state=555)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment