Skip to content

Instantly share code, notes, and snippets.

@vlasenkoalexey
Last active March 13, 2020 18:13
Show Gist options
  • Save vlasenkoalexey/323fc431a024fd1c42496d78cbe4d6a5 to your computer and use it in GitHub Desktop.
Save vlasenkoalexey/323fc431a024fd1c42496d78cbe4d6a5 to your computer and use it in GitHub Desktop.
def train_estimator_linear(model_dir):
global ARGS
logging.info('training for {} steps'.format(get_max_steps()))
config = tf.estimator.RunConfig().replace(save_summary_steps=10)
hooks = []
if ARGS.profiler:
profiler_hook = tf.estimator.ProfilerHook(
save_steps=get_training_steps_per_epoch(),
output_dir=os.path.join(model_dir, "profiler"),
show_dataflow=True,
show_memory=True)
hooks.append(profiler_hook)
feature_columns = create_feature_columns()
estimator = tf.estimator.LinearClassifier(
feature_columns=feature_columns,
optimizer=GradientDescentOptimizer(learning_rate=0.001),
model_dir=model_dir,
config=config
)
logging.info('training and evaluating linear estimator model')
tf.estimator.train_and_evaluate(
estimator,
train_spec=tf.estimator.TrainSpec(input_fn=lambda: get_dataset('train'),
max_steps=get_max_steps(),
hooks=hooks),
eval_spec=tf.estimator.EvalSpec(input_fn=lambda: get_dataset('test')))
logging.info('done evaluating estimator model')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment