Skip to content

Instantly share code, notes, and snippets.

@caleb-kaiser
Last active October 28, 2019 20:46
Show Gist options
  • Save caleb-kaiser/666aef3aa60ab74738a09f2154dd2e2d to your computer and use it in GitHub Desktop.
Save caleb-kaiser/666aef3aa60ab74738a09f2154dd2e2d to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import sys
import re
import boto3
AWS_ACCESS_KEY_ID = "" #@param {type:"string"}
AWS_SECRET_ACCESS_KEY = "" #@param {type:"string"}
S3_UPLOAD_PATH = "s3://your/bucket" #@param {type:"string"}
try:
import cPickle as pickle # python2
except ModuleNotFoundError:
import pickle # python3
try: # python2
reload(sys)
sys.setdefaultencoding('utf-8')
except NameError:
pass
def usage(msg):
if msg:
sys.stderr.write('{}\n\n'.format(msg))
sys.stderr.write('python train_model.py features seed model\n\n')
sys.stderr.write('\tfeatures \t input features and labels pickle file.\n')
sys.stderr.write('\tseed \t\t random state (integer). Example: 20170423\n')
sys.stderr.write('\tmodel \t\t output model pickle file.\n')
sys.exit(1)
if len(sys.argv) != 4:
usage('Wrong number of arguments. Usage:')
input = sys.argv[1]
output = sys.argv[3]
seed = int(sys.argv[2])
with open(input, 'rb') as fd:
matrix = pickle.load(fd)
labels = np.squeeze(matrix[:, 1].toarray())
x = matrix[:, 2:]
sys.stderr.write('Input matrix size {}\n'.format(matrix.shape))
sys.stderr.write('X matrix size {}\n'.format(x.shape))
sys.stderr.write('Y matrix size {}\n'.format(labels.shape))
clf = RandomForestClassifier(n_estimators=700, n_jobs=6, random_state=seed)
clf.fit(x, labels)
with open(output, 'wb') as fd:
pickle.dump(clf, fd)
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([1, 4]))]
onx = convert_sklearn(clf, initial_types=initial_type)
with open("model.onnx", "wb") as f:
f.write(onx.SerializeToString())
bucket = re.search("s3://(.+?)/", S3_UPLOAD_PATH).group(1)
s3 = boto3.client("s3", aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
filepath = './model.onnx'
filekey = 'model.onnx'
print("Uploading s3://{}/{} ...".format(bucket, filekey), end = '')
s3.upload_file(filepath, bucket, filekey)
print(" ✓")
print("\nUploaded model export directory to " + S3_UPLOAD_PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment