Skip to content

Instantly share code, notes, and snippets.

@hanneshapke
Created March 9, 2020 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hanneshapke/b91e10de15b9ef3eae10f9928f040b77 to your computer and use it in GitHub Desktop.
Save hanneshapke/b91e10de15b9ef3eae10f9928f040b77 to your computer and use it in GitHub Desktop.
def run_fn(fn_args: TrainerFnArgs):
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
fn_args.train_files, tf_transform_output, 32)
eval_dataset = _input_fn(
fn_args.eval_files, tf_transform_output, 32)
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = get_model(tf_transform_output=tf_transform_output)
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
signatures = {
'serving_default':
_get_serve_tf_examples_fn(model, tf_transform_output
).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name='examples')),
}
model.save(
fn_args.serving_model_dir,
save_format='tf',
signatures=signatures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment