Skip to content

Instantly share code, notes, and snippets.

@fxsjy
Last active October 13, 2021 01:24
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save fxsjy/5574345 to your computer and use it in GitHub Desktop.
Save fxsjy/5574345 to your computer and use it in GitHub Desktop.
mnist with sklearn
import numpy
import random
from numpy import arange
#from classification import *
from sklearn import metrics
from sklearn.datasets import fetch_mldata
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import time
def run():
mnist = fetch_mldata('MNIST original')
#mnist.data, mnist.target = shuffle(mnist.data, mnist.target)
#print mnist.data.shape
# Trunk the data
n_train = 60000
n_test = 10000
# Define training and testing sets
indices = arange(len(mnist.data))
random.seed(0)
#train_idx = random.sample(indices, n_train)
#test_idx = random.sample(indices, n_test)
train_idx = arange(0,n_train)
test_idx = arange(n_train+1,n_train+n_test)
X_train, y_train = mnist.data[train_idx], mnist.target[train_idx]
X_test, y_test = mnist.data[test_idx], mnist.target[test_idx]
# Apply a learning algorithm
print "Applying a learning algorithm..."
clf = RandomForestClassifier(n_estimators=10,n_jobs=2)
clf.fit(X_train, y_train)
# Make a prediction
print "Making predictions..."
y_pred = clf.predict(X_test)
#print y_pred
# Evaluate the prediction
print "Evaluating results..."
print "Precision: \t", metrics.precision_score(y_test, y_pred)
print "Recall: \t", metrics.recall_score(y_test, y_pred)
print "F1 score: \t", metrics.f1_score(y_test, y_pred)
print "Mean accuracy: \t", clf.score(X_test, y_test)
if __name__ == "__main__":
start_time = time.time()
results = run()
end_time = time.time()
print "Overall running time:", end_time - start_time
@timkofu
Copy link

timkofu commented Jul 18, 2019

Had the same issue witth fetch_mldata(). After reading this SF answer https://stackoverflow.com/a/51301798/433717 I downloaded it from Kaggle.

@EyjanHuang
Copy link

There is a problem that if I use the original .idx file, transform it into the matrix and train it, the time is so long. Any better solution?

@codeshoper
Copy link

i can't import the dataset whenever i try to do so this error pops up
ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets'

@sri-vishnu-001
Copy link

i can't import the dataset whenever i try to do so this error pops up
ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets'

use fetch_openml inplace of fetch_mldata and use 'mnist_784' inplace of 'MNIST original'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment