Last active
October 16, 2022 02:06
-
-
Save wassname/f34321d4797a356a82802bdfb935e6cd to your computer and use it in GitHub Desktop.
split stratify pandas by unique
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
If you want to split and sample at the same time use something else. | |
but in timeseries sometimes you want to split by time, then resample to get balanced weights | |
@url:https://gist.github.com/wassname/f34321d4797a356a82802bdfb935e6cd/edit | |
@author:wassname | |
@lic: meh | |
""" | |
def rebalance(d:pd.DataFrame, y_col='y_cls', replace=True): | |
"""Rebalance an dataframe compared to labels""" | |
# find sampling weights | |
y2 = d[y_col] | |
eps = 1e-5 | |
freq = 1/y2.value_counts(normalize=1).sort_index() | |
freq /= freq.sum() + eps | |
weights = y2.replace(freq) | |
if replace==1: | |
# keer array the same size, some over and some undersampling to get an equal proportion of each label type | |
n = len(d) | |
elif replace>1: | |
# keep all of biggest class, oversample the smaller | |
n = int(freq.max() * len(d)) * len(freq) | |
else: | |
# undersample larger classes | |
n = int(freq.min() * len(d)) * len(freq) | |
d2 = d.sample(n=n, weights=weights, replace=True) | |
print(d2[y_col].value_counts(normalize=1).sort_index()) | |
return d2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An example of stratified split by unique id | |
For example there might be a repeat patient id `individualID`, and you want to make sure it only appears in the train and not the test. | |
You also want to make sure the train and test have the same proportion from each `Country`. | |
url: https://gist.github.com/wassname/f34321d4797a356a82802bdfb935e6cd | |
author: wassname | |
lic: meh | |
""" | |
from sklearn.model_selection import train_test_split | |
def split_by_unique_col(df_train, col='individualID', stratify_cols=['Country']): | |
# Make a dataframe of unique ids, with our stratification data | |
df_ids = df_train[[col]+stratify_cols].groupby(col).first() | |
# split up the unique ids, stratifying | |
df_ids_train, df_ids_other = train_test_split(df_ids, test_size=0.4, random_state=random_seed, stratify=df_ids[stratify_cols]) | |
df_ids_vals, df_ids_test = train_test_split(df_ids_other, test_size=0.5, random_state=random_seed, stratify=df_ids_other[stratify_cols]) | |
# | |
train = df_train[df_train[col].isin(df_ids_train.index)] | |
valid = df_train[df_train[col].isin(df_ids_vals.index)] | |
test = df_train[df_train[col].isin(df_ids_test.index)] | |
# make sure there is no overlap | |
assert not set(train[col]).intersection(set(test[col])) | |
assert not set(train[col]).intersection(set(valid[col])) | |
assert not set(test[col]).intersection(set(valid[col])) | |
return train, valid, test | |
train, valid, test = split_by_unique_col(df_train) | |
# eyeball check for stratification | |
print(train.Country.value_counts(normalize=True)) | |
print(valid.Country.value_counts(normalize=True)) | |
test.Country.value_counts(normalize=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment