Skip to content

Instantly share code, notes, and snippets.

@jinliangwei
Last active April 12, 2020 02:41
Show Gist options
  • Save jinliangwei/eed8fd564a35deae3b892092f3171866 to your computer and use it in GitHub Desktop.
Save jinliangwei/eed8fd564a35deae3b892092f3171866 to your computer and use it in GitHub Desktop.
diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py
index 30b00df..8d98c45 100644
--- a/tensor2tensor/utils/trainer_lib.py
+++ b/tensor2tensor/utils/trainer_lib.py
@@ -85,6 +85,8 @@ def create_hparams(hparams_set,
data_dir=None,
problem_name=None):
"""Create HParams with data_dir and problem hparams, if kwargs provided."""
+ print("create_hparams called, hparams_set = ", hparams_set)
+
hparams = registry.hparams(hparams_set)
if data_dir:
hparams.add_hparam("data_dir", data_dir)
@@ -165,7 +167,6 @@ def create_run_config(master="",
initial_infeed_sleep_secs=tpu_infeed_sleep_secs,
**tpu_config_extra_kwargs)
run_config_args["tpu_config"] = tpu_config
-
config = run_config_cls(**run_config_args)
# If not using TPU, add device info for data_parallelism
@@ -199,6 +200,7 @@ def create_estimator(model_name,
decode_hparams=None,
use_tpu=False):
"""Create a T2T Estimator."""
+ tf.logging.info("create_estimator, model_name = %s" % model_name)
model_fn = t2t_model.T2TModel.make_estimator_model_fn(
model_name, hparams, decode_hparams=decode_hparams, use_tpu=use_tpu)
@@ -269,13 +271,14 @@ class T2TExperiment(object):
"""Custom Experiment class for running distributed experiments."""
def __init__(self, estimator, hparams, train_spec, eval_spec,
- use_validation_monitor, decode_hparams=None):
+ use_validation_monitor, decode_hparams=None, server=None):
self._train_spec = train_spec
self._eval_spec = eval_spec
self._hparams = hparams
self._decode_hparams = decode_hparams
self._estimator = estimator
self._use_validation_monitor = use_validation_monitor
+ self._server = server
@property
def estimator(self):
@@ -352,6 +355,7 @@ class T2TExperiment(object):
ValueError: if not enough information is available in the estimator's
config to create a server.
"""
+ tf.logging.info("run_std_server called")
config = self._estimator.config
if (not config.cluster_spec or not config.task_type or not config.master or
config.task_id is None):
@@ -378,6 +382,17 @@ class T2TExperiment(object):
self.decode()
+def create_tf_server(config):
+ #tf.logging.info("task_type=%s", config.task_type,
+ # "task_id=%d", config.task_id)
+ server = tf.train.Server(
+ config.cluster_spec,
+ job_name=config.task_type,
+ task_index=config.task_id,
+ config=config.tf_config,
+ start=True)
+ return server
+
def create_experiment(
run_config,
hparams,
@@ -408,6 +423,11 @@ def create_experiment(
hparams.add_hparam("schedule", schedule)
add_problem_hparams(hparams, problem_name)
+ server = None
+ if getattr(run_config, "cluster_spec") and \
+ schedule != "run_std_server":
+ server = create_tf_server(run_config)
+
# Estimator
estimator = create_estimator(
model_name,
@@ -457,13 +477,14 @@ def create_experiment(
use_early_stopping = (
schedule not in local_schedules and eval_early_stopping_steps)
train_hooks, eval_hooks = create_hooks(
- use_tfdbg=use_tfdbg,
- use_dbgprofile=use_dbgprofile,
- dbgprofile_kwargs=dbgprofile_kwargs,
- use_validation_monitor=use_validation_monitor,
- validation_monitor_kwargs=validation_monitor_kwargs,
- use_early_stopping=use_early_stopping,
- early_stopping_kwargs=early_stopping_kwargs)
+ use_tfdbg=use_tfdbg,
+ use_dbgprofile=use_dbgprofile,
+ dbgprofile_kwargs=dbgprofile_kwargs,
+ use_validation_monitor=use_validation_monitor,
+ validation_monitor_kwargs=validation_monitor_kwargs,
+ use_early_stopping=use_early_stopping,
+ early_stopping_kwargs=early_stopping_kwargs
+ )
train_hooks += t2t_model.T2TModel.get_train_hooks(model_name)
eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name)
@@ -494,7 +515,8 @@ def create_experiment(
eval_delay_secs=0 if schedule == "evaluate" else 120,
**hooks_kwargs if not use_tpu else {})
return T2TExperiment(estimator, hparams, train_spec, eval_spec,
- use_validation_monitor, decode_hparams)
+ use_validation_monitor, decode_hparams,
+ server)
def create_experiment_fn(*args, **kwargs):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment