Skip to content

Instantly share code, notes, and snippets.

@ultrons
Last active May 26, 2020 19:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ultrons/e0fd10e340aaf3b4efaa0e5b3d1c5370 to your computer and use it in GitHub Desktop.
Save ultrons/e0fd10e340aaf3b4efaa0e5b3d1c5370 to your computer and use it in GitHub Desktop.
DIff of Intermediate Changes
diff --git a/fairseq/criterions/binary_cross_entropy.py b/fairseq/criterions/binary_cross_entropy.py
index 557f50bd..3b2d0d0f 100644
--- a/fairseq/criterions/binary_cross_entropy.py
+++ b/fairseq/criterions/binary_cross_entropy.py
@@ -8,8 +8,9 @@ import math
import torch
import torch.nn.functional as F
-from fairseq import utils
+from fairseq import utils, metrics
from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.metsumm import metsumm
@register_criterion('binary_cross_entropy')
@@ -41,7 +42,7 @@ class BinaryCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
- logits = model.get_logits(net_output).float()
+ logits = model.get_logits(net_output).float() #DEB: No incr
target = model.get_targets(sample, net_output)
weights = None
@@ -57,7 +58,10 @@ class BinaryCrossEntropyCriterion(FairseqCriterion):
else:
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
- sample_size = target.numel() if self.infonce else target.long().sum().item()
+ metsumm("DEBUG_MESSAGE: Before target sum long")
+ sample_size = target.numel() if self.infonce else target.sum().long()
+ metsumm("DEBUG_MESSAGE: After target sum long")
+ #sample_size = target.numel()
losses.append(loss)
if self.loss_weights is not None and hasattr(model, "get_extra_losses"):
@@ -73,8 +77,9 @@ class BinaryCrossEntropyCriterion(FairseqCriterion):
loss += p
losses.append(p)
+ #'loss': loss.item() if reduce else loss,
logging_output = {
- 'loss': loss.item() if reduce else loss,
+ 'loss': loss.data,
'ntokens': sample_size,
'nsentences': logits.size(0),
'sample_size': sample_size,
@@ -82,11 +87,13 @@ class BinaryCrossEntropyCriterion(FairseqCriterion):
for lk in self.log_keys:
if lk in net_output:
- logging_output[lk] = float((net_output[lk]))
+ #logging_output[lk] = float((net_output[lk]))
+ logging_output[lk] = net_output[lk] #float causes 3 comp, 2 aten calls
if len(losses) > 1:
for i, l in enumerate(losses):
- logging_output[f'loss_{i}'] = l.item()
+ #logging_output[f'loss_{i}'] = l.item()
+ logging_output[f'loss_{i}'] = l
if self.infonce:
with torch.no_grad():
@@ -109,34 +116,64 @@ class BinaryCrossEntropyCriterion(FairseqCriterion):
logging_output['target'] = target.cpu().numpy()
return loss, sample_size, logging_output
+# @staticmethod
+# def aggregate_logging_outputs(logging_outputs):
+# """Aggregate logging outputs from data parallel training."""
+# loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
+# ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
+# nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
+# sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))
+# agg_output = {
+# 'loss': loss_sum / sample_size / math.log(2),
+# 'ntokens': ntokens,
+# 'nsentences': nsentences,
+# 'sample_size': sample_size,
+# }
+# if sample_size != ntokens:
+# agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
+#
+# correct = sum(log.get("correct", 0) for log in logging_outputs)
+# total = sum(log.get("count", 0) for log in logging_outputs)
+# if total > 0:
+# agg_output['accuracy'] = correct / total
+#
+# builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
+#
+# for k in logging_outputs[0]:
+# if k not in builtin_keys:
+# val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs)
+# if k.startswith('loss'):
+# val = val / ntokens if ntokens > 0 else float('nan')
+# agg_output[k] = val
+#
+# return agg_output
+
@staticmethod
- def aggregate_logging_outputs(logging_outputs):
+ def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
- loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
- ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
- nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
- sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))
- agg_output = {
- 'loss': loss_sum / sample_size / math.log(2),
- 'ntokens': ntokens,
- 'nsentences': nsentences,
- 'sample_size': sample_size,
- }
+ metsumm("DEBUG_MESSAGE: before loss extract")
+ #loss_ = [log['loss'].item() for log in logging_outputs] # Result in 6X fewer aten:_local_scalar_dense
+ #loss_ = torch.stack([log['loss'] for log in logging_outputs]) # Result in 6X fewer aten:_local_scalar_dense
+ #loss_sum = loss_.sum()
+ loss_ = [log.get('loss', 0) for log in logging_outputs]
+ loss_sum = sum(loss_)
+ metsumm("DEBUG_MESSAGE: After loss extract {}".format(loss_))
+ ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
+ sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
if sample_size != ntokens:
- agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
-
+ metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), round=3)
+ metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
+ metsumm("DEBUG_MESSAGE: After loss scaling")
correct = sum(log.get("correct", 0) for log in logging_outputs)
total = sum(log.get("count", 0) for log in logging_outputs)
if total > 0:
- agg_output['accuracy'] = correct / total
-
- builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
-
- for k in logging_outputs[0]:
- if k not in builtin_keys:
- val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs)
- if k.startswith('loss'):
- val = val / ntokens if ntokens > 0 else float('nan')
- agg_output[k] = val
+ metrics.log_scalar('accuracy', correct / total, round=3)
- return agg_output
+ @staticmethod
+ def logging_outputs_can_be_summed() -> bool:
+ """
+ Whether the logging outputs returned by `forward` can be summed
+ across workers prior to calling `reduce_metrics`. Setting this
+ to True will improves distributed training speed.
+ """
+ return True
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py
index 40cbc206..e0616ea5 100644
--- a/fairseq/data/audio/raw_audio_dataset.py
+++ b/fairseq/data/audio/raw_audio_dataset.py
@@ -151,6 +151,13 @@ class FileAudioDataset(RawAudioDataset):
fname = os.path.join(self.root_dir, self.fnames[index])
wav, curr_sample_rate = sf.read(fname)
- feats = torch.from_numpy(wav).float()
+ wav_padded = np.zeros(self.max_sample_size)
+ if wav.shape[0] > self.max_sample_size:
+ wav_padded = wav[:self.max_sample_size]
+ else:
+ wav_padded[:wav.shape[0]]=wav
+ #feats = torch.from_numpy(wav).float()
+ feats = torch.from_numpy(wav_padded).float()
+ #print("DEBUG_MESSAGE:i WAV SHAPE", wav.shape, wav_padded.shape)
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py
index 78e6d4d2..73dbccf7 100644
--- a/fairseq/logging/meters.py
+++ b/fairseq/logging/meters.py
@@ -123,6 +123,7 @@ class TimeMeter(Meter):
self.start = time.perf_counter()
self.n = n
self.i = 0
+ print("DEBUG_MESSAGE: Time Meter created at:", self.start)
def update(self, val=1):
self.n = type_as(self.n, val) + val
@@ -145,6 +146,7 @@ class TimeMeter(Meter):
@property
def avg(self):
+ print("DEBUG_MESSAGE: Elapsed Time:",self.n, self.elapsed_time)
return self.n / self.elapsed_time
@property
@@ -153,6 +155,7 @@ class TimeMeter(Meter):
@property
def smoothed_value(self) -> float:
+ print("DEBUG_MESSAGE: Smooth Value Called")
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
@@ -209,6 +212,7 @@ class StopwatchMeter(Meter):
@property
def smoothed_value(self) -> float:
+ print("DEBUG_MESSAGE: Meter Key:" , self.avg, self.sum, self.n)
val = self.avg if self.sum > 0 else self.elapsed_time
if self.round is not None and val is not None:
val = safe_round(val, self.round)
@@ -256,6 +260,7 @@ class MetersDict(OrderedDict):
def get_smoothed_value(self, key: str) -> float:
"""Get a single smoothed value."""
meter = self[key]
+ print("DEBUG_MESSAGE: Meter Key:" , key, meter)
if isinstance(meter, MetersDict._DerivedMeter):
return meter.fn(self)
else:
diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py
index 5036cfe2..e8f84f6e 100644
--- a/fairseq/tasks/fairseq_task.py
+++ b/fairseq/tasks/fairseq_task.py
@@ -9,6 +9,7 @@ import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
+from fairseq.metsumm import metsumm
class FairseqTask(object):
@@ -392,8 +393,11 @@ class FairseqTask(object):
)
else:
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ metsumm("DEBUG_MESSAGE: After ntoken extraction")
metrics.log_scalar("wpb", ntokens, priority=180, round=1)
+ metsumm("DEBUG_MESSAGE: After wpb extraction")
metrics.log_speed("wps", ntokens, priority=90, round=1)
+ metsumm("DEBUG_MESSAGE: After wpsntoken extraction")
if not any("nsentences" in log for log in logging_outputs):
warnings.warn(
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index 377ce99b..11399bbe 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -20,6 +20,7 @@ from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
+from fairseq.metsumm import metsumm
logger = logging.getLogger(__name__)
@@ -38,6 +39,8 @@ class Trainer(object):
def __init__(self, args, task, model, criterion, quantizer=None):
self.args = args
self.task = task
+ self.logging_history = []
+ self.cumm_sample_size = 0
# catalog shared parameters
shared_params = _catalog_shared_params(model)
@@ -399,6 +402,7 @@ class Trainer(object):
try:
with maybe_no_sync():
# forward and backward
+ metsumm("DEBUG_MESSAGE: Before TASK.Train Step")
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
@@ -407,10 +411,12 @@ class Trainer(object):
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
)
+ metsumm("DEBUG_MESSAGE: After TASK.Train Step")
del loss
logging_outputs.append(logging_output)
sample_size += sample_size_i
+ print("DEBUG_MESSAGE: Inputshape", sample["net_input"]["source"].size())
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
@@ -451,6 +457,7 @@ class Trainer(object):
# gather logging outputs from all replicas
if self._sync_stats():
+
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
)
@@ -483,7 +490,9 @@ class Trainer(object):
self._check_grad_norms(grad_norm)
# take an optimization step
+ metsumm("DEBUG_MESSAGE: Before Optimizer Step")
self.optimizer.step()
+ metsumm("DEBUG_MESSAGE: After Optimizer Step")
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print out where it fails
with NanDetector(self.model):
@@ -504,11 +513,14 @@ class Trainer(object):
raise e
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
+ metsumm("DEBUG_MESSAGE: Before Additional Opt")
if hasattr(self.model, 'perform_additional_optimizer_actions'):
+ metsumm("DEBUG_MESSAGE: Before Additional Opt: Opt Action")
if hasattr(self.optimizer, 'fp32_params'):
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
else:
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)
+ metsumm("DEBUG_MESSAGE: After Additional Opt: Opt Action")
if not overflow or self.args.distributed_wrapper == 'SlowMo':
self.set_num_updates(self.get_num_updates() + 1)
@@ -516,15 +528,24 @@ class Trainer(object):
if self.tpu:
# mark step on TPUs
import torch_xla.core.xla_model as xm
+ metsumm("DEBUG_MESSAGE: Before MARK Step")
xm.mark_step()
+ metsumm("DEBUG_MESSAGE: After MARK Step")
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
+ self.logging_history += logging_outputs
+ self.cumm_sample_size += sample_size
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
+ metsumm("DEBUG_MESSAGE: Before Additional Opt reduce log stat")
logging_output = self._reduce_and_log_stats(
- logging_outputs, sample_size, grad_norm,
+ #logging_outputs, sample_size, grad_norm,
+ self.logging_history, self.cumm_sample_size, grad_norm,
)
+ self.logging_history = []
+ self.cumm_sample_size = 0
+ metsumm("DEBUG_MESSAGE: After Additional Opt reduce log stat")
# log whenever there's an XLA compilation, since these
# slow down training and may indicate opportunities for
@@ -546,6 +567,7 @@ class Trainer(object):
) == 0
):
torch.cuda.empty_cache()
+ metsumm("DEBUG_MESSAGE: After Additional Opt")
if self.args.fp16:
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)
@@ -867,7 +889,9 @@ class Trainer(object):
with metrics.aggregate() as agg:
if logging_outputs is not None:
+ metsumm("DEBUG_MESSAGE: Before reduce meterics")
self.task.reduce_metrics(logging_outputs, self.get_criterion())
+ metsumm("DEBUG_MESSAGE: After reduce meterics")
del logging_outputs
# support legacy interface
diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py
index ee08d6fe..7fa350e4 100644
--- a/fairseq_cli/train.py
+++ b/fairseq_cli/train.py
@@ -28,6 +28,7 @@ from fairseq.data import iterators
from fairseq.logging import meters, metrics, progress_bar
from fairseq.trainer import Trainer
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
+from fairseq.metsumm import metsumm
logging.basicConfig(
@@ -202,15 +203,23 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf):
valid_subsets = args.valid_subset.split(',')
for samples in progress:
+ #print("DEBUG_MESSAGE:", len(samples))
+ #for i in samples:
+ # print(i['net_input']['source'].size())
+ #continue
with metrics.aggregate('train_inner'):
+ metsumm("DEBUG_MESSAGE: Before Main Train Step.")
log_output = trainer.train_step(samples)
+ metsumm("DEBUG_MESSAGE: After Main Train Step.")
if log_output is None: # OOM, overflow, ...
continue
# log mid-epoch stats
num_updates = trainer.get_num_updates()
if num_updates % args.log_interval == 0:
+ metsumm("DEBUG_MESSAGE: Before Get Training Stat")
stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
+ metsumm("DEBUG_MESSAGE: After Get Training Stat")
progress.log(stats, tag='train_inner', step=num_updates)
# reset mid-epoch stats after each log interval
@@ -218,9 +227,12 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf):
metrics.reset_meters('train_inner')
end_of_epoch = not itr.has_next()
+ metsumm("DEBUG_MESSAGE: Before Valid Losses Computation:")
valid_losses = validate_and_save(
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
)
+ metsumm("DEBUG_MESSAGE: After Valid Losses Computation:")
+
if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
break
@@ -375,7 +387,7 @@ def cli_main(modify_parser=None):
xmp.spawn(
fn=distributed_main,
args=(args, ),
- nprocs=8, # use all 8 TPU cores
+ nprocs=1, # use all 8 TPU cores
)
else:
# single GPU training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment