Skip to content

Instantly share code, notes, and snippets.

@Swarchal
Created December 21, 2016 11:19
Show Gist options
  • Save Swarchal/59d91d6c364e4e697c5f06666ff83d75 to your computer and use it in GitHub Desktop.
Save Swarchal/59d91d6c364e4e697c5f06666ff83d75 to your computer and use it in GitHub Desktop.
train test split
import random
def train_test_split(data, labels, test_prop=0.3):
"""roll your own train test split"""
assert len(data) == len(labels)
n_test = round(test_prop * len(data))
n_train = len(data) - n_test
combined = list(zip(data, labels))
random.shuffle(combined)
x_train, y_train = zip(*combined[:n_train])
x_test, y_test = zip(*combined[-n_test:])
return [list(i) for i in [x_train, y_train, x_test, y_test]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment