Created
October 31, 2019 15:06
-
-
Save caleb-kaiser/79eb48ec545bcdaa6c801573ef60b745 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.ensemble import RandomForestClassifier | |
import sys | |
import re | |
import boto3 | |
from skl2onnx import convert_sklearn | |
from skl2onnx.common.data_types import FloatTensorType | |
try: | |
import cPickle as pickle # python2 | |
except ModuleNotFoundError: | |
import pickle # python3 | |
S3_UPLOAD_PATH = "s3://YOUR/BUCKET" #@param {type:"string"} | |
input = sys.argv[1] | |
with open(input, 'rb') as fd: | |
model = pickle.load(fd) | |
initial_type = [('float_input', FloatTensorType([1, 4]))] | |
onx = convert_sklearn(model, initial_types=initial_type) | |
with open("model.onnx", "wb") as f: | |
f.write(onx.SerializeToString()) | |
try: | |
bucket = re.search("s3://(.+?)/", S3_UPLOAD_PATH).group(1) | |
except: | |
print("\033[91m{}\033[00m".format("ERROR: Invalid s3 path (should be of the form s3://my-bucket/path/to/file)"), file=sys.stderr) | |
session = boto3.Session(profile_name='dev') | |
s3 = session.client("s3") | |
filepath = './model.onnx' | |
filekey = 'model.onnx' | |
print("Uploading s3://{}/{} ...".format(bucket, filekey), end = '') | |
s3.upload_file(filepath, bucket, filekey) | |
print(" ✓") | |
print("\nUploaded model export directory to " + S3_UPLOAD_PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment