Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active October 16, 2022 02:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/f34321d4797a356a82802bdfb935e6cd to your computer and use it in GitHub Desktop.
Save wassname/f34321d4797a356a82802bdfb935e6cd to your computer and use it in GitHub Desktop.
split stratify pandas by unique
"""
If you want to split and sample at the same time use something else.
but in timeseries sometimes you want to split by time, then resample to get balanced weights
@url:https://gist.github.com/wassname/f34321d4797a356a82802bdfb935e6cd/edit
@author:wassname
@lic: meh
"""
def rebalance(d:pd.DataFrame, y_col='y_cls', replace=True):
"""Rebalance an dataframe compared to labels"""
# find sampling weights
y2 = d[y_col]
eps = 1e-5
freq = 1/y2.value_counts(normalize=1).sort_index()
freq /= freq.sum() + eps
weights = y2.replace(freq)
if replace==1:
# keer array the same size, some over and some undersampling to get an equal proportion of each label type
n = len(d)
elif replace>1:
# keep all of biggest class, oversample the smaller
n = int(freq.max() * len(d)) * len(freq)
else:
# undersample larger classes
n = int(freq.min() * len(d)) * len(freq)
d2 = d.sample(n=n, weights=weights, replace=True)
print(d2[y_col].value_counts(normalize=1).sort_index())
return d2
"""
An example of stratified split by unique id
For example there might be a repeat patient id `individualID`, and you want to make sure it only appears in the train and not the test.
You also want to make sure the train and test have the same proportion from each `Country`.
url: https://gist.github.com/wassname/f34321d4797a356a82802bdfb935e6cd
author: wassname
lic: meh
"""
from sklearn.model_selection import train_test_split
def split_by_unique_col(df_train, col='individualID', stratify_cols=['Country']):
# Make a dataframe of unique ids, with our stratification data
df_ids = df_train[[col]+stratify_cols].groupby(col).first()
# split up the unique ids, stratifying
df_ids_train, df_ids_other = train_test_split(df_ids, test_size=0.4, random_state=random_seed, stratify=df_ids[stratify_cols])
df_ids_vals, df_ids_test = train_test_split(df_ids_other, test_size=0.5, random_state=random_seed, stratify=df_ids_other[stratify_cols])
#
train = df_train[df_train[col].isin(df_ids_train.index)]
valid = df_train[df_train[col].isin(df_ids_vals.index)]
test = df_train[df_train[col].isin(df_ids_test.index)]
# make sure there is no overlap
assert not set(train[col]).intersection(set(test[col]))
assert not set(train[col]).intersection(set(valid[col]))
assert not set(test[col]).intersection(set(valid[col]))
return train, valid, test
train, valid, test = split_by_unique_col(df_train)
# eyeball check for stratification
print(train.Country.value_counts(normalize=True))
print(valid.Country.value_counts(normalize=True))
test.Country.value_counts(normalize=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment