Skip to content

Instantly share code, notes, and snippets.

@duttashi
Created February 18, 2021 07:53
Show Gist options
  • Save duttashi/1f254a58b9a24d85bf795df43de00f29 to your computer and use it in GitHub Desktop.
Save duttashi/1f254a58b9a24d85bf795df43de00f29 to your computer and use it in GitHub Desktop.
function to split a data into 3 sets (train, test, validate)
# create a custom function to split data into 3 sets
import numpy as np
def train_validate_test_split(df, train_percent=.6, validate_percent=.2, seed=None):
np.random.seed(seed)
perm = np.random.permutation(df.index)
m = len(df.index)
train_end = int(train_percent * m)
validate_end = int(validate_percent * m) + train_end
train = df.iloc[perm[:train_end]]
validate = df.iloc[perm[train_end:validate_end]]
test = df.iloc[perm[validate_end:]]
return train, validate, test
# usage
np.random.seed([3,1415])
df = pd.DataFrame(np.random.rand(10, 5), columns=list('ABCDE'))
print(df)
train, validate, test = train_validate_test_split(df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment