Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Created September 11, 2018 23:28
Show Gist options
  • Save zmjjmz/ba033696f7d84df119725ff187817b22 to your computer and use it in GitHub Desktop.
Save zmjjmz/ba033696f7d84df119725ff187817b22 to your computer and use it in GitHub Desktop.
Sagemaker Multi-input repro
import os
import json
import numpy
import tensorflow
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
print("Tensorflow version: {0}".format(tensorflow.VERSION))
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
DATA_SIZE = 1024
BATCH_SIZE = 32
N_EPOCHS = 1
EMBED_DIM = 5
DATA_DICT = {
'input1': numpy.random.rand(DATA_SIZE, EMBED_DIM).astype('float32'),
'input2': numpy.random.rand(DATA_SIZE, EMBED_DIM).astype('float32'),
}
LABELS = tensorflow.keras.utils.to_categorical(
numpy.random.randint(2, size=(DATA_SIZE,)), num_classes=2)
def get_input_fn():
return tensorflow.estimator.inputs.numpy_input_fn(
x=DATA_DICT, y=LABELS,
shuffle=True,
batch_size=BATCH_SIZE,
num_epochs=N_EPOCHS,
)
def train_input_fn(training_dir, params):
return get_input_fn()()
def eval_input_fn(training_dir, params):
return get_input_fn()()
def serving_input_fn(params):
# annoyingly I can't really give this the model and have it pull the name
# from the input layer
tensor1 = tensorflow.placeholder(
tensorflow.float32, shape=[None, EMBED_DIM])
tensor2 = tensorflow.placeholder(
tensorflow.float32, shape=[None, EMBED_DIM])
return build_raw_serving_input_receiver_fn({
'input1': tensor1,
'input2': tensor2,
})()
def input_fn(serialized_input, content_type):
"""An input_fn that loads a pickled object"""
if content_type == "application/json":
deserialized_input = json.loads(serialized_input)
if isinstance(deserialized_input, dict):
deserialized_tensorproto = {
k:tensorflow.make_tensor_proto(v)
for k, v in deserialized_input.items()
}
else:
deserialized_tensorproto = tensorflow.make_tensor_proto(serialized_input)
return deserialized_tensorproto
else:
# Handle other content-types here or raise an Exception
# if the content type is not supported.
pass
def keras_model_fn(hyperparameters):
inp1 = tensorflow.keras.layers.Input(shape=(EMBED_DIM,), name='input1')
inp2 = tensorflow.keras.layers.Input(shape=(EMBED_DIM,), name='input2')
merge_layer = tensorflow.keras.layers.Multiply(
name='multiply')([inp1, inp2])
class_layer = tensorflow.keras.layers.Dense(
2, input_shape=(EMBED_DIM,), activation='softmax', name='classes')(merge_layer)
model = tensorflow.keras.models.Model(
inputs=[inp1, inp2], outputs=[class_layer])
model.compile(optimizer='sgd',
loss={'classes': 'categorical_crossentropy'},
metrics={'classes': 'accuracy'})
return model
if __name__ == "__main__":
# do the sagemaker thang
import sagemaker
from sagemaker.tensorflow import TensorFlow
import boto3
role = sagemaker.Session(boto_session=boto3.Session(
profile_name='sagemaker')).get_caller_identity_arn()
tf_estimator = TensorFlow(entry_point=__file__, role=role,
training_steps=N_EPOCHS *
(DATA_SIZE / BATCH_SIZE),
evaluation_steps=1, train_instance_count=1,
train_instance_type='local',
# seems to be broken for 1.10
framework_version='1.9.0',
checkpoint_path='s3://sagemaker-omc-test/test/sagemaker_multiin/checkpoints',
base_job_name='sagemaker-multiin-test')
tf_estimator.fit('s3://sagemaker-omc-test/test/sagemaker_multiin/data')
#import pdb; pdb.set_trace()
predictor = tf_estimator.deploy(
initial_instance_count=1, instance_type='local',
endpoint_name='sagemaker_multiin_test')
prediction = predictor.predict({
'input1': numpy.random.rand(1, EMBED_DIM).astype('float32').tolist(),
'input2': numpy.random.rand(1, EMBED_DIM).astype('float32').tolist(),
})
predictor.delete_endpoint()
print(prediction)
@zmjjmz
Copy link
Author

zmjjmz commented Sep 11, 2018

Obviously you'll have to change the S3 bucket and roles to be your own if you want to reproduce this. If you remove the input_fn you'll get a ERROR in serving: Unsupported request data format:, but it will work if you include it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment