Skip to content

Instantly share code, notes, and snippets.

@amir-rahnama
Last active November 5, 2020 09:32
Show Gist options
  • Save amir-rahnama/56bfa34ff0092d9d29da1f4f2bea9e26 to your computer and use it in GitHub Desktop.
Save amir-rahnama/56bfa34ff0092d9d29da1f4f2bea9e26 to your computer and use it in GitHub Desktop.
Cross validation with pure numpy
import numpy as np
np.random.seed(0)
def cross_validation(X, y, cv_size=0.1):
"""Run cross validation on a numpy ndarray and return corresponding indices as well
@param: X data in the form of numpy ndarray
@param: y labels in the form of numpy ndarray
@param: cv_size size of the test set
"""
data_size = X.shape[0]
test_size = np.round(data_size * 0.1).astype(np.int32)
total_idx = np.arange(0, data_size)
test_idx = np.random.choice(total_idx, replace=False, size=(test_size))
train_idx = np.setdiff1d(total_idx, test_idx)
return X[train_idx, :], y[train_idx, :], train_idx, X[test_idx, :], y[test_idx, :], test_idx
@amir-rahnama
Copy link
Author

The reason for using this instead of Sklearn's from sklearn.model_selection import train_test_split:

  • avoid loading a library just for the sake of one method
  • return the corresponding indices from each batch which comes in handy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment