Skip to content

Instantly share code, notes, and snippets.

@azarnyx
Last active May 10, 2021 15:39
Show Gist options
  • Save azarnyx/7bda05dc2e95a3c47b7d353e85ed259c to your computer and use it in GitHub Desktop.
Save azarnyx/7bda05dc2e95a3c47b7d353e85ed259c to your computer and use it in GitHub Desktop.
# Specify that endpoint accept JSON
JSON_CONTENT_TYPE = 'application/json'
def predict_fn(input, model):
proba = model.predict_proba(input)
return json.dumps({
"proba": str(list(proba[0]))
})
def model_fn(model_dir):
clf = load(os.path.join(model_dir, 'sklearnclf.joblib'))
return clf
def input_fn(request_body, content_type=JSON_CONTENT_TYPE):
logger.info('Deserializing the input data.')
# process an jsonlines uploaded to the endpoint
if content_type == JSON_CONTENT_TYPE:
request_body = json.loads(request_body)
st = request_body["text"]
return get_embedding(st).reshape(1,-1)
raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type))
def output_fn(prediction, accept=JSON_CONTENT_TYPE):
logger.info('Serializing the generated output.')
if accept == JSON_CONTENT_TYPE: return json.dumps(prediction), accept
raise Exception('Requested unsupported ContentType in Accept: {}'.format(accept))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment