Skip to content

Instantly share code, notes, and snippets.

@alperyeg
Last active August 19, 2022 11:39
Show Gist options
  • Save alperyeg/ca5e5e9b5ffb442a9ce5caca7c8399c1 to your computer and use it in GitHub Desktop.
Save alperyeg/ca5e5e9b5ffb442a9ce5caca7c8399c1 to your computer and use it in GitHub Desktop.
Loading MNIST dataset with scikit learn
from sklearn.datasets import fetch_openml
from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
def fetch_data(test_size=10000, randomize=False, standardize=True):
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
if randomize:
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, shuffle=False)
if standardize:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
return X_train, y_train, X_test, y_test
if __name__ == '__main__':
train_data, train_labels, test_data, test_labels = fetch_data()
@rezamarzban
Copy link

Good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment