Last active
December 1, 2016 18:06
-
-
Save appendjeff/77ec9d2912f2a0c44e3840ea9f63cceb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """@datascientist: mnest""" | |
| import os | |
| import time | |
| import logging | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.externals import joblib | |
| class DigitClassifier(object): | |
| """ | |
| Uses files named train.csv and test.csv to train/test to classifier. | |
| This object assumes that these files live next to this module. | |
| Download the training csv files here: | |
| https://www.kaggle.com/c/digit-recognizer/data | |
| Uses pickle for persistence of classifer. | |
| """ | |
| _Classifier = RandomForestClassifier | |
| _persistent_classifier_f_name = 'digit_classifier.pkl' | |
| # Prediction input description, e.g. [4, 0, 0, 210, ...] | |
| PIXEL_RANGE = (0, 255) | |
| NUM_OF_PIXELS = 784 | |
| def __init__(self, train_now=False, n_estimators=1000): | |
| """Create the classifier or import old classifier""" | |
| self.logger = logging.getLogger(__name__) | |
| self.pwd = os.path.dirname(os.path.abspath(__file__)) | |
| self.n_estimators = n_estimators | |
| self.pers_f_path = os.path.join( | |
| self.pwd, self._persistent_classifier_f_name) | |
| if train_now or self._should_train_now(): | |
| self.logger.debug('Training new digit classifier.') | |
| self._train() | |
| else: | |
| age = self.get_age_of_old_classifier() | |
| self.logger.debug('Using old classifier that is %s days old'%age) | |
| self._read_classifier() | |
| def _train(self, persist=True): | |
| """Train (and persist) the classifier.""" | |
| # Read the data | |
| train_path = os.path.join(self.pwd, 'train.csv') | |
| train = pd.read_csv(train_path) | |
| labels = train.iloc[:,0].values | |
| trainX = train.iloc[:,1:].values | |
| # Train the classifier | |
| self.classifier = self._Classifier(n_estimators=self.n_estimators) | |
| self.classifier.fit(trainX, labels) | |
| # Persist the classifier | |
| if persist: | |
| joblib.dump(self.classifier, self.pers_f_path) | |
| def _read_classifier(self): | |
| self.classifier = joblib.load(self.pers_f_path) | |
| def test(self): | |
| """Use the testing set on the trained classifier""" | |
| test_path = os.path.join(self.pwd, 'test.csv') | |
| test = pd.read_csv(test_path) | |
| testX = self.test.iloc[:,:].values | |
| testY = self.classifier.predict(testX) | |
| return testY | |
| def predict_digit(self, arr): | |
| """ | |
| Given an array of length 784 where the values are | |
| integers between 0 and 255 (representing grayscaled pixels) | |
| this will return the predicted digit (0-9) | |
| """ | |
| try: | |
| return self.classifier.predict([arr])[0] | |
| except ValueError: | |
| self.logger.info('predict_digit failed with ValueError') | |
| return None | |
| def random_digit_test(self): | |
| """Check if the classifier works on random data.""" | |
| random_arr = [np.random.randint( | |
| *self.PIXEL_RANGE) for x in range(self.NUM_OF_PIXELS)] | |
| return self.predict_digit(random_arr) | |
| def get_age_of_old_classifier(self): | |
| """Return how old the persistent classifer is in days""" | |
| sec_diff = abs(time.time() - os.path.getmtime(self.pers_f_path)) | |
| return sec_diff/float(24 * 60 * 60) | |
| def _should_train_now(self): | |
| """Uses os related logic to check if classifier should be trained""" | |
| # If the persistent file does not exist | |
| if not os.path.exists(self.pers_f_path): | |
| return True | |
| # If the persistent file is blank | |
| if os.path.getsize(self.pers_f_path) < 1: | |
| return True | |
| # If the persistent file is over 100 days old | |
| if self.get_age_of_old_classifier() > 100: | |
| return True | |
| return days_old |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example of converting sklearn classifier script into a python class