Last active
April 12, 2020 02:41
-
-
Save jinliangwei/eed8fd564a35deae3b892092f3171866 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py | |
index 30b00df..8d98c45 100644 | |
--- a/tensor2tensor/utils/trainer_lib.py | |
+++ b/tensor2tensor/utils/trainer_lib.py | |
@@ -85,6 +85,8 @@ def create_hparams(hparams_set, | |
data_dir=None, | |
problem_name=None): | |
"""Create HParams with data_dir and problem hparams, if kwargs provided.""" | |
+ print("create_hparams called, hparams_set = ", hparams_set) | |
+ | |
hparams = registry.hparams(hparams_set) | |
if data_dir: | |
hparams.add_hparam("data_dir", data_dir) | |
@@ -165,7 +167,6 @@ def create_run_config(master="", | |
initial_infeed_sleep_secs=tpu_infeed_sleep_secs, | |
**tpu_config_extra_kwargs) | |
run_config_args["tpu_config"] = tpu_config | |
- | |
config = run_config_cls(**run_config_args) | |
# If not using TPU, add device info for data_parallelism | |
@@ -199,6 +200,7 @@ def create_estimator(model_name, | |
decode_hparams=None, | |
use_tpu=False): | |
"""Create a T2T Estimator.""" | |
+ tf.logging.info("create_estimator, model_name = %s" % model_name) | |
model_fn = t2t_model.T2TModel.make_estimator_model_fn( | |
model_name, hparams, decode_hparams=decode_hparams, use_tpu=use_tpu) | |
@@ -269,13 +271,14 @@ class T2TExperiment(object): | |
"""Custom Experiment class for running distributed experiments.""" | |
def __init__(self, estimator, hparams, train_spec, eval_spec, | |
- use_validation_monitor, decode_hparams=None): | |
+ use_validation_monitor, decode_hparams=None, server=None): | |
self._train_spec = train_spec | |
self._eval_spec = eval_spec | |
self._hparams = hparams | |
self._decode_hparams = decode_hparams | |
self._estimator = estimator | |
self._use_validation_monitor = use_validation_monitor | |
+ self._server = server | |
@property | |
def estimator(self): | |
@@ -352,6 +355,7 @@ class T2TExperiment(object): | |
ValueError: if not enough information is available in the estimator's | |
config to create a server. | |
""" | |
+ tf.logging.info("run_std_server called") | |
config = self._estimator.config | |
if (not config.cluster_spec or not config.task_type or not config.master or | |
config.task_id is None): | |
@@ -378,6 +382,17 @@ class T2TExperiment(object): | |
self.decode() | |
+def create_tf_server(config): | |
+ #tf.logging.info("task_type=%s", config.task_type, | |
+ # "task_id=%d", config.task_id) | |
+ server = tf.train.Server( | |
+ config.cluster_spec, | |
+ job_name=config.task_type, | |
+ task_index=config.task_id, | |
+ config=config.tf_config, | |
+ start=True) | |
+ return server | |
+ | |
def create_experiment( | |
run_config, | |
hparams, | |
@@ -408,6 +423,11 @@ def create_experiment( | |
hparams.add_hparam("schedule", schedule) | |
add_problem_hparams(hparams, problem_name) | |
+ server = None | |
+ if getattr(run_config, "cluster_spec") and \ | |
+ schedule != "run_std_server": | |
+ server = create_tf_server(run_config) | |
+ | |
# Estimator | |
estimator = create_estimator( | |
model_name, | |
@@ -457,13 +477,14 @@ def create_experiment( | |
use_early_stopping = ( | |
schedule not in local_schedules and eval_early_stopping_steps) | |
train_hooks, eval_hooks = create_hooks( | |
- use_tfdbg=use_tfdbg, | |
- use_dbgprofile=use_dbgprofile, | |
- dbgprofile_kwargs=dbgprofile_kwargs, | |
- use_validation_monitor=use_validation_monitor, | |
- validation_monitor_kwargs=validation_monitor_kwargs, | |
- use_early_stopping=use_early_stopping, | |
- early_stopping_kwargs=early_stopping_kwargs) | |
+ use_tfdbg=use_tfdbg, | |
+ use_dbgprofile=use_dbgprofile, | |
+ dbgprofile_kwargs=dbgprofile_kwargs, | |
+ use_validation_monitor=use_validation_monitor, | |
+ validation_monitor_kwargs=validation_monitor_kwargs, | |
+ use_early_stopping=use_early_stopping, | |
+ early_stopping_kwargs=early_stopping_kwargs | |
+ ) | |
train_hooks += t2t_model.T2TModel.get_train_hooks(model_name) | |
eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name) | |
@@ -494,7 +515,8 @@ def create_experiment( | |
eval_delay_secs=0 if schedule == "evaluate" else 120, | |
**hooks_kwargs if not use_tpu else {}) | |
return T2TExperiment(estimator, hparams, train_spec, eval_spec, | |
- use_validation_monitor, decode_hparams) | |
+ use_validation_monitor, decode_hparams, | |
+ server) | |
def create_experiment_fn(*args, **kwargs): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment