Skip to content

Instantly share code, notes, and snippets.

@caleb-kaiser
Created October 31, 2019 15:06
Show Gist options
  • Save caleb-kaiser/79eb48ec545bcdaa6c801573ef60b745 to your computer and use it in GitHub Desktop.
Save caleb-kaiser/79eb48ec545bcdaa6c801573ef60b745 to your computer and use it in GitHub Desktop.
from sklearn.ensemble import RandomForestClassifier
import sys
import re
import boto3
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
try:
import cPickle as pickle # python2
except ModuleNotFoundError:
import pickle # python3
S3_UPLOAD_PATH = "s3://YOUR/BUCKET" #@param {type:"string"}
input = sys.argv[1]
with open(input, 'rb') as fd:
model = pickle.load(fd)
initial_type = [('float_input', FloatTensorType([1, 4]))]
onx = convert_sklearn(model, initial_types=initial_type)
with open("model.onnx", "wb") as f:
f.write(onx.SerializeToString())
try:
bucket = re.search("s3://(.+?)/", S3_UPLOAD_PATH).group(1)
except:
print("\033[91m{}\033[00m".format("ERROR: Invalid s3 path (should be of the form s3://my-bucket/path/to/file)"), file=sys.stderr)
session = boto3.Session(profile_name='dev')
s3 = session.client("s3")
filepath = './model.onnx'
filekey = 'model.onnx'
print("Uploading s3://{}/{} ...".format(bucket, filekey), end = '')
s3.upload_file(filepath, bucket, filekey)
print(" ✓")
print("\nUploaded model export directory to " + S3_UPLOAD_PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment