Created
September 11, 2018 23:28
-
-
Save zmjjmz/ba033696f7d84df119725ff187817b22 to your computer and use it in GitHub Desktop.
Sagemaker Multi-input repro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import numpy | |
import tensorflow | |
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn | |
print("Tensorflow version: {0}".format(tensorflow.VERSION)) | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
DATA_SIZE = 1024 | |
BATCH_SIZE = 32 | |
N_EPOCHS = 1 | |
EMBED_DIM = 5 | |
DATA_DICT = { | |
'input1': numpy.random.rand(DATA_SIZE, EMBED_DIM).astype('float32'), | |
'input2': numpy.random.rand(DATA_SIZE, EMBED_DIM).astype('float32'), | |
} | |
LABELS = tensorflow.keras.utils.to_categorical( | |
numpy.random.randint(2, size=(DATA_SIZE,)), num_classes=2) | |
def get_input_fn(): | |
return tensorflow.estimator.inputs.numpy_input_fn( | |
x=DATA_DICT, y=LABELS, | |
shuffle=True, | |
batch_size=BATCH_SIZE, | |
num_epochs=N_EPOCHS, | |
) | |
def train_input_fn(training_dir, params): | |
return get_input_fn()() | |
def eval_input_fn(training_dir, params): | |
return get_input_fn()() | |
def serving_input_fn(params): | |
# annoyingly I can't really give this the model and have it pull the name | |
# from the input layer | |
tensor1 = tensorflow.placeholder( | |
tensorflow.float32, shape=[None, EMBED_DIM]) | |
tensor2 = tensorflow.placeholder( | |
tensorflow.float32, shape=[None, EMBED_DIM]) | |
return build_raw_serving_input_receiver_fn({ | |
'input1': tensor1, | |
'input2': tensor2, | |
})() | |
def input_fn(serialized_input, content_type): | |
"""An input_fn that loads a pickled object""" | |
if content_type == "application/json": | |
deserialized_input = json.loads(serialized_input) | |
if isinstance(deserialized_input, dict): | |
deserialized_tensorproto = { | |
k:tensorflow.make_tensor_proto(v) | |
for k, v in deserialized_input.items() | |
} | |
else: | |
deserialized_tensorproto = tensorflow.make_tensor_proto(serialized_input) | |
return deserialized_tensorproto | |
else: | |
# Handle other content-types here or raise an Exception | |
# if the content type is not supported. | |
pass | |
def keras_model_fn(hyperparameters): | |
inp1 = tensorflow.keras.layers.Input(shape=(EMBED_DIM,), name='input1') | |
inp2 = tensorflow.keras.layers.Input(shape=(EMBED_DIM,), name='input2') | |
merge_layer = tensorflow.keras.layers.Multiply( | |
name='multiply')([inp1, inp2]) | |
class_layer = tensorflow.keras.layers.Dense( | |
2, input_shape=(EMBED_DIM,), activation='softmax', name='classes')(merge_layer) | |
model = tensorflow.keras.models.Model( | |
inputs=[inp1, inp2], outputs=[class_layer]) | |
model.compile(optimizer='sgd', | |
loss={'classes': 'categorical_crossentropy'}, | |
metrics={'classes': 'accuracy'}) | |
return model | |
if __name__ == "__main__": | |
# do the sagemaker thang | |
import sagemaker | |
from sagemaker.tensorflow import TensorFlow | |
import boto3 | |
role = sagemaker.Session(boto_session=boto3.Session( | |
profile_name='sagemaker')).get_caller_identity_arn() | |
tf_estimator = TensorFlow(entry_point=__file__, role=role, | |
training_steps=N_EPOCHS * | |
(DATA_SIZE / BATCH_SIZE), | |
evaluation_steps=1, train_instance_count=1, | |
train_instance_type='local', | |
# seems to be broken for 1.10 | |
framework_version='1.9.0', | |
checkpoint_path='s3://sagemaker-omc-test/test/sagemaker_multiin/checkpoints', | |
base_job_name='sagemaker-multiin-test') | |
tf_estimator.fit('s3://sagemaker-omc-test/test/sagemaker_multiin/data') | |
#import pdb; pdb.set_trace() | |
predictor = tf_estimator.deploy( | |
initial_instance_count=1, instance_type='local', | |
endpoint_name='sagemaker_multiin_test') | |
prediction = predictor.predict({ | |
'input1': numpy.random.rand(1, EMBED_DIM).astype('float32').tolist(), | |
'input2': numpy.random.rand(1, EMBED_DIM).astype('float32').tolist(), | |
}) | |
predictor.delete_endpoint() | |
print(prediction) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Obviously you'll have to change the S3 bucket and roles to be your own if you want to reproduce this. If you remove the
input_fn
you'll get aERROR in serving: Unsupported request data format:
, but it will work if you include it.