Skip to content

Instantly share code, notes, and snippets.

@EnkrateiaLucca
Created September 22, 2023 15:07
Show Gist options
  • Save EnkrateiaLucca/4eb217a4e5806b08517b1e655f3c233b to your computer and use it in GitHub Desktop.
Save EnkrateiaLucca/4eb217a4e5806b08517b1e655f3c233b to your computer and use it in GitHub Desktop.
Splits dataset into train val and test
def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1, target_variable=None):
assert (train_split + test_split + val_split) == 1
# Only allows for equal validation and test splits
assert val_split == test_split
# Shuffle
df_sample = df.sample(frac=1, random_state=12)
# Specify seed to always have the same split distribution between runs
# If target variable is provided, generate stratified sets
if target_variable is not None:
grouped_df = df_sample.groupby(target_variable)
arr_list = [np.split(g, [int(train_split * len(g)), int((1 - val_split) * len(g))]) for i, g in grouped_df]
train_ds = pd.concat([t[0] for t in arr_list])
val_ds = pd.concat([t[1] for t in arr_list])
test_ds = pd.concat([v[2] for v in arr_list])
else:
indices_or_sections = [int(train_split * len(df)), int((1 - val_split) * len(df))]
train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections)
return train_ds.index, val_ds.index, test_ds.index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment