Skip to content

Instantly share code, notes, and snippets.

@appendjeff
Last active December 1, 2016 18:06
Show Gist options
  • Select an option

  • Save appendjeff/77ec9d2912f2a0c44e3840ea9f63cceb to your computer and use it in GitHub Desktop.

Select an option

Save appendjeff/77ec9d2912f2a0c44e3840ea9f63cceb to your computer and use it in GitHub Desktop.
"""@datascientist: mnest"""
import os
import time
import logging
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
class DigitClassifier(object):
"""
Uses files named train.csv and test.csv to train/test to classifier.
This object assumes that these files live next to this module.
Download the training csv files here:
https://www.kaggle.com/c/digit-recognizer/data
Uses pickle for persistence of classifer.
"""
_Classifier = RandomForestClassifier
_persistent_classifier_f_name = 'digit_classifier.pkl'
# Prediction input description, e.g. [4, 0, 0, 210, ...]
PIXEL_RANGE = (0, 255)
NUM_OF_PIXELS = 784
def __init__(self, train_now=False, n_estimators=1000):
"""Create the classifier or import old classifier"""
self.logger = logging.getLogger(__name__)
self.pwd = os.path.dirname(os.path.abspath(__file__))
self.n_estimators = n_estimators
self.pers_f_path = os.path.join(
self.pwd, self._persistent_classifier_f_name)
if train_now or self._should_train_now():
self.logger.debug('Training new digit classifier.')
self._train()
else:
age = self.get_age_of_old_classifier()
self.logger.debug('Using old classifier that is %s days old'%age)
self._read_classifier()
def _train(self, persist=True):
"""Train (and persist) the classifier."""
# Read the data
train_path = os.path.join(self.pwd, 'train.csv')
train = pd.read_csv(train_path)
labels = train.iloc[:,0].values
trainX = train.iloc[:,1:].values
# Train the classifier
self.classifier = self._Classifier(n_estimators=self.n_estimators)
self.classifier.fit(trainX, labels)
# Persist the classifier
if persist:
joblib.dump(self.classifier, self.pers_f_path)
def _read_classifier(self):
self.classifier = joblib.load(self.pers_f_path)
def test(self):
"""Use the testing set on the trained classifier"""
test_path = os.path.join(self.pwd, 'test.csv')
test = pd.read_csv(test_path)
testX = self.test.iloc[:,:].values
testY = self.classifier.predict(testX)
return testY
def predict_digit(self, arr):
"""
Given an array of length 784 where the values are
integers between 0 and 255 (representing grayscaled pixels)
this will return the predicted digit (0-9)
"""
try:
return self.classifier.predict([arr])[0]
except ValueError:
self.logger.info('predict_digit failed with ValueError')
return None
def random_digit_test(self):
"""Check if the classifier works on random data."""
random_arr = [np.random.randint(
*self.PIXEL_RANGE) for x in range(self.NUM_OF_PIXELS)]
return self.predict_digit(random_arr)
def get_age_of_old_classifier(self):
"""Return how old the persistent classifer is in days"""
sec_diff = abs(time.time() - os.path.getmtime(self.pers_f_path))
return sec_diff/float(24 * 60 * 60)
def _should_train_now(self):
"""Uses os related logic to check if classifier should be trained"""
# If the persistent file does not exist
if not os.path.exists(self.pers_f_path):
return True
# If the persistent file is blank
if os.path.getsize(self.pers_f_path) < 1:
return True
# If the persistent file is over 100 days old
if self.get_age_of_old_classifier() > 100:
return True
return days_old
@appendjeff
Copy link
Author

Example of converting sklearn classifier script into a python class

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment