Skip to content

Instantly share code, notes, and snippets.

@adiell
Created March 25, 2019 09:51
Show Gist options
  • Save adiell/0d20212a20ae3cabe38fadf25ebddd1b to your computer and use it in GitHub Desktop.
Save adiell/0d20212a20ae3cabe38fadf25ebddd1b to your computer and use it in GitHub Desktop.
Sagemaker example
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.optimizers import Adam
import numpy as np
import argparse
import os
import json
import logging
import tensorflow as tf
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model.signature_def_utils import predict_signature_def
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.keras import backend as K
logging.basicConfig(level=logging.INFO)
def parse_args():
global args
parser = argparse.ArgumentParser()
parser.add_argument('--train-data-dir', type=str, default=os.environ.get('SM_CHANNEL_TRAINING'))
parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'))
parser.add_argument('--model-data-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
parser.add_argument('--training-env', type=str, default=os.environ.get('SM_TRAINING_ENV'))
args, _ = parser.parse_known_args()
def define_model():
model = Sequential()
model.add(Dense(10, input_shape=(2,), name = 'input_layer'))
model.add(Dense(1, name = 'output_layer'))
opt = Adam()
model.compile(optimizer=opt, loss = 'mse')
print(model.summary())
return model
def build_model(X, y):
model = define_model()
model.fit(X, y,
batch_size=1, epochs=10,
verbose=2
)
return model
def generate_data(seed = 23):
np.random.seed(seed)
X = np.random.randn(300,2)
y = 2 * X[:, 0] + X[:, 1]
return X, y
def save_model_using_simple_save(model, export_path):
print("Saving using simple save to " + export_path)
with tf.keras.backend.get_session() as sess:
tf.saved_model.simple_save(
sess,
export_path,
inputs={'input': model.input},
outputs={'output': model.output})
def save_model_using_builder(model, export_path):
print("Saving using builder to " + export_path)
saved_model_builder = builder.SavedModelBuilder(export_path)
signature = predict_signature_def(
inputs={"inputs": model.input}, outputs={"score": model.output})
with K.get_session() as sess:
# Save the meta graph and variables
saved_model_builder.add_meta_graph_and_variables(
sess=sess, tags=[tag_constants.SERVING], signature_def_map={"serving_default": signature})
saved_model_builder.save()
if __name__ == '__main__':
parse_args()
hyperparameters = json.loads(args.training_env)['hyperparameters']
X, y= generate_data()
model = build_model(X, y)
export_path = os.path.join(args.model_data_dir, 'export', 'Servo','1')
save_model_using_builder(model, export_path)
import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker.transformer import Transformer
import pandas as pd
role = 'XXX'
if __name__ == '__main__':
estimator = TensorFlow(entry_point='main_script.py',
role=role,
py_version='py3',
framework_version='1.12.0',
train_instance_count=1,
train_instance_type='ml.m4.xlarge',
base_job_name='demo'
)
estimator.fit(wait=True)
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', endpoint_name='my-demo')
preds = predictor.predict([[1, 1], [2, 1], [3, 1]])
print(preds)
predictor.delete_endpoint()
transformer = Transformer(
base_transform_job_name='Batch-Transform-demo',
model_name=estimator._current_job_name,
instance_count=1,
instance_type='ml.c4.xlarge'
)
pd.DataFrame({'x1': 100 * [0.5], 'x2': 100 * [0.5]}).to_csv('small_file.csv', header=False, index=False)
pd.DataFrame({'x1': 1000000 * [0.5], 'x2': 1000000 * [0.5]}).to_csv('large_file.csv', header=False, index=False)
session = sagemaker.Session()
small_file = session.upload_data(path='small_file.csv', key_prefix='transformer_demo')
large_file = session.upload_data(path='large_file.csv', key_prefix='transformer_demo')
transformer.transform(small_file, content_type='text/csv', split_type='Line')
transformer.transform(large_file, content_type='text/csv', split_type='Line')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment