Skip to content

Instantly share code, notes, and snippets.

@AparaV
Last active January 3, 2018 01:58
Show Gist options
  • Save AparaV/902692e441c06604703dbc7ffd2d3680 to your computer and use it in GitHub Desktop.
Save AparaV/902692e441c06604703dbc7ffd2d3680 to your computer and use it in GitHub Desktop.
def split(train_dataset):
'''
Shuffle data and split into 3 datasets
1. Training - 60%
2. Validation - 20%
3. Testing - 20%
'''
# Shuffle data
train_dataset = train_dataset.sample(frac=1)
train, valid, test = np.split(train_dataset,
[int(.6 * len(train_dataset)), int(.8 * len(train_dataset))])
# Convert into numpy arrays
x_train = train.drop(['SalePrice', 'Id'], axis=1).as_matrix().astype(np.float32)
y_train = train['SalePrice'].as_matrix().astype(np.float32).reshape((np.shape(x_train)[0], 1))
x_test = test.drop(['SalePrice', 'Id'], axis=1).as_matrix().astype(np.float32)
y_test = test['SalePrice'].as_matrix().astype(np.float32).reshape((np.shape(x_test)[0], 1))
x_valid = valid.drop(['SalePrice', 'Id'], axis=1).as_matrix().astype(np.float32)
y_valid = valid['SalePrice'].as_matrix().astype(np.float32).reshape((np.shape(x_valid)[0], 1))
return x_train, y_train, x_test, y_test, x_valid, y_valid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment