Skip to content

Instantly share code, notes, and snippets.

@nitin-bommi
Created June 18, 2020 14:15
Show Gist options
  • Save nitin-bommi/e3adbd809fe8abfea081505272c5c3cc to your computer and use it in GitHub Desktop.
Save nitin-bommi/e3adbd809fe8abfea081505272c5c3cc to your computer and use it in GitHub Desktop.
training a dumb classifier
# Importing the dataset.
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
# Creating independent and dependent variables.
X, y = mnist['data'], mnist['target']
# Splitting the data into training set and test set.
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
"""
The training set is already shuffled for us, which is good as this guarantees that all
cross-validation folds will be similar.
"""
# Training a binary classifier.
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits.
y_test_5 = (y_test == 5)
"""
Building a dumb classifier that just classifies every single image in the “not-5” class.
"""
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
pass
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment