Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
split_by_unique_col
from sklearn.model_selection import train_test_split
import pandas as pd
def shuffle_df(df, random_seed=42):
return df.sample(frac=1, random_state=random_seed, replace=False)
def split_by_unique_col(df, col='patient_id', stratify_cols=[], random_seed=42):
"""
Make a dataframe of unique ids, with our stratification data
url: https://gist.github.com/wassname/13af904e117fdec775446fedb559c57d
"""
df_ids = df[[col]+stratify_cols].groupby(col).first()
# split up the unique ids, stratifying
df_ids_train, df_ids_other = train_test_split(df_ids, test_size=0.4, random_state=random_seed, stratify=df_ids[stratify_cols] if len(stratify_cols) else None)
df_ids_vals, df_ids_test = train_test_split(df_ids_other, test_size=0.5, random_state=random_seed, stratify=df_ids_other[stratify_cols] if len(stratify_cols) else None)
train = df[df[col].isin(df_ids_train.index)]
valid = df[df[col].isin(df_ids_vals.index)]
test = df[df[col].isin(df_ids_test.index)]
# make sure there is no overlap
assert not set(train[col]).intersection(set(test[col]))
assert not set(train[col]).intersection(set(valid[col]))
assert not set(test[col]).intersection(set(valid[col]))
train = shuffle_df(train, random_seed=random_seed)
test = shuffle_df(test, random_seed=random_seed)
valid = shuffle_df(valid, random_seed=random_seed)
return train, valid, test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment