Skip to content

Instantly share code, notes, and snippets.

@lukmanr
Last active October 29, 2018 17:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukmanr/d3efc692da61a5f48f856bbde644ce1e to your computer and use it in GitHub Desktop.
Save lukmanr/d3efc692da61a5f48f856bbde644ce1e to your computer and use it in GitHub Desktop.
TF Model Optimization code 2
def run_experiment(hparams, train_data, train_labels, run_config, create_estimator_fn=create_estimator):
train_spec = tf.estimator.TrainSpec(
input_fn = tf.estimator.inputs.numpy_input_fn(
x={'input_image': train_data},
y=train_labels,
batch_size=hparams.batch_size,
num_epochs=None,
shuffle=True),
max_steps=hparams.max_training_steps
)
eval_spec = tf.estimator.EvalSpec(
input_fn = tf.estimator.inputs.numpy_input_fn(
x={'input_image': train_data},
y=train_labels,
batch_size=hparams.batch_size,
num_epochs=1,
shuffle=False),
steps=None,
throttle_secs=hparams.eval_throttle_secs
)
tf.logging.set_verbosity(tf.logging.INFO)
time_start = datetime.utcnow()
print('Experiment started at {}'.format(time_start.strftime('%H:%M:%S')))
print('.......................................')
estimator = create_estimator_fn(hparams, run_config)
tf.estimator.train_and_evaluate(
estimator=estimator,
train_spec=train_spec,
eval_spec=eval_spec
)
time_end = datetime.utcnow()
print('.......................................')
print('Experiment finished at {}'.format(time_end.strftime('%H:%M:%S')))
print('')
time_elapsed = time_end - time_start
print('Experiment elapsed time: {} seconds'.format(time_elapsed.total_seconds()))
return estimator
def train_and_export_model(train_data, train_labels):
model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME)
hparams = tf.contrib.training.HParams(
batch_size=100,
hidden_units=[1024],
num_conv_layers=2,
init_filters=64,
dropout=0.85,
max_training_steps=50,
eval_throttle_secs=10,
learning_rate=1e-3,
debug=True
)
run_config = tf.estimator.RunConfig(
tf_random_seed=19830610,
save_checkpoints_steps=1000,
keep_checkpoint_max=3,
model_dir=model_dir
)
if tf.gfile.Exists(model_dir):
print('Removing previous artifacts...')
tf.gfile.DeleteRecursively(model_dir)
os.makedirs(model_dir)
estimator = run_experiment(hparams, train_data, train_labels, run_config, create_estimator_keras)
def make_serving_input_receiver_fn():
inputs = {'input_image': tf.placeholder(
shape=[None,28,28], dtype=tf.float32, name='serving_input_image')}
return tf.estimator.export.build_raw_serving_input_receiver_fn(inputs)
export_dir = os.path.join(model_dir, 'export')
if tf.gfile.Exists(export_dir):
tf.gfile.DeleteRecursively(export_dir)
estimator.export_savedmodel(
export_dir_base=export_dir,
serving_input_receiver_fn=make_serving_input_receiver_fn()
)
return export_dir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment