Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created April 15, 2019 09:15
Show Gist options
  • Save seanie12/aed0b201826c500c22aa73c69d9ea0ef to your computer and use it in GitHub Desktop.
Save seanie12/aed0b201826c500c22aa73c69d9ea0ef to your computer and use it in GitHub Desktop.
import numpy as np
import pickle
train_x_file = "./data/train_x"
train_y_file = "./data/train_y"
dev_x_file = "./data/dev_x"
dev_y_file = "./data/dev_y"
input_file = "./data/input.pkl"
label_file = "./data/label.pkl"
with open(input_file,"rb") as f:
x = pickle.load(f)
with open(label_file,"rb") as f:
y = pickle.load(f)
num_data = len(x)
num_train = int(num_data * 0.8)
shuffled_idx = np.random.permutation(num_data)
# shuffle data
shuffled_x = x[shuffled_idx]
shuffled_y = y[shuffled_idx]
# split train / dev set
train_x = shuffled_x[:num_train]
train_y = shuffled_y[:num_train]
dev_x = shuffled_x[num_train:]
dev_y = shuffled_y[num_train:]
np.save(train_x_file, train_x)
np.save(train_y_file, train_y)
np.save(dev_x_file, dev_x)
np.save(dev_y_file, dev_y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment