Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
train_model_PYthon
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import sys
try: import cPickle as pickle # python2
except: import pickle # python3
from scipy import sparse
from numpy import loadtxt
import feather as ft
if len(sys.argv) != 4:
sys.stderr.write('Arguments error. Usage:\n')
sys.stderr.write('\tpython train_model.py INPUT_MATRIX_FILE SEED OUTPUT_MODEL_FILE\n')
sys.exit(1)
input = sys.argv[1]
seed = int(sys.argv[2])
output = sys.argv[3]
df = ft.read_dataframe(input)
labels = df.loc[:,'label']
x = df.loc[:, df.columns != 'label']
clf = RandomForestClassifier(n_estimators=100, n_jobs=2, random_state=seed)
clf.fit(x, labels.ix[:,0])
with open(output, 'wb') as fd:
pickle.dump(clf, fd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.