Created
May 29, 2020 02:42
-
-
Save ultrons/ab5fef918b23ae3e938f11e5e1b50ad2 to your computer and use it in GitHub Desktop.
Moved Cross replica sum to loss interval 'if'
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/fairseq/trainer.py b/fairseq/trainer.py | |
index 377ce99b..22756df6 100644 | |
--- a/fairseq/trainer.py | |
+++ b/fairseq/trainer.py | |
@@ -20,6 +20,7 @@ from fairseq.file_io import PathManager | |
from fairseq.logging import meters, metrics | |
from fairseq.nan_detector import NanDetector | |
from fairseq.optim import lr_scheduler | |
+from fairseq.metsumm import metsumm | |
logger = logging.getLogger(__name__) | |
@@ -38,6 +39,8 @@ class Trainer(object): | |
def __init__(self, args, task, model, criterion, quantizer=None): | |
self.args = args | |
self.task = task | |
+ self.logging_history = [] | |
+ self.cumm_sample_size = 0 | |
# catalog shared parameters | |
shared_params = _catalog_shared_params(model) | |
@@ -411,6 +414,7 @@ class Trainer(object): | |
logging_outputs.append(logging_output) | |
sample_size += sample_size_i | |
+## print("DEBUG_MESSAGE: Inputshape", sample["net_input"]["source"].size()) | |
# emptying the CUDA cache after the first step can | |
# reduce the chance of OOM | |
@@ -450,7 +454,9 @@ class Trainer(object): | |
sample_size = float(sample_size) | |
# gather logging outputs from all replicas | |
- if self._sync_stats(): | |
+ if self._sync_stats() and not self.tpu: | |
+ #if self._sync_stats(): | |
+ | |
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( | |
logging_outputs, sample_size, ooms, ignore=is_dummy_batch, | |
) | |
@@ -520,11 +526,21 @@ class Trainer(object): | |
# only log stats every log_interval steps | |
# this causes wps to be misreported when log_interval > 1 | |
+ self.logging_history += logging_outputs | |
+ self.cumm_sample_size += sample_size | |
logging_output = {} | |
if self.get_num_updates() % self.args.log_interval == 0: | |
+ # Doing Cross replica reduce first | |
+ logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( | |
+ self.logging_history, self.cumm_sample_size, ooms, ignore=is_dummy_batch, | |
+ ) | |
+ # then log stats | |
logging_output = self._reduce_and_log_stats( | |
logging_outputs, sample_size, grad_norm, | |
+ #self.logging_history, self.cumm_sample_size, grad_norm, | |
) | |
+ self.logging_history = [] | |
+ self.cumm_sample_size = 0 | |
# log whenever there's an XLA compilation, since these | |
# slow down training and may indicate opportunities for |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment