Created
October 5, 2019 16:18
-
-
Save aymericdelab/719a4a406283d021e25c6c59c0944137 to your computer and use it in GitHub Desktop.
What we need to add to an entry point script so that the model that we run on AWS can also run on Azure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
##ARGPARSE | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data-folder', type=str, dest='data_folder', help='data folder mounting point') | |
parser.add_argument('--batch-size', type=int, dest='batch_size', default=50, help='mini batch size for training') | |
parser.add_argument('--learning-rate', type=float, dest='learning_rate', default=0.001, help='learning rate') | |
parser.add_argument('--prefix', type=str, dest='prefix', help='target path when uploading data to Azure storage') | |
parser.add_argument('--steps', type=int, dest='steps', help = 'number of steps') | |
args = parser.parse_args() | |
data_folder = args.data_folder | |
batch_size = args.batch_size | |
learning_rate = args.learning_rate | |
prefix = args.prefix | |
steps = args.steps | |
print('training dataset is stored here:', data_folder) | |
#show the logging hooks during the tf.estimator training. | |
tf.logging.set_verbosity(tf.logging.INFO) | |
#store the output of the model in the azure blob storage | |
output_dir='tmp/output/' | |
#parameters given to model_fn and train_input_fn | |
params={'learning_rate' : learning_rate, | |
'batch_size' : batch_size} | |
#create a tf.estimator object | |
estimator = estimator=tf.estimator.Estimator(model_fn = model_fn, | |
model_dir = output_dir, | |
params = params) | |
#specify which values we want to check during training | |
tensors_to_log = {"probabilities": "softmax_tensor"} | |
logging_hook = tf.train.LoggingTensorHook( | |
tensors=tensors_to_log, | |
every_n_iter=50) | |
#the location of our data | |
training_dir = os.path.join(data_folder,prefix) | |
#train the tf.estimator | |
estimator.train( | |
input_fn=train_input_fn(training_dir, params), | |
steps=steps, | |
hooks=[logging_hook]) | |
#save our model on azus storage so that we can retrieve it in our local notebook for evaluation/deployment | |
estimator.export_saved_model('saved_model', serving_input_fn) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment