Skip to content

Instantly share code, notes, and snippets.

@MattMcMurray
Created November 9, 2017 21:55
Show Gist options
  • Save MattMcMurray/4e8e9cd2df20684ffdbb49c0c94d2c0e to your computer and use it in GitHub Desktop.
Save MattMcMurray/4e8e9cd2df20684ffdbb49c0c94d2c0e to your computer and use it in GitHub Desktop.
from sklearn.model_selection import StratifiedShuffleSplit
# Let's create an age category
age_cat = np.ceil(appt_data['Age'] / 10)
# Let's group anybody >100yrs old into the 100 year old category, as they are outliers
age_cat.where(age_cat < 100, 100, inplace=True)
appt_data['AgeCategory'] = age_cat
# Create a test set that is 20% of all values
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(appt_data, appt_data['AgeCategory']):
strat_train_set = appt_data.loc[train_index]
strat_test_set = appt_data.loc[test_index]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment