Skip to content

Instantly share code, notes, and snippets.

@ultrons
Created May 29, 2020 02:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ultrons/ab5fef918b23ae3e938f11e5e1b50ad2 to your computer and use it in GitHub Desktop.
Save ultrons/ab5fef918b23ae3e938f11e5e1b50ad2 to your computer and use it in GitHub Desktop.
Moved Cross replica sum to loss interval 'if'
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index 377ce99b..22756df6 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -20,6 +20,7 @@ from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
+from fairseq.metsumm import metsumm
logger = logging.getLogger(__name__)
@@ -38,6 +39,8 @@ class Trainer(object):
def __init__(self, args, task, model, criterion, quantizer=None):
self.args = args
self.task = task
+ self.logging_history = []
+ self.cumm_sample_size = 0
# catalog shared parameters
shared_params = _catalog_shared_params(model)
@@ -411,6 +414,7 @@ class Trainer(object):
logging_outputs.append(logging_output)
sample_size += sample_size_i
+## print("DEBUG_MESSAGE: Inputshape", sample["net_input"]["source"].size())
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
@@ -450,7 +454,9 @@ class Trainer(object):
sample_size = float(sample_size)
# gather logging outputs from all replicas
- if self._sync_stats():
+ if self._sync_stats() and not self.tpu:
+ #if self._sync_stats():
+
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
)
@@ -520,11 +526,21 @@ class Trainer(object):
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
+ self.logging_history += logging_outputs
+ self.cumm_sample_size += sample_size
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
+ # Doing Cross replica reduce first
+ logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
+ self.logging_history, self.cumm_sample_size, ooms, ignore=is_dummy_batch,
+ )
+ # then log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
+ #self.logging_history, self.cumm_sample_size, grad_norm,
)
+ self.logging_history = []
+ self.cumm_sample_size = 0
# log whenever there's an XLA compilation, since these
# slow down training and may indicate opportunities for
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment