Last active
May 26, 2020 19:10
-
-
Save ultrons/e0fd10e340aaf3b4efaa0e5b3d1c5370 to your computer and use it in GitHub Desktop.
DIff of Intermediate Changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/fairseq/criterions/binary_cross_entropy.py b/fairseq/criterions/binary_cross_entropy.py | |
index 557f50bd..3b2d0d0f 100644 | |
--- a/fairseq/criterions/binary_cross_entropy.py | |
+++ b/fairseq/criterions/binary_cross_entropy.py | |
@@ -8,8 +8,9 @@ import math | |
import torch | |
import torch.nn.functional as F | |
-from fairseq import utils | |
+from fairseq import utils, metrics | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
+from fairseq.metsumm import metsumm | |
@register_criterion('binary_cross_entropy') | |
@@ -41,7 +42,7 @@ class BinaryCrossEntropyCriterion(FairseqCriterion): | |
3) logging outputs to display while training | |
""" | |
net_output = model(**sample['net_input']) | |
- logits = model.get_logits(net_output).float() | |
+ logits = model.get_logits(net_output).float() #DEB: No incr | |
target = model.get_targets(sample, net_output) | |
weights = None | |
@@ -57,7 +58,10 @@ class BinaryCrossEntropyCriterion(FairseqCriterion): | |
else: | |
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) | |
- sample_size = target.numel() if self.infonce else target.long().sum().item() | |
+ metsumm("DEBUG_MESSAGE: Before target sum long") | |
+ sample_size = target.numel() if self.infonce else target.sum().long() | |
+ metsumm("DEBUG_MESSAGE: After target sum long") | |
+ #sample_size = target.numel() | |
losses.append(loss) | |
if self.loss_weights is not None and hasattr(model, "get_extra_losses"): | |
@@ -73,8 +77,9 @@ class BinaryCrossEntropyCriterion(FairseqCriterion): | |
loss += p | |
losses.append(p) | |
+ #'loss': loss.item() if reduce else loss, | |
logging_output = { | |
- 'loss': loss.item() if reduce else loss, | |
+ 'loss': loss.data, | |
'ntokens': sample_size, | |
'nsentences': logits.size(0), | |
'sample_size': sample_size, | |
@@ -82,11 +87,13 @@ class BinaryCrossEntropyCriterion(FairseqCriterion): | |
for lk in self.log_keys: | |
if lk in net_output: | |
- logging_output[lk] = float((net_output[lk])) | |
+ #logging_output[lk] = float((net_output[lk])) | |
+ logging_output[lk] = net_output[lk] #float causes 3 comp, 2 aten calls | |
if len(losses) > 1: | |
for i, l in enumerate(losses): | |
- logging_output[f'loss_{i}'] = l.item() | |
+ #logging_output[f'loss_{i}'] = l.item() | |
+ logging_output[f'loss_{i}'] = l | |
if self.infonce: | |
with torch.no_grad(): | |
@@ -109,34 +116,64 @@ class BinaryCrossEntropyCriterion(FairseqCriterion): | |
logging_output['target'] = target.cpu().numpy() | |
return loss, sample_size, logging_output | |
+# @staticmethod | |
+# def aggregate_logging_outputs(logging_outputs): | |
+# """Aggregate logging outputs from data parallel training.""" | |
+# loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) | |
+# ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) | |
+# nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) | |
+# sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) | |
+# agg_output = { | |
+# 'loss': loss_sum / sample_size / math.log(2), | |
+# 'ntokens': ntokens, | |
+# 'nsentences': nsentences, | |
+# 'sample_size': sample_size, | |
+# } | |
+# if sample_size != ntokens: | |
+# agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) | |
+# | |
+# correct = sum(log.get("correct", 0) for log in logging_outputs) | |
+# total = sum(log.get("count", 0) for log in logging_outputs) | |
+# if total > 0: | |
+# agg_output['accuracy'] = correct / total | |
+# | |
+# builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'} | |
+# | |
+# for k in logging_outputs[0]: | |
+# if k not in builtin_keys: | |
+# val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs) | |
+# if k.startswith('loss'): | |
+# val = val / ntokens if ntokens > 0 else float('nan') | |
+# agg_output[k] = val | |
+# | |
+# return agg_output | |
+ | |
@staticmethod | |
- def aggregate_logging_outputs(logging_outputs): | |
+ def reduce_metrics(logging_outputs) -> None: | |
"""Aggregate logging outputs from data parallel training.""" | |
- loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs)) | |
- ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs)) | |
- nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) | |
- sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) | |
- agg_output = { | |
- 'loss': loss_sum / sample_size / math.log(2), | |
- 'ntokens': ntokens, | |
- 'nsentences': nsentences, | |
- 'sample_size': sample_size, | |
- } | |
+ metsumm("DEBUG_MESSAGE: before loss extract") | |
+ #loss_ = [log['loss'].item() for log in logging_outputs] # Result in 6X fewer aten:_local_scalar_dense | |
+ #loss_ = torch.stack([log['loss'] for log in logging_outputs]) # Result in 6X fewer aten:_local_scalar_dense | |
+ #loss_sum = loss_.sum() | |
+ loss_ = [log.get('loss', 0) for log in logging_outputs] | |
+ loss_sum = sum(loss_) | |
+ metsumm("DEBUG_MESSAGE: After loss extract {}".format(loss_)) | |
+ ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) | |
+ sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) | |
if sample_size != ntokens: | |
- agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) | |
- | |
+ metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), round=3) | |
+ metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) | |
+ metsumm("DEBUG_MESSAGE: After loss scaling") | |
correct = sum(log.get("correct", 0) for log in logging_outputs) | |
total = sum(log.get("count", 0) for log in logging_outputs) | |
if total > 0: | |
- agg_output['accuracy'] = correct / total | |
- | |
- builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'} | |
- | |
- for k in logging_outputs[0]: | |
- if k not in builtin_keys: | |
- val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs) | |
- if k.startswith('loss'): | |
- val = val / ntokens if ntokens > 0 else float('nan') | |
- agg_output[k] = val | |
+ metrics.log_scalar('accuracy', correct / total, round=3) | |
- return agg_output | |
+ @staticmethod | |
+ def logging_outputs_can_be_summed() -> bool: | |
+ """ | |
+ Whether the logging outputs returned by `forward` can be summed | |
+ across workers prior to calling `reduce_metrics`. Setting this | |
+ to True will improves distributed training speed. | |
+ """ | |
+ return True | |
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py | |
index 40cbc206..e0616ea5 100644 | |
--- a/fairseq/data/audio/raw_audio_dataset.py | |
+++ b/fairseq/data/audio/raw_audio_dataset.py | |
@@ -151,6 +151,13 @@ class FileAudioDataset(RawAudioDataset): | |
fname = os.path.join(self.root_dir, self.fnames[index]) | |
wav, curr_sample_rate = sf.read(fname) | |
- feats = torch.from_numpy(wav).float() | |
+ wav_padded = np.zeros(self.max_sample_size) | |
+ if wav.shape[0] > self.max_sample_size: | |
+ wav_padded = wav[:self.max_sample_size] | |
+ else: | |
+ wav_padded[:wav.shape[0]]=wav | |
+ #feats = torch.from_numpy(wav).float() | |
+ feats = torch.from_numpy(wav_padded).float() | |
+ #print("DEBUG_MESSAGE:i WAV SHAPE", wav.shape, wav_padded.shape) | |
feats = self.postprocess(feats, curr_sample_rate) | |
return {"id": index, "source": feats} | |
diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py | |
index 78e6d4d2..73dbccf7 100644 | |
--- a/fairseq/logging/meters.py | |
+++ b/fairseq/logging/meters.py | |
@@ -123,6 +123,7 @@ class TimeMeter(Meter): | |
self.start = time.perf_counter() | |
self.n = n | |
self.i = 0 | |
+ print("DEBUG_MESSAGE: Time Meter created at:", self.start) | |
def update(self, val=1): | |
self.n = type_as(self.n, val) + val | |
@@ -145,6 +146,7 @@ class TimeMeter(Meter): | |
@property | |
def avg(self): | |
+ print("DEBUG_MESSAGE: Elapsed Time:",self.n, self.elapsed_time) | |
return self.n / self.elapsed_time | |
@property | |
@@ -153,6 +155,7 @@ class TimeMeter(Meter): | |
@property | |
def smoothed_value(self) -> float: | |
+ print("DEBUG_MESSAGE: Smooth Value Called") | |
val = self.avg | |
if self.round is not None and val is not None: | |
val = safe_round(val, self.round) | |
@@ -209,6 +212,7 @@ class StopwatchMeter(Meter): | |
@property | |
def smoothed_value(self) -> float: | |
+ print("DEBUG_MESSAGE: Meter Key:" , self.avg, self.sum, self.n) | |
val = self.avg if self.sum > 0 else self.elapsed_time | |
if self.round is not None and val is not None: | |
val = safe_round(val, self.round) | |
@@ -256,6 +260,7 @@ class MetersDict(OrderedDict): | |
def get_smoothed_value(self, key: str) -> float: | |
"""Get a single smoothed value.""" | |
meter = self[key] | |
+ print("DEBUG_MESSAGE: Meter Key:" , key, meter) | |
if isinstance(meter, MetersDict._DerivedMeter): | |
return meter.fn(self) | |
else: | |
diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py | |
index 5036cfe2..e8f84f6e 100644 | |
--- a/fairseq/tasks/fairseq_task.py | |
+++ b/fairseq/tasks/fairseq_task.py | |
@@ -9,6 +9,7 @@ import torch | |
from fairseq import metrics, search, tokenizer, utils | |
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary | |
+from fairseq.metsumm import metsumm | |
class FairseqTask(object): | |
@@ -392,8 +393,11 @@ class FairseqTask(object): | |
) | |
else: | |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
+ metsumm("DEBUG_MESSAGE: After ntoken extraction") | |
metrics.log_scalar("wpb", ntokens, priority=180, round=1) | |
+ metsumm("DEBUG_MESSAGE: After wpb extraction") | |
metrics.log_speed("wps", ntokens, priority=90, round=1) | |
+ metsumm("DEBUG_MESSAGE: After wpsntoken extraction") | |
if not any("nsentences" in log for log in logging_outputs): | |
warnings.warn( | |
diff --git a/fairseq/trainer.py b/fairseq/trainer.py | |
index 377ce99b..11399bbe 100644 | |
--- a/fairseq/trainer.py | |
+++ b/fairseq/trainer.py | |
@@ -20,6 +20,7 @@ from fairseq.file_io import PathManager | |
from fairseq.logging import meters, metrics | |
from fairseq.nan_detector import NanDetector | |
from fairseq.optim import lr_scheduler | |
+from fairseq.metsumm import metsumm | |
logger = logging.getLogger(__name__) | |
@@ -38,6 +39,8 @@ class Trainer(object): | |
def __init__(self, args, task, model, criterion, quantizer=None): | |
self.args = args | |
self.task = task | |
+ self.logging_history = [] | |
+ self.cumm_sample_size = 0 | |
# catalog shared parameters | |
shared_params = _catalog_shared_params(model) | |
@@ -399,6 +402,7 @@ class Trainer(object): | |
try: | |
with maybe_no_sync(): | |
# forward and backward | |
+ metsumm("DEBUG_MESSAGE: Before TASK.Train Step") | |
loss, sample_size_i, logging_output = self.task.train_step( | |
sample=sample, | |
model=self.model, | |
@@ -407,10 +411,12 @@ class Trainer(object): | |
update_num=self.get_num_updates(), | |
ignore_grad=is_dummy_batch, | |
) | |
+ metsumm("DEBUG_MESSAGE: After TASK.Train Step") | |
del loss | |
logging_outputs.append(logging_output) | |
sample_size += sample_size_i | |
+ print("DEBUG_MESSAGE: Inputshape", sample["net_input"]["source"].size()) | |
# emptying the CUDA cache after the first step can | |
# reduce the chance of OOM | |
@@ -451,6 +457,7 @@ class Trainer(object): | |
# gather logging outputs from all replicas | |
if self._sync_stats(): | |
+ | |
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs( | |
logging_outputs, sample_size, ooms, ignore=is_dummy_batch, | |
) | |
@@ -483,7 +490,9 @@ class Trainer(object): | |
self._check_grad_norms(grad_norm) | |
# take an optimization step | |
+ metsumm("DEBUG_MESSAGE: Before Optimizer Step") | |
self.optimizer.step() | |
+ metsumm("DEBUG_MESSAGE: After Optimizer Step") | |
except FloatingPointError: | |
# re-run the forward and backward pass with hooks attached to print out where it fails | |
with NanDetector(self.model): | |
@@ -504,11 +513,14 @@ class Trainer(object): | |
raise e | |
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step | |
+ metsumm("DEBUG_MESSAGE: Before Additional Opt") | |
if hasattr(self.model, 'perform_additional_optimizer_actions'): | |
+ metsumm("DEBUG_MESSAGE: Before Additional Opt: Opt Action") | |
if hasattr(self.optimizer, 'fp32_params'): | |
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) | |
else: | |
self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) | |
+ metsumm("DEBUG_MESSAGE: After Additional Opt: Opt Action") | |
if not overflow or self.args.distributed_wrapper == 'SlowMo': | |
self.set_num_updates(self.get_num_updates() + 1) | |
@@ -516,15 +528,24 @@ class Trainer(object): | |
if self.tpu: | |
# mark step on TPUs | |
import torch_xla.core.xla_model as xm | |
+ metsumm("DEBUG_MESSAGE: Before MARK Step") | |
xm.mark_step() | |
+ metsumm("DEBUG_MESSAGE: After MARK Step") | |
# only log stats every log_interval steps | |
# this causes wps to be misreported when log_interval > 1 | |
+ self.logging_history += logging_outputs | |
+ self.cumm_sample_size += sample_size | |
logging_output = {} | |
if self.get_num_updates() % self.args.log_interval == 0: | |
+ metsumm("DEBUG_MESSAGE: Before Additional Opt reduce log stat") | |
logging_output = self._reduce_and_log_stats( | |
- logging_outputs, sample_size, grad_norm, | |
+ #logging_outputs, sample_size, grad_norm, | |
+ self.logging_history, self.cumm_sample_size, grad_norm, | |
) | |
+ self.logging_history = [] | |
+ self.cumm_sample_size = 0 | |
+ metsumm("DEBUG_MESSAGE: After Additional Opt reduce log stat") | |
# log whenever there's an XLA compilation, since these | |
# slow down training and may indicate opportunities for | |
@@ -546,6 +567,7 @@ class Trainer(object): | |
) == 0 | |
): | |
torch.cuda.empty_cache() | |
+ metsumm("DEBUG_MESSAGE: After Additional Opt") | |
if self.args.fp16: | |
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) | |
@@ -867,7 +889,9 @@ class Trainer(object): | |
with metrics.aggregate() as agg: | |
if logging_outputs is not None: | |
+ metsumm("DEBUG_MESSAGE: Before reduce meterics") | |
self.task.reduce_metrics(logging_outputs, self.get_criterion()) | |
+ metsumm("DEBUG_MESSAGE: After reduce meterics") | |
del logging_outputs | |
# support legacy interface | |
diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py | |
index ee08d6fe..7fa350e4 100644 | |
--- a/fairseq_cli/train.py | |
+++ b/fairseq_cli/train.py | |
@@ -28,6 +28,7 @@ from fairseq.data import iterators | |
from fairseq.logging import meters, metrics, progress_bar | |
from fairseq.trainer import Trainer | |
from fairseq.model_parallel.megatron_trainer import MegatronTrainer | |
+from fairseq.metsumm import metsumm | |
logging.basicConfig( | |
@@ -202,15 +203,23 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): | |
valid_subsets = args.valid_subset.split(',') | |
for samples in progress: | |
+ #print("DEBUG_MESSAGE:", len(samples)) | |
+ #for i in samples: | |
+ # print(i['net_input']['source'].size()) | |
+ #continue | |
with metrics.aggregate('train_inner'): | |
+ metsumm("DEBUG_MESSAGE: Before Main Train Step.") | |
log_output = trainer.train_step(samples) | |
+ metsumm("DEBUG_MESSAGE: After Main Train Step.") | |
if log_output is None: # OOM, overflow, ... | |
continue | |
# log mid-epoch stats | |
num_updates = trainer.get_num_updates() | |
if num_updates % args.log_interval == 0: | |
+ metsumm("DEBUG_MESSAGE: Before Get Training Stat") | |
stats = get_training_stats(metrics.get_smoothed_values('train_inner')) | |
+ metsumm("DEBUG_MESSAGE: After Get Training Stat") | |
progress.log(stats, tag='train_inner', step=num_updates) | |
# reset mid-epoch stats after each log interval | |
@@ -218,9 +227,12 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): | |
metrics.reset_meters('train_inner') | |
end_of_epoch = not itr.has_next() | |
+ metsumm("DEBUG_MESSAGE: Before Valid Losses Computation:") | |
valid_losses = validate_and_save( | |
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch | |
) | |
+ metsumm("DEBUG_MESSAGE: After Valid Losses Computation:") | |
+ | |
if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: | |
break | |
@@ -375,7 +387,7 @@ def cli_main(modify_parser=None): | |
xmp.spawn( | |
fn=distributed_main, | |
args=(args, ), | |
- nprocs=8, # use all 8 TPU cores | |
+ nprocs=1, # use all 8 TPU cores | |
) | |
else: | |
# single GPU training |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment