Skip to content

Instantly share code, notes, and snippets.

@aymericdelab
Created October 5, 2019 16:18
Show Gist options
  • Save aymericdelab/719a4a406283d021e25c6c59c0944137 to your computer and use it in GitHub Desktop.
Save aymericdelab/719a4a406283d021e25c6c59c0944137 to your computer and use it in GitHub Desktop.
What we need to add to an entry point script so that the model that we run on AWS can also run on Azure
def main():
##ARGPARSE
parser = argparse.ArgumentParser()
parser.add_argument('--data-folder', type=str, dest='data_folder', help='data folder mounting point')
parser.add_argument('--batch-size', type=int, dest='batch_size', default=50, help='mini batch size for training')
parser.add_argument('--learning-rate', type=float, dest='learning_rate', default=0.001, help='learning rate')
parser.add_argument('--prefix', type=str, dest='prefix', help='target path when uploading data to Azure storage')
parser.add_argument('--steps', type=int, dest='steps', help = 'number of steps')
args = parser.parse_args()
data_folder = args.data_folder
batch_size = args.batch_size
learning_rate = args.learning_rate
prefix = args.prefix
steps = args.steps
print('training dataset is stored here:', data_folder)
#show the logging hooks during the tf.estimator training.
tf.logging.set_verbosity(tf.logging.INFO)
#store the output of the model in the azure blob storage
output_dir='tmp/output/'
#parameters given to model_fn and train_input_fn
params={'learning_rate' : learning_rate,
'batch_size' : batch_size}
#create a tf.estimator object
estimator = estimator=tf.estimator.Estimator(model_fn = model_fn,
model_dir = output_dir,
params = params)
#specify which values we want to check during training
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=50)
#the location of our data
training_dir = os.path.join(data_folder,prefix)
#train the tf.estimator
estimator.train(
input_fn=train_input_fn(training_dir, params),
steps=steps,
hooks=[logging_hook])
#save our model on azus storage so that we can retrieve it in our local notebook for evaluation/deployment
estimator.export_saved_model('saved_model', serving_input_fn)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment