Skip to content

Instantly share code, notes, and snippets.

@jakechen
Created August 27, 2017 20:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jakechen/f46bc82184a98fc7de3b0633f73b766a to your computer and use it in GitHub Desktop.
Save jakechen/f46bc82184a98fc7de3b0633f73b766a to your computer and use it in GitHub Desktop.
Saving a trained MXNet model to S3, then recall and use the model for a prediction
import boto3
import mxnet as mx
from mxnet.io import NDArrayIter
def predict_from_s3(record, bucket_name, s3_symbol_key, s3_params_key):
"""Graphs MXNet network definitions from and S3 bucket and uses it for prediction on a single record
Keyword arguments:
record -- the record to predict from
bucket_name -- bucket where your MXNet network is stored
s3_symbol_key -- key to your MXNet Symbol in S3
s3_params_key -- key to your MXNet Parameters in S3
"""
s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket_name)
bucket.download_file(s3_symbol_key, './temp_symbol.mxnet')
bucket.download_file(s3_params_key, './temp_params.mxnet')
sym = mx.symbol.load('./temp_symbol.mxnet') # loads network graph
mod = mx.mod.Module(sym, context=mx.gpu(0)) # instantiates new MXNet Module from loaded network graph
mod.bind(NDArrayIter(record).provide_data, for_training=False) # binds the current symbol to an executor
mod.load_params('./temp_params.mxnet')
y_pred = model.predict(NDArrayIter(record))
retur y_pred
# Assumes that a model has been trained prior to the following code
# Training will look something like this:
#
# mod = mx.mod.Module(sym)
# mod.fit(...)
local_symbol_path = "your_local_symbol_path" # temp path to export your network graph
local_params_path = "your_local_params_path" # temp path to export your network parameters i.e. weights
bucket_name = "your_bucket_here" # s3 key to save your network to
s3_symbol_key = "your_s3_symbol_key" # s3 key to save your network graph
s3_params_key = "your_s3_params_key" # s3 key to save your network parameters i.e. weights
# Save network to local
sym.save(local_symbol_path)
mod.save_params(local_params_path)
# Upload to S3
import boto3
s3 = boto3.resource('s3')
s3.Bucket(bucket_name).upload_file(local_symbol_path, s3_symbol_key)
s3.Bucket(bucket_name).upload_file(local_params_path, s3_params_key)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment