Created
June 10, 2017 14:22
-
-
Save mbednarski/0d9bce51cc218dfb88d26a913b0dc5b0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import sys | |
from sklearn.ensemble import RandomForestClassifier | |
sys.path.append('src') | |
from data import get_featues, get_label | |
class RandomForestModel(object): | |
def __init__(self): | |
self.clf = RandomForestClassifier(n_estimators=50, max_depth=700) | |
self.name = 'RandomForest' | |
def get_params(self): | |
return self.clf.get_params() | |
def train(self, dframe): | |
X = get_featues(dframe) | |
y = get_label(dframe) | |
self.clf.fit(X, y) | |
def predict(self, X): | |
y_pred = self.clf.predict(X) | |
return y_pred | |
def save(self, fname): | |
with open(fname, 'wb') as ofile: | |
pickle.dump(self.clf, ofile, pickle.HIGHEST_PROTOCOL) | |
def load(self, fname): | |
with open(fname, 'rb') as ifile: | |
self.clf = pickle.load(ifile) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment