Created
October 29, 2020 18:29
-
-
Save taylanbil/1e41b03261cbbdb9900b51874e5da532 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ git diff 621e8341 > ~/debug/20201005-wav2vec2/latestdiff.diff | |
(torch-xla-1.7) | |
taylanbil at beefy-pytorch-xla-eu in ~/debug/20201005-wav2vec2/fairseq (w2v2●●) | |
$ cat ~/debug/20201005-wav2vec2/latestdiff.diff | |
diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py | |
index 019db622..c316147f 100644 | |
--- a/fairseq/criterions/wav2vec_criterion.py | |
+++ b/fairseq/criterions/wav2vec_criterion.py | |
@@ -10,6 +10,7 @@ import torch.nn.functional as F | |
from fairseq import metrics, utils | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
+from fairseq.utils import index_put | |
@register_criterion('wav2vec') | |
@@ -41,9 +42,18 @@ class Wav2vecCriterion(FairseqCriterion): | |
3) logging outputs to display while training | |
""" | |
net_output = model(**sample['net_input']) | |
+ | |
logits = model.get_logits(net_output).float() | |
target = model.get_targets(sample, net_output) | |
+ if logits.device.type == 'xla': | |
+ # tpu-comment: since dynamic shapes lead to recompilations on xla, | |
+ # we don't shrink tensors using mask_indices. | |
+ # Instead, we do the following when computing loss: | |
+ mi = sample['net_input']['mask_indices'].reshape(logits.size(0)) | |
+ target = index_put(target, ~mi, -1) | |
+ | |
+ # XXX: handle weights on xla. | |
weights = None | |
if hasattr(model, 'get_target_weights') and not self.infonce: | |
weights = model.get_target_weights(target, net_output) | |
@@ -53,11 +63,24 @@ class Wav2vecCriterion(FairseqCriterion): | |
losses = [] | |
if self.infonce: | |
- loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",) | |
+ loss = F.cross_entropy( | |
+ logits, target, reduction="sum" if reduce else "none", | |
+ ignore_index=-1, | |
+ ) | |
else: | |
- loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) | |
+ loss = F.binary_cross_entropy_with_logits( | |
+ logits, target.float(), weights, | |
+ reduction="sum" if reduce else "none", | |
+ ignore_index=-1, | |
+ ) | |
- sample_size = target.numel() if self.infonce else target.long().sum().item() | |
+ if 'sample_size' in sample and self.infonce: | |
+ sample_size = sample['sample_size'] | |
+ elif 'mask_indices' in sample['net_input'] and self.infonce: | |
+ # XXX: what happens if not self.infonce? | |
+ sample_size = sample['net_input']['mask_indices'].sum() | |
+ else: | |
+ sample_size = target.numel() if self.infonce else target.long().sum() | |
losses.append(loss) | |
if self.loss_weights is not None: | |
@@ -75,7 +98,8 @@ class Wav2vecCriterion(FairseqCriterion): | |
losses.append(p) | |
logging_output = { | |
- 'loss': loss.item() if reduce else loss, | |
+ #'loss': loss.item() if reduce else loss, | |
+ 'loss': loss.detach(), | |
'ntokens': sample_size, | |
'nsentences': sample['id'].numel(), | |
'sample_size': sample_size, | |
@@ -83,11 +107,14 @@ class Wav2vecCriterion(FairseqCriterion): | |
for lk in self.log_keys: | |
if lk in net_output: | |
- logging_output[lk] = float((net_output[lk])) | |
+ value = net_output[lk] | |
+ if not torch.is_tensor(value) or value.device.type != 'xla': | |
+ value = float(value) | |
+ logging_output[lk] = value | |
if len(losses) > 1: | |
for i, l in enumerate(losses): | |
- logging_output[f'loss_{i}'] = l.item() | |
+ logging_output[f'loss_{i}'] = l.detach() | |
if self.infonce: | |
with torch.no_grad(): | |
@@ -99,13 +126,14 @@ class Wav2vecCriterion(FairseqCriterion): | |
max = logits.argmax(-1) == 0 | |
min = logits.argmin(-1) == 0 | |
both = max & min | |
- corr = max.long().sum().item() - both.long().sum().item() | |
- count = max.numel() | |
+ # corr = max.long().sum().item() - both.long().sum().item() | |
+ corr = max.long().sum() - both.long().sum() | |
+ count = float(max.numel()) | |
logging_output["correct"] = corr | |
logging_output["count"] = count | |
- if log_pred: | |
+ if log_pred and logits.device.type != 'xla': | |
logging_output['logits'] = logits.cpu().numpy() | |
logging_output['target'] = target.cpu().numpy() | |
return loss, sample_size, logging_output | |
@@ -132,7 +160,7 @@ class Wav2vecCriterion(FairseqCriterion): | |
if total > 0: | |
metrics.log_derived( | |
"accuracy", | |
- lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) | |
+ lambda meters: meters["_correct"].sum / meters["_total"].sum | |
if meters["_total"].sum > 0 | |
else float("nan"), | |
) | |
@@ -154,4 +182,4 @@ class Wav2vecCriterion(FairseqCriterion): | |
across workers prior to calling `reduce_metrics`. Setting this | |
to True will improves distributed training speed. | |
""" | |
- return False | |
+ return True | |
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py | |
index 675b0956..2114c93a 100644 | |
--- a/fairseq/data/audio/raw_audio_dataset.py | |
+++ b/fairseq/data/audio/raw_audio_dataset.py | |
@@ -12,7 +12,8 @@ import sys | |
import torch | |
import torch.nn.functional as F | |
-from .. import FairseqDataset | |
+from .. import FairseqDataset, BaseWrapperDataset | |
+from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes | |
logger = logging.getLogger(__name__) | |
@@ -27,6 +28,8 @@ class RawAudioDataset(FairseqDataset): | |
min_length=0, | |
pad=False, | |
normalize=False, | |
+ compute_mask_indices=False, | |
+ args=None, | |
): | |
super().__init__() | |
@@ -40,6 +43,12 @@ class RawAudioDataset(FairseqDataset): | |
self.pad = pad | |
self.shuffle = shuffle | |
self.normalize = normalize | |
+ self.compute_mask_indices = compute_mask_indices | |
+ if self.compute_mask_indices: | |
+ self.args = args | |
+ self._features_size_map = {} | |
+ self._C = self.args.encoder_embed_dim | |
+ self._conv_feature_layers = eval(self.args.conv_feature_layers) | |
def __getitem__(self, index): | |
raise NotImplementedError() | |
@@ -71,6 +80,45 @@ class RawAudioDataset(FairseqDataset): | |
end = size - diff + start | |
return wav[start:end] | |
+ def _compute_mask_indices(self, dims, padding_mask): | |
+ B, T, C = dims | |
+ mask_indices, mask_channel_indices = None, None | |
+ if self.args.mask_prob > 0: | |
+ mask_indices = compute_mask_indices( | |
+ (B, T), | |
+ padding_mask, | |
+ self.args.mask_prob, | |
+ self.args.mask_length, | |
+ self.args.mask_selection, | |
+ self.args.mask_other, | |
+ min_masks=2, | |
+ no_overlap=self.args.no_mask_overlap, | |
+ min_space=self.args.mask_min_space, | |
+ ) | |
+ mask_indices = torch.from_numpy(mask_indices) | |
+ if self.args.mask_channel_prob > 0: | |
+ mask_channel_indices = compute_mask_indices( | |
+ (B, C), | |
+ None, | |
+ self.args.mask_channel_prob, | |
+ self.args.mask_channel_length, | |
+ self.args.mask_channel_selection, | |
+ self.args.mask_channel_other, | |
+ no_overlap=self.args.no_mask_channel_overlap, | |
+ min_space=self.args.mask_channel_min_space, | |
+ ) | |
+ mask_channel_indices = ( | |
+ torch.from_numpy(mask_channel_indices) | |
+ .unsqueeze(1) | |
+ .expand(-1, T, -1) | |
+ ) | |
+ | |
+ return mask_indices, mask_channel_indices | |
+ | |
+ @staticmethod | |
+ def _bucket_tensor(tensor, num_pad, value): | |
+ return F.pad(tensor, (0, num_pad), value=value) | |
+ | |
def collater(self, samples): | |
samples = [ | |
s | |
@@ -106,9 +154,53 @@ class RawAudioDataset(FairseqDataset): | |
collated_sources[i] = self.crop_to_max_size(source, target_size) | |
input = {"source": collated_sources} | |
+ out = {"id": torch.LongTensor([s["id"] for s in samples])} | |
if self.pad: | |
input["padding_mask"] = padding_mask | |
- return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} | |
+ | |
+ if hasattr(self, 'num_buckets') and self.num_buckets > 0: | |
+ assert self.pad, "Cannot bucket without padding first." | |
+ bucket = max(self._bucketed_sizes[s['id']] for s in samples) | |
+ num_pad = bucket - collated_sources.size(-1) | |
+ if num_pad: | |
+ input['source'] = self._bucket_tensor( | |
+ collated_sources, num_pad, 0 | |
+ ) | |
+ input['padding_mask'] = self._bucket_tensor( | |
+ padding_mask, num_pad, True | |
+ ) | |
+ | |
+ if self.compute_mask_indices: | |
+ B = input['source'].size(0) | |
+ T = self._get_mask_indices_dims(input['source'].size(-1)) | |
+ padding_mask_reshaped = input['padding_mask'].clone() | |
+ extra = padding_mask_reshaped.size(1) % T | |
+ if extra > 0: | |
+ padding_mask_reshaped = padding_mask_reshaped[:, :-extra] | |
+ padding_mask_reshaped = padding_mask_reshaped.view( | |
+ padding_mask_reshaped.size(0), T, -1 | |
+ ) | |
+ padding_mask_reshaped = padding_mask_reshaped.all(-1) | |
+ mask_indices, mask_channel_indices = self._compute_mask_indices( | |
+ (B, T, self._C), padding_mask_reshaped, | |
+ ) | |
+ input["mask_indices"] = mask_indices | |
+ input['padding_counts'] = input['mask_indices'].sum(-1).tolist() | |
+ input["mask_channel_indices"] = mask_channel_indices | |
+ out['sample_size'] = mask_indices.sum().item() | |
+ | |
+ out["net_input"] = input | |
+ return out | |
+ | |
+ def _get_mask_indices_dims(self, size, padding=0, dilation=1): | |
+ if size not in self._features_size_map: | |
+ L_in = size | |
+ for (_, kernel_size, stride) in self._conv_feature_layers: | |
+ L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 | |
+ L_out = 1 + L_out // stride | |
+ L_in = L_out | |
+ self._features_size_map[size] = L_out | |
+ return self._features_size_map[size] | |
def num_tokens(self, index): | |
return self.size(index) | |
@@ -144,6 +236,9 @@ class FileAudioDataset(RawAudioDataset): | |
min_length=0, | |
pad=False, | |
normalize=False, | |
+ compute_mask_indices=False, | |
+ args=None, | |
+ num_buckets=0, | |
): | |
super().__init__( | |
sample_rate=sample_rate, | |
@@ -153,6 +248,8 @@ class FileAudioDataset(RawAudioDataset): | |
min_length=min_length, | |
pad=pad, | |
normalize=normalize, | |
+ compute_mask_indices=compute_mask_indices, | |
+ args=args, | |
) | |
self.fnames = [] | |
@@ -169,8 +266,26 @@ class FileAudioDataset(RawAudioDataset): | |
continue | |
self.fnames.append(items[0]) | |
self.sizes.append(sz) | |
+ self.set_bucket_info(num_buckets) | |
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") | |
+ def set_bucket_info(self, num_buckets): | |
+ self.num_buckets = num_buckets | |
+ if self.num_buckets > 0: | |
+ self._collated_sizes = np.minimum( | |
+ np.array(self.sizes), self.max_sample_size, | |
+ ) | |
+ self.buckets = get_buckets( | |
+ self._collated_sizes, self.num_buckets, | |
+ ) | |
+ self._bucketed_sizes = get_bucketed_sizes( | |
+ self._collated_sizes, self.buckets | |
+ ) | |
+ logger.info( | |
+ f"{len(self.buckets)} bucket(s) for the audio dataset: " | |
+ f"{self.buckets}" | |
+ ) | |
+ | |
def __getitem__(self, index): | |
import soundfile as sf | |
diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py | |
index 6f53d011..e4ce3c40 100644 | |
--- a/fairseq/data/bucket_pad_length_dataset.py | |
+++ b/fairseq/data/bucket_pad_length_dataset.py | |
@@ -7,6 +7,7 @@ import numpy as np | |
import torch.nn.functional as F | |
from fairseq.data import BaseWrapperDataset | |
+from fairseq.data.data_utils import get_buckets, get_bucketed_sizes | |
class BucketPadLengthDataset(BaseWrapperDataset): | |
@@ -30,42 +31,43 @@ class BucketPadLengthDataset(BaseWrapperDataset): | |
num_buckets, | |
pad_idx, | |
left_pad, | |
+ tensor_key=None, | |
): | |
super().__init__(dataset) | |
self.pad_idx = pad_idx | |
self.left_pad = left_pad | |
assert num_buckets > 0 | |
- self.buckets = np.unique( | |
- np.percentile( | |
- sizes, | |
- np.linspace(0, 100, num_buckets + 1), | |
- interpolation='lower', | |
- )[1:] | |
- ) | |
+ self.buckets = get_buckets(sizes, num_buckets) | |
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) | |
+ self._tensor_key = tensor_key | |
- def get_bucketed_sizes(orig_sizes, buckets): | |
- sizes = np.copy(orig_sizes) | |
- assert np.min(sizes) >= 0 | |
- start_val = -1 | |
- for end_val in buckets: | |
- mask = (sizes > start_val) & (sizes <= end_val) | |
- sizes[mask] = end_val | |
- start_val = end_val | |
- return sizes | |
+ def _set_tensor(self, item, val): | |
+ if self._tensor_key is None: | |
+ return val | |
+ item[self._tensor_key] = val | |
+ return item | |
- self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) | |
+ def _get_tensor(self, item): | |
+ if self._tensor_key is None: | |
+ return item | |
+ return item[self._tensor_key] | |
- def __getitem__(self, index): | |
- item = self.dataset[index] | |
- bucket_size = self._bucketed_sizes[index] | |
- num_pad = bucket_size - item.size(-1) | |
+ def _pad(self, tensor, bucket_size, dim=-1): | |
+ num_pad = bucket_size - tensor.size(dim) | |
return F.pad( | |
- item, | |
+ tensor, | |
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), | |
value=self.pad_idx, | |
) | |
+ def __getitem__(self, index): | |
+ item = self.dataset[index] | |
+ bucket_size = self._bucketed_sizes[index] | |
+ tensor = self._get_tensor(item) | |
+ padded = self._pad(tensor, bucket_size) | |
+ return self._set_tensor(item, padded) | |
+ | |
@property | |
def sizes(self): | |
return self._bucketed_sizes | |
diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py | |
index 57991a88..e6818219 100644 | |
--- a/fairseq/data/data_utils.py | |
+++ b/fairseq/data/data_utils.py | |
@@ -272,6 +272,7 @@ def post_process(sentence: str, symbol: str): | |
sentence = (sentence + " ").replace(symbol, "").rstrip() | |
return sentence | |
+ | |
def compute_mask_indices( | |
shape: Tuple[int, int], | |
padding_mask: Optional[torch.Tensor], | |
@@ -283,6 +284,7 @@ def compute_mask_indices( | |
no_overlap: bool = False, | |
min_space: int = 0, | |
) -> np.ndarray: | |
+#) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Computes random mask spans for a given shape | |
@@ -393,3 +395,25 @@ def compute_mask_indices( | |
mask[i, mask_idc] = True | |
return mask | |
+ | |
+ | |
+def get_buckets(sizes, num_buckets): | |
+ buckets = np.unique( | |
+ np.percentile( | |
+ sizes, | |
+ np.linspace(0, 100, num_buckets + 1), | |
+ interpolation='lower', | |
+ )[1:] | |
+ ) | |
+ return buckets | |
+ | |
+ | |
+def get_bucketed_sizes(orig_sizes, buckets): | |
+ sizes = np.copy(orig_sizes) | |
+ assert np.min(sizes) >= 0 | |
+ start_val = -1 | |
+ for end_val in buckets: | |
+ mask = (sizes > start_val) & (sizes <= end_val) | |
+ sizes[mask] = end_val | |
+ start_val = end_val | |
+ return sizes | |
diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py | |
index 7ee89adc..0060d8c1 100644 | |
--- a/fairseq/distributed_utils.py | |
+++ b/fairseq/distributed_utils.py | |
@@ -113,7 +113,6 @@ def distributed_init(args): | |
args.device_id = xm.get_local_ordinal() | |
args.distributed_rank = xm.get_ordinal() | |
xm.rendezvous('distributed_init') # wait for all workers | |
- xm.mark_step() | |
if is_master(args): | |
logging.getLogger().setLevel(logging.INFO) | |
@@ -182,7 +181,10 @@ def call_main(args, main, **kwargs): | |
xmp.spawn( | |
fn=distributed_main, | |
args=(main, args, kwargs), | |
- nprocs=8, # use all 8 TPU cores | |
+ # tpu-comment: | |
+ # 8 devices in one TPU VM, is the max processes to be spawned. | |
+ # The rest is driven by xm.distributed.xla_dist | |
+ nprocs=min(args.distributed_world_size, 8), | |
) | |
else: | |
# single GPU main | |
diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py | |
index 6ca1d201..0bac4c25 100644 | |
--- a/fairseq/logging/metrics.py | |
+++ b/fairseq/logging/metrics.py | |
@@ -289,3 +289,11 @@ def load_state_dict(state_dict): | |
for name, agg_state in state_dict.items(): | |
_aggregators[name] = MetersDict() | |
_aggregators[name].load_state_dict(agg_state) | |
+ | |
+ | |
+def xla_metrics_report(): | |
+ try: | |
+ import torch_xla.debug.metrics as met | |
+ print(met.metrics_report()) | |
+ except ImportError: | |
+ return | |
diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py | |
new file mode 100644 | |
index 00000000..ce0cfa05 | |
--- /dev/null | |
+++ b/fairseq/metsumm.py | |
@@ -0,0 +1,18 @@ | |
+# FIXME: remove this file | |
+def metsumm(stepno=''): | |
+ if hasattr(metsumm, 'STEPNO'): | |
+ metsumm.STEPNO += stepno.lower()=="before forward" | |
+ else: | |
+ metsumm.STEPNO = 0 | |
+ try: | |
+ import torch_xla.debug.metrics as met | |
+ x = met.metrics_report().split('\n') | |
+ for i, line in enumerate(x): | |
+ if 'CompileTime' in line or 'aten::' in line: | |
+ key = line.split()[-1] | |
+ value = x[i+1].split()[-1] | |
+ print('step {}-{}, key {}, value {}'.format( | |
+ metsumm.STEPNO, stepno, key, value) | |
+ ) | |
+ except RuntimeError: | |
+ return | |
diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py | |
index 226f035b..be3f8df0 100644 | |
--- a/fairseq/models/wav2vec/wav2vec2.py | |
+++ b/fairseq/models/wav2vec/wav2vec2.py | |
@@ -27,7 +27,7 @@ from fairseq.modules import ( | |
TransposeLast, | |
) | |
from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
-from fairseq.utils import buffered_arange | |
+from fairseq.utils import buffered_arange, index_put | |
@register_model("wav2vec2") | |
@@ -405,47 +405,66 @@ class Wav2Vec2Model(BaseFairseqModel): | |
return cls(args) | |
- def apply_mask(self, x, padding_mask): | |
+ def apply_mask( | |
+ self, x, padding_mask, | |
+ mask_indices=None, mask_channel_indices=None, | |
+ ): | |
B, T, C = x.shape | |
if self.mask_prob > 0: | |
- mask_indices = compute_mask_indices( | |
- (B, T), | |
- padding_mask, | |
- self.mask_prob, | |
- self.mask_length, | |
- self.mask_selection, | |
- self.mask_other, | |
- min_masks=2, | |
- no_overlap=self.no_mask_overlap, | |
- min_space=self.mask_min_space, | |
- ) | |
- mask_indices = torch.from_numpy(mask_indices).to(x.device) | |
- x[mask_indices] = self.mask_emb | |
+ if mask_indices is None: | |
+ mask_indices = compute_mask_indices( | |
+ (B, T), | |
+ padding_mask, | |
+ self.mask_prob, | |
+ self.mask_length, | |
+ self.mask_selection, | |
+ self.mask_other, | |
+ min_masks=2, | |
+ no_overlap=self.no_mask_overlap, | |
+ min_space=self.mask_min_space, | |
+ ) | |
+ mask_indices = torch.from_numpy(mask_indices).to(x.device) | |
+ x = index_put(x, mask_indices, self.mask_emb) | |
else: | |
mask_indices = None | |
if self.mask_channel_prob > 0: | |
- mask_channel_indices = compute_mask_indices( | |
- (B, C), | |
- None, | |
- self.mask_channel_prob, | |
- self.mask_channel_length, | |
- self.mask_channel_selection, | |
- self.mask_channel_other, | |
- no_overlap=self.no_mask_channel_overlap, | |
- min_space=self.mask_channel_min_space, | |
- ) | |
- mask_channel_indices = ( | |
- torch.from_numpy(mask_channel_indices) | |
- .to(x.device) | |
- .unsqueeze(1) | |
- .expand(-1, T, -1) | |
- ) | |
- x[mask_channel_indices] = 0 | |
+ if mask_channel_indices is None: | |
+ mask_channel_indices = compute_mask_indices( | |
+ (B, C), | |
+ None, | |
+ self.mask_channel_prob, | |
+ self.mask_channel_length, | |
+ self.mask_channel_selection, | |
+ self.mask_channel_other, | |
+ no_overlap=self.no_mask_channel_overlap, | |
+ min_space=self.mask_channel_min_space, | |
+ ) | |
+ mask_channel_indices = ( | |
+ torch.from_numpy(mask_channel_indices) | |
+ .to(x.device) | |
+ .unsqueeze(1) | |
+ .expand(-1, T, -1) | |
+ ) | |
+ x = index_put(x, mask_channel_indices, 0) | |
return x, mask_indices | |
- def sample_negatives(self, y, num): | |
+ def _get_neg_idxs(self, high, size, padding_counts=None): | |
+ if padding_counts is None: | |
+ neg_idxs = torch.randint(low=0, high=high-1, size=size) | |
+ else: | |
+ bsz, l = size | |
+ num = l // self.n_negatives | |
+ assert len(padding_counts) == bsz | |
+ neg_idxs = [ | |
+ torch.randint(low=0, high=num-pc-1, size=(1, l)) | |
+ for pc in padding_counts | |
+ ] | |
+ neg_idxs = torch.stack(neg_idxs) | |
+ return neg_idxs | |
+ | |
+ def sample_negatives(self, y, num, padding_counts=None): | |
if self.n_negatives == 0 and self.cross_sample_negatives == 0: | |
return y.new(0) | |
@@ -466,8 +485,9 @@ class Wav2Vec2Model(BaseFairseqModel): | |
.flatten() | |
) | |
- neg_idxs = torch.randint( | |
- low=0, high=high - 1, size=(bsz, self.n_negatives * num) | |
+ neg_idxs = self._get_neg_idxs( | |
+ high, (bsz, self.n_negatives * num), | |
+ padding_counts=padding_counts, | |
) | |
neg_idxs[neg_idxs >= tszs] += 1 | |
@@ -511,14 +531,21 @@ class Wav2Vec2Model(BaseFairseqModel): | |
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) | |
- logits /= self.logit_temp | |
+ logits = logits / self.logit_temp | |
- if neg_is_pos.any(): | |
- logits[1:][neg_is_pos] = float("-inf") | |
+ if logits.device.type == 'xla' or neg_is_pos.any(): | |
+ #pass | |
+ fillval = -float(2**30) | |
+ if not hasattr(self, '_inftensor'): | |
+ self._inftensor = torch.tensor(fillval).to(x.device) | |
+ logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) | |
return logits | |
- def forward(self, source, padding_mask=None, mask=True, features_only=False): | |
+ def forward( | |
+ self, source, padding_mask=None, mask=True, features_only=False, | |
+ mask_indices=None, mask_channel_indices=None, padding_counts=None, | |
+ ): | |
if self.feature_grad_mult > 0: | |
features = self.feature_extractor(source) | |
@@ -562,9 +589,17 @@ class Wav2Vec2Model(BaseFairseqModel): | |
features = self.project_inp(features) | |
if mask: | |
- x, mask_indices = self.apply_mask(features, padding_mask) | |
- if mask_indices is not None: | |
- y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) | |
+ x, mask_indices = self.apply_mask( | |
+ features, padding_mask, | |
+ mask_indices=mask_indices, | |
+ mask_channel_indices=mask_channel_indices, | |
+ ) | |
+ if x.device.type != 'xla' and mask_indices is not None: | |
+ # tpu-comment: reducing the size in a dynamic way causes | |
+ # too many recompilations on xla. | |
+ y = unmasked_features[mask_indices].view( | |
+ unmasked_features.size(0), -1, unmasked_features.size(-1) | |
+ ) | |
else: | |
y = unmasked_features | |
else: | |
@@ -588,12 +623,18 @@ class Wav2Vec2Model(BaseFairseqModel): | |
y = self.project_q(y) | |
if self.negatives_from_everywhere: | |
- neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) | |
- negs, _ = self.sample_negatives(neg_cands, y.size(1)) | |
+ neg_cands, *_ = self.quantizer( | |
+ unmasked_features, produce_targets=False, | |
+ ) | |
+ negs, _ = self.sample_negatives( | |
+ neg_cands, y.size(1), padding_counts=padding_counts, | |
+ ) | |
negs = self.project_q(negs) | |
else: | |
- negs, _ = self.sample_negatives(y, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ y, y.size(1), padding_counts=padding_counts, | |
+ ) | |
if self.codebook_negatives > 0: | |
cb_negs = self.quantizer.sample_from_codebook( | |
@@ -608,17 +649,23 @@ class Wav2Vec2Model(BaseFairseqModel): | |
y = self.project_q(y) | |
if self.negatives_from_everywhere: | |
- negs, _ = self.sample_negatives(unmasked_features, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ unmasked_features, y.size(1), padding_counts=padding_counts, | |
+ ) | |
negs = self.project_q(negs) | |
else: | |
- negs, _ = self.sample_negatives(y, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ y, y.size(1), padding_counts=padding_counts, | |
+ ) | |
- x = x[mask_indices].view(x.size(0), -1, x.size(-1)) | |
+ if x.device.type != 'xla': | |
+ # tpu-comment: reducing the size in a dynamic way causes | |
+ # too many recompilations on xla. | |
+ x = x[mask_indices].view(x.size(0), -1, x.size(-1)) | |
if self.target_glu: | |
y = self.target_glu(y) | |
negs = self.target_glu(negs) | |
- | |
x = self.final_proj(x) | |
x = self.compute_preds(x, y, negs) | |
@@ -811,7 +858,7 @@ class TransformerEncoder(nn.Module): | |
x_conv = self.pos_conv(x.transpose(1, 2)) | |
x_conv = x_conv.transpose(1, 2) | |
- x += x_conv | |
+ x = x + x_conv | |
if not self.layer_norm_first: | |
x = self.layer_norm(x) | |
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py | |
index 01ddd229..35a08af1 100644 | |
--- a/fairseq/modules/gumbel_vector_quantizer.py | |
+++ b/fairseq/modules/gumbel_vector_quantizer.py | |
@@ -160,6 +160,7 @@ class GumbelVectorQuantizer(nn.Module): | |
avg_probs = torch.softmax( | |
x.view(bsz * tsz, self.groups, -1).float(), dim=-1 | |
).mean(dim=0) | |
+ avg_probs = avg_probs.detach() | |
result["prob_perplexity"] = torch.exp( | |
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) | |
).sum() | |
diff --git a/fairseq/options.py b/fairseq/options.py | |
index 171c6796..ce6edc34 100644 | |
--- a/fairseq/options.py | |
+++ b/fairseq/options.py | |
@@ -238,6 +238,7 @@ def get_parser(desc, default_task="translation"): | |
help='pseudo random number generator seed') | |
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') | |
parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA') | |
+ parser.add_argument('--xla-metrics-debug', action='store_true', help='Print XLA debug info') | |
parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu') | |
parser.add_argument('--fp16', action='store_true', help='use FP16') | |
parser.add_argument('--memory-efficient-bf16', action='store_true', | |
diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py | |
index f3363746..2683e705 100644 | |
--- a/fairseq/tasks/audio_pretraining.py | |
+++ b/fairseq/tasks/audio_pretraining.py | |
@@ -8,7 +8,9 @@ | |
import os | |
import sys | |
-from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset | |
+from fairseq.data import ( | |
+ FileAudioDataset, Dictionary, AddTargetDataset, BucketPadLengthDataset | |
+) | |
from . import FairseqTask, register_task | |
@@ -31,6 +33,14 @@ class AudioPretrainingTask(FairseqTask): | |
@staticmethod | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
+ parser.add_argument( | |
+ '--num-batch-buckets', default=0, type=int, | |
+ help=( | |
+ 'if >0, then bucket source and target lengths into N ' | |
+ 'buckets and pad accordingly; this is useful on TPUs ' | |
+ 'to minimize the number of compilations' | |
+ ), | |
+ ) | |
parser.add_argument("data", help="path to data directory") | |
parser.add_argument( | |
"--sample-rate", | |
@@ -103,6 +113,9 @@ class AudioPretrainingTask(FairseqTask): | |
min_length=self.args.min_sample_size, | |
pad=self.args.labels is not None or self.args.enable_padding, | |
normalize=self.args.normalize, | |
+ compute_mask_indices=self.args.tpu, | |
+ args=self.args, | |
+ num_buckets=self.args.num_batch_buckets or int(self.args.tpu), | |
) | |
if self.args.labels: | |
diff --git a/fairseq/trainer.py b/fairseq/trainer.py | |
index a91d12fd..37d578e3 100644 | |
--- a/fairseq/trainer.py | |
+++ b/fairseq/trainer.py | |
@@ -347,7 +347,7 @@ class Trainer(object): | |
num_shards=self.data_parallel_world_size if shard_batch_itr else 1, | |
shard_id=self.data_parallel_rank if shard_batch_itr else 0, | |
num_workers=self.args.num_workers, | |
- epoch=epoch | |
+ epoch=epoch, | |
) | |
def get_valid_iterator( | |
@@ -368,7 +368,7 @@ class Trainer(object): | |
seed=self.args.seed, | |
num_shards=self.data_parallel_world_size, | |
shard_id=self.data_parallel_rank, | |
- num_workers=self.args.num_workers | |
+ num_workers=self.args.num_workers, | |
) | |
def begin_epoch(self, epoch): | |
@@ -422,6 +422,7 @@ class Trainer(object): | |
try: | |
with maybe_no_sync(): | |
# forward and backward | |
+ | |
loss, sample_size_i, logging_output = self.task.train_step( | |
sample=sample, | |
model=self.model, | |
@@ -462,8 +463,7 @@ class Trainer(object): | |
# before marking step can lead to OOM errors. | |
# To handle gradient accumulation use case, we explicitly | |
# mark step here for every forward pass without a backward pass | |
- import torch_xla.core.xla_model as xm | |
- xm.mark_step() | |
+ self._xla_markstep_and_send_to_cpu() | |
if is_dummy_batch: | |
if torch.is_tensor(sample_size): | |
@@ -548,16 +548,23 @@ class Trainer(object): | |
if self.tpu: | |
# mark step on TPUs | |
- import torch_xla.core.xla_model as xm | |
- xm.mark_step() | |
# only log stats every log_interval steps | |
# this causes wps to be misreported when log_interval > 1 | |
logging_output = {} | |
if self.get_num_updates() % self.args.log_interval == 0: | |
+ logging_outputs = self._xla_markstep_and_send_to_cpu( | |
+ logging_outputs | |
+ ) | |
logging_output = self._reduce_and_log_stats( | |
logging_outputs, sample_size, grad_norm, | |
) | |
+ self._xla_markstep_and_send_to_cpu() | |
+ # FIXME: taylan when I put step closure, logging outputs is shrunk.. | |
+ #xm.add_step_closure( | |
+ # self._reduce_and_log_stats, | |
+ # args=(logging_outputs, sample_size, grad_norm) | |
+ #) | |
# log whenever there's an XLA compilation, since these | |
# slow down training and may indicate opportunities for | |
@@ -593,9 +600,7 @@ class Trainer(object): | |
if self._dummy_batch == "DUMMY": | |
self._dummy_batch = sample | |
if self.tpu: | |
- import torch_xla.core.xla_model as xm | |
- xm.rendezvous('valid_step') # wait for all workers | |
- xm.mark_step() | |
+ self._xla_markstep_and_send_to_cpu() | |
with torch.no_grad(): | |
self.model.eval() | |
@@ -641,6 +646,8 @@ class Trainer(object): | |
) | |
# log validation stats | |
+ if self.tpu: | |
+ logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) | |
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) | |
return logging_output | |
@@ -778,8 +785,9 @@ class Trainer(object): | |
utils.set_torch_seed(seed) | |
def _sync_stats(self): | |
- # Return True if it's using multiple GPUs and DDP or multiple GPUs with | |
- # BMUF and it's a bmuf sync with warmup iterations completed before. | |
+ # Return True if it's using multiple devices and DDP | |
+ # or multiple devices with BMUF and it's a bmuf sync | |
+ # with warmup iterations completed before. | |
if self.data_parallel_world_size == 1: | |
return False | |
elif self.args.use_bmuf: | |
@@ -912,7 +920,11 @@ class Trainer(object): | |
) | |
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): | |
- if grad_norm is not None: | |
+ # tpu-comment: grad_norm is a tensor in XLA | |
+ if ( | |
+ (not torch.is_tensor(grad_norm) and grad_norm is not None) | |
+ or (torch.is_tensor(grad_norm) and not torch.isnan(grad_norm)) | |
+ ): | |
metrics.log_speed("ups", 1., priority=100, round=2) | |
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) | |
if self.args.clip_norm > 0: | |
@@ -968,6 +980,13 @@ class Trainer(object): | |
logging.info("NOTE: XLA compilation detected; {}".format(message)) | |
self._num_xla_compiles = num_xla_compiles | |
+ def _xla_markstep_and_send_to_cpu(self, data=None): | |
+ import torch_xla.core.xla_model as xm | |
+ xm.mark_step() | |
+ if data is not None: | |
+ from fairseq.utils import xla_device_to_cpu | |
+ return xla_device_to_cpu(data) | |
+ | |
def _catalog_shared_params(module, memo=None, prefix=''): | |
if memo is None: | |
diff --git a/fairseq/utils.py b/fairseq/utils.py | |
index f6886033..7e8db9f6 100644 | |
--- a/fairseq/utils.py | |
+++ b/fairseq/utils.py | |
@@ -81,6 +81,7 @@ def move_to_cuda(sample): | |
def move_to_cpu(sample): | |
+ | |
def _move_to_cpu(tensor): | |
# PyTorch has poor support for half tensors (float16) on CPU. | |
# Move any such tensors to float32. | |
@@ -252,6 +253,9 @@ def convert_padding_direction( | |
def item(tensor): | |
+ # tpu-comment: making this a no-op for xla devices. | |
+ if torch.is_tensor(tensor) and tensor.device.type == 'xla': | |
+ return tensor.detach() | |
if hasattr(tensor, "item"): | |
return tensor.item() | |
if hasattr(tensor, "__getitem__"): | |
@@ -560,6 +564,23 @@ def get_tpu_device(args): | |
return xm.xla_device() | |
+def index_put(tensor, indices, value): | |
+ if tensor.device.type == 'xla': | |
+ for _ in range(indices.dim(), tensor.dim()): | |
+ indices = indices.unsqueeze(-1) | |
+ if indices.size(-1) < tensor.size(-1): | |
+ indices = indices.expand_as(tensor) | |
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) | |
+ else: | |
+ tensor[indices] = value | |
+ return tensor | |
+ | |
+ | |
+def xla_device_to_cpu(dat): | |
+ import torch_xla.core.xla_model as xm | |
+ return xm._maybe_convert_to_cpu(dat) | |
+ | |
+ | |
class CudaEnvironment(object): | |
def __init__(self): | |
cur_device = torch.cuda.current_device() | |
diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py | |
index 806e4bc5..832f4746 100644 | |
--- a/fairseq_cli/train.py | |
+++ b/fairseq_cli/train.py | |
@@ -96,7 +96,7 @@ def main(args): | |
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) | |
) | |
logger.info( | |
- "max tokens per GPU = {} and max sentences per GPU = {}".format( | |
+ "max tokens per device = {} and max sentences per device = {}".format( | |
args.max_tokens, args.max_sentences | |
) | |
) | |
@@ -106,9 +106,7 @@ def main(args): | |
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) | |
if args.tpu: | |
import torch_xla.core.xla_model as xm | |
- | |
xm.rendezvous("load_checkpoint") # wait for all workers | |
- xm.mark_step() | |
# Train until the learning rate gets too small | |
max_epoch = args.max_epoch or math.inf | |
@@ -167,7 +165,6 @@ def tpu_data_loader(args, itr): | |
import torch_xla.distributed.parallel_loader as pl | |
xm.rendezvous("tpu_data_loader") # wait for all workers | |
- xm.mark_step() | |
device = utils.get_tpu_device(args) | |
return iterators.CountingIterator( | |
pl.ParallelLoader(itr, [device]).per_device_loader(device), | |
@@ -211,6 +208,12 @@ def train(args, trainer, task, epoch_itr): | |
should_stop = False | |
num_updates = trainer.get_num_updates() | |
for i, samples in enumerate(progress): | |
+ | |
+ from fairseq.metsumm import metsumm as m | |
+ if not i % 50: | |
+ import torch_xla.core.xla_model as xm | |
+ if xm.is_master_ordinal(): | |
+ m(str(i)) | |
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( | |
"train_step-%d" % i | |
): | |
@@ -226,7 +229,10 @@ def train(args, trainer, task, epoch_itr): | |
# reset mid-epoch stats after each log interval | |
# the end-of-epoch stats will still be preserved | |
+ # FIXME: taylan reset in closure!!!!!!!!!!!!!!!!! | |
metrics.reset_meters("train_inner") | |
+ if args.xla_metrics_debug: | |
+ metrics.xla_metrics_report() | |
end_of_epoch = not itr.has_next() | |
valid_losses, should_stop = validate_and_save( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment