Last active
November 5, 2020 09:32
-
-
Save amir-rahnama/56bfa34ff0092d9d29da1f4f2bea9e26 to your computer and use it in GitHub Desktop.
Cross validation with pure numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
np.random.seed(0) | |
def cross_validation(X, y, cv_size=0.1): | |
"""Run cross validation on a numpy ndarray and return corresponding indices as well | |
@param: X data in the form of numpy ndarray | |
@param: y labels in the form of numpy ndarray | |
@param: cv_size size of the test set | |
""" | |
data_size = X.shape[0] | |
test_size = np.round(data_size * 0.1).astype(np.int32) | |
total_idx = np.arange(0, data_size) | |
test_idx = np.random.choice(total_idx, replace=False, size=(test_size)) | |
train_idx = np.setdiff1d(total_idx, test_idx) | |
return X[train_idx, :], y[train_idx, :], train_idx, X[test_idx, :], y[test_idx, :], test_idx |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The reason for using this instead of Sklearn's
from sklearn.model_selection import train_test_split
: