Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Created October 29, 2020 18:29
Show Gist options
  • Save taylanbil/1e41b03261cbbdb9900b51874e5da532 to your computer and use it in GitHub Desktop.
Save taylanbil/1e41b03261cbbdb9900b51874e5da532 to your computer and use it in GitHub Desktop.
$ git diff 621e8341 > ~/debug/20201005-wav2vec2/latestdiff.diff
(torch-xla-1.7)
taylanbil at beefy-pytorch-xla-eu in ~/debug/20201005-wav2vec2/fairseq (w2v2●●)
$ cat ~/debug/20201005-wav2vec2/latestdiff.diff
diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py
index 019db622..c316147f 100644
--- a/fairseq/criterions/wav2vec_criterion.py
+++ b/fairseq/criterions/wav2vec_criterion.py
@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.utils import index_put
@register_criterion('wav2vec')
@@ -41,9 +42,18 @@ class Wav2vecCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
+
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
+ if logits.device.type == 'xla':
+ # tpu-comment: since dynamic shapes lead to recompilations on xla,
+ # we don't shrink tensors using mask_indices.
+ # Instead, we do the following when computing loss:
+ mi = sample['net_input']['mask_indices'].reshape(logits.size(0))
+ target = index_put(target, ~mi, -1)
+
+ # XXX: handle weights on xla.
weights = None
if hasattr(model, 'get_target_weights') and not self.infonce:
weights = model.get_target_weights(target, net_output)
@@ -53,11 +63,24 @@ class Wav2vecCriterion(FairseqCriterion):
losses = []
if self.infonce:
- loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
+ loss = F.cross_entropy(
+ logits, target, reduction="sum" if reduce else "none",
+ ignore_index=-1,
+ )
else:
- loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
+ loss = F.binary_cross_entropy_with_logits(
+ logits, target.float(), weights,
+ reduction="sum" if reduce else "none",
+ ignore_index=-1,
+ )
- sample_size = target.numel() if self.infonce else target.long().sum().item()
+ if 'sample_size' in sample and self.infonce:
+ sample_size = sample['sample_size']
+ elif 'mask_indices' in sample['net_input'] and self.infonce:
+ # XXX: what happens if not self.infonce?
+ sample_size = sample['net_input']['mask_indices'].sum()
+ else:
+ sample_size = target.numel() if self.infonce else target.long().sum()
losses.append(loss)
if self.loss_weights is not None:
@@ -75,7 +98,8 @@ class Wav2vecCriterion(FairseqCriterion):
losses.append(p)
logging_output = {
- 'loss': loss.item() if reduce else loss,
+ #'loss': loss.item() if reduce else loss,
+ 'loss': loss.detach(),
'ntokens': sample_size,
'nsentences': sample['id'].numel(),
'sample_size': sample_size,
@@ -83,11 +107,14 @@ class Wav2vecCriterion(FairseqCriterion):
for lk in self.log_keys:
if lk in net_output:
- logging_output[lk] = float((net_output[lk]))
+ value = net_output[lk]
+ if not torch.is_tensor(value) or value.device.type != 'xla':
+ value = float(value)
+ logging_output[lk] = value
if len(losses) > 1:
for i, l in enumerate(losses):
- logging_output[f'loss_{i}'] = l.item()
+ logging_output[f'loss_{i}'] = l.detach()
if self.infonce:
with torch.no_grad():
@@ -99,13 +126,14 @@ class Wav2vecCriterion(FairseqCriterion):
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
- corr = max.long().sum().item() - both.long().sum().item()
- count = max.numel()
+ # corr = max.long().sum().item() - both.long().sum().item()
+ corr = max.long().sum() - both.long().sum()
+ count = float(max.numel())
logging_output["correct"] = corr
logging_output["count"] = count
- if log_pred:
+ if log_pred and logits.device.type != 'xla':
logging_output['logits'] = logits.cpu().numpy()
logging_output['target'] = target.cpu().numpy()
return loss, sample_size, logging_output
@@ -132,7 +160,7 @@ class Wav2vecCriterion(FairseqCriterion):
if total > 0:
metrics.log_derived(
"accuracy",
- lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5)
+ lambda meters: meters["_correct"].sum / meters["_total"].sum
if meters["_total"].sum > 0
else float("nan"),
)
@@ -154,4 +182,4 @@ class Wav2vecCriterion(FairseqCriterion):
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
- return False
+ return True
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py
index 675b0956..2114c93a 100644
--- a/fairseq/data/audio/raw_audio_dataset.py
+++ b/fairseq/data/audio/raw_audio_dataset.py
@@ -12,7 +12,8 @@ import sys
import torch
import torch.nn.functional as F
-from .. import FairseqDataset
+from .. import FairseqDataset, BaseWrapperDataset
+from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
logger = logging.getLogger(__name__)
@@ -27,6 +28,8 @@ class RawAudioDataset(FairseqDataset):
min_length=0,
pad=False,
normalize=False,
+ compute_mask_indices=False,
+ args=None,
):
super().__init__()
@@ -40,6 +43,12 @@ class RawAudioDataset(FairseqDataset):
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
+ self.compute_mask_indices = compute_mask_indices
+ if self.compute_mask_indices:
+ self.args = args
+ self._features_size_map = {}
+ self._C = self.args.encoder_embed_dim
+ self._conv_feature_layers = eval(self.args.conv_feature_layers)
def __getitem__(self, index):
raise NotImplementedError()
@@ -71,6 +80,45 @@ class RawAudioDataset(FairseqDataset):
end = size - diff + start
return wav[start:end]
+ def _compute_mask_indices(self, dims, padding_mask):
+ B, T, C = dims
+ mask_indices, mask_channel_indices = None, None
+ if self.args.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.args.mask_prob,
+ self.args.mask_length,
+ self.args.mask_selection,
+ self.args.mask_other,
+ min_masks=2,
+ no_overlap=self.args.no_mask_overlap,
+ min_space=self.args.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices)
+ if self.args.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.args.mask_channel_prob,
+ self.args.mask_channel_length,
+ self.args.mask_channel_selection,
+ self.args.mask_channel_other,
+ no_overlap=self.args.no_mask_channel_overlap,
+ min_space=self.args.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+
+ return mask_indices, mask_channel_indices
+
+ @staticmethod
+ def _bucket_tensor(tensor, num_pad, value):
+ return F.pad(tensor, (0, num_pad), value=value)
+
def collater(self, samples):
samples = [
s
@@ -106,9 +154,53 @@ class RawAudioDataset(FairseqDataset):
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
+ out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
- return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}
+
+ if hasattr(self, 'num_buckets') and self.num_buckets > 0:
+ assert self.pad, "Cannot bucket without padding first."
+ bucket = max(self._bucketed_sizes[s['id']] for s in samples)
+ num_pad = bucket - collated_sources.size(-1)
+ if num_pad:
+ input['source'] = self._bucket_tensor(
+ collated_sources, num_pad, 0
+ )
+ input['padding_mask'] = self._bucket_tensor(
+ padding_mask, num_pad, True
+ )
+
+ if self.compute_mask_indices:
+ B = input['source'].size(0)
+ T = self._get_mask_indices_dims(input['source'].size(-1))
+ padding_mask_reshaped = input['padding_mask'].clone()
+ extra = padding_mask_reshaped.size(1) % T
+ if extra > 0:
+ padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
+ padding_mask_reshaped = padding_mask_reshaped.view(
+ padding_mask_reshaped.size(0), T, -1
+ )
+ padding_mask_reshaped = padding_mask_reshaped.all(-1)
+ mask_indices, mask_channel_indices = self._compute_mask_indices(
+ (B, T, self._C), padding_mask_reshaped,
+ )
+ input["mask_indices"] = mask_indices
+ input['padding_counts'] = input['mask_indices'].sum(-1).tolist()
+ input["mask_channel_indices"] = mask_channel_indices
+ out['sample_size'] = mask_indices.sum().item()
+
+ out["net_input"] = input
+ return out
+
+ def _get_mask_indices_dims(self, size, padding=0, dilation=1):
+ if size not in self._features_size_map:
+ L_in = size
+ for (_, kernel_size, stride) in self._conv_feature_layers:
+ L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1
+ L_out = 1 + L_out // stride
+ L_in = L_out
+ self._features_size_map[size] = L_out
+ return self._features_size_map[size]
def num_tokens(self, index):
return self.size(index)
@@ -144,6 +236,9 @@ class FileAudioDataset(RawAudioDataset):
min_length=0,
pad=False,
normalize=False,
+ compute_mask_indices=False,
+ args=None,
+ num_buckets=0,
):
super().__init__(
sample_rate=sample_rate,
@@ -153,6 +248,8 @@ class FileAudioDataset(RawAudioDataset):
min_length=min_length,
pad=pad,
normalize=normalize,
+ compute_mask_indices=compute_mask_indices,
+ args=args,
)
self.fnames = []
@@ -169,8 +266,26 @@ class FileAudioDataset(RawAudioDataset):
continue
self.fnames.append(items[0])
self.sizes.append(sz)
+ self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
+ def set_bucket_info(self, num_buckets):
+ self.num_buckets = num_buckets
+ if self.num_buckets > 0:
+ self._collated_sizes = np.minimum(
+ np.array(self.sizes), self.max_sample_size,
+ )
+ self.buckets = get_buckets(
+ self._collated_sizes, self.num_buckets,
+ )
+ self._bucketed_sizes = get_bucketed_sizes(
+ self._collated_sizes, self.buckets
+ )
+ logger.info(
+ f"{len(self.buckets)} bucket(s) for the audio dataset: "
+ f"{self.buckets}"
+ )
+
def __getitem__(self, index):
import soundfile as sf
diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py
index 6f53d011..e4ce3c40 100644
--- a/fairseq/data/bucket_pad_length_dataset.py
+++ b/fairseq/data/bucket_pad_length_dataset.py
@@ -7,6 +7,7 @@ import numpy as np
import torch.nn.functional as F
from fairseq.data import BaseWrapperDataset
+from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
class BucketPadLengthDataset(BaseWrapperDataset):
@@ -30,42 +31,43 @@ class BucketPadLengthDataset(BaseWrapperDataset):
num_buckets,
pad_idx,
left_pad,
+ tensor_key=None,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
assert num_buckets > 0
- self.buckets = np.unique(
- np.percentile(
- sizes,
- np.linspace(0, 100, num_buckets + 1),
- interpolation='lower',
- )[1:]
- )
+ self.buckets = get_buckets(sizes, num_buckets)
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
+ self._tensor_key = tensor_key
- def get_bucketed_sizes(orig_sizes, buckets):
- sizes = np.copy(orig_sizes)
- assert np.min(sizes) >= 0
- start_val = -1
- for end_val in buckets:
- mask = (sizes > start_val) & (sizes <= end_val)
- sizes[mask] = end_val
- start_val = end_val
- return sizes
+ def _set_tensor(self, item, val):
+ if self._tensor_key is None:
+ return val
+ item[self._tensor_key] = val
+ return item
- self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
+ def _get_tensor(self, item):
+ if self._tensor_key is None:
+ return item
+ return item[self._tensor_key]
- def __getitem__(self, index):
- item = self.dataset[index]
- bucket_size = self._bucketed_sizes[index]
- num_pad = bucket_size - item.size(-1)
+ def _pad(self, tensor, bucket_size, dim=-1):
+ num_pad = bucket_size - tensor.size(dim)
return F.pad(
- item,
+ tensor,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ bucket_size = self._bucketed_sizes[index]
+ tensor = self._get_tensor(item)
+ padded = self._pad(tensor, bucket_size)
+ return self._set_tensor(item, padded)
+
@property
def sizes(self):
return self._bucketed_sizes
diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py
index 57991a88..e6818219 100644
--- a/fairseq/data/data_utils.py
+++ b/fairseq/data/data_utils.py
@@ -272,6 +272,7 @@ def post_process(sentence: str, symbol: str):
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
+
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
@@ -283,6 +284,7 @@ def compute_mask_indices(
no_overlap: bool = False,
min_space: int = 0,
) -> np.ndarray:
+#) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Computes random mask spans for a given shape
@@ -393,3 +395,25 @@ def compute_mask_indices(
mask[i, mask_idc] = True
return mask
+
+
+def get_buckets(sizes, num_buckets):
+ buckets = np.unique(
+ np.percentile(
+ sizes,
+ np.linspace(0, 100, num_buckets + 1),
+ interpolation='lower',
+ )[1:]
+ )
+ return buckets
+
+
+def get_bucketed_sizes(orig_sizes, buckets):
+ sizes = np.copy(orig_sizes)
+ assert np.min(sizes) >= 0
+ start_val = -1
+ for end_val in buckets:
+ mask = (sizes > start_val) & (sizes <= end_val)
+ sizes[mask] = end_val
+ start_val = end_val
+ return sizes
diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py
index 7ee89adc..0060d8c1 100644
--- a/fairseq/distributed_utils.py
+++ b/fairseq/distributed_utils.py
@@ -113,7 +113,6 @@ def distributed_init(args):
args.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal()
xm.rendezvous('distributed_init') # wait for all workers
- xm.mark_step()
if is_master(args):
logging.getLogger().setLevel(logging.INFO)
@@ -182,7 +181,10 @@ def call_main(args, main, **kwargs):
xmp.spawn(
fn=distributed_main,
args=(main, args, kwargs),
- nprocs=8, # use all 8 TPU cores
+ # tpu-comment:
+ # 8 devices in one TPU VM, is the max processes to be spawned.
+ # The rest is driven by xm.distributed.xla_dist
+ nprocs=min(args.distributed_world_size, 8),
)
else:
# single GPU main
diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py
index 6ca1d201..0bac4c25 100644
--- a/fairseq/logging/metrics.py
+++ b/fairseq/logging/metrics.py
@@ -289,3 +289,11 @@ def load_state_dict(state_dict):
for name, agg_state in state_dict.items():
_aggregators[name] = MetersDict()
_aggregators[name].load_state_dict(agg_state)
+
+
+def xla_metrics_report():
+ try:
+ import torch_xla.debug.metrics as met
+ print(met.metrics_report())
+ except ImportError:
+ return
diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py
new file mode 100644
index 00000000..ce0cfa05
--- /dev/null
+++ b/fairseq/metsumm.py
@@ -0,0 +1,18 @@
+# FIXME: remove this file
+def metsumm(stepno=''):
+ if hasattr(metsumm, 'STEPNO'):
+ metsumm.STEPNO += stepno.lower()=="before forward"
+ else:
+ metsumm.STEPNO = 0
+ try:
+ import torch_xla.debug.metrics as met
+ x = met.metrics_report().split('\n')
+ for i, line in enumerate(x):
+ if 'CompileTime' in line or 'aten::' in line:
+ key = line.split()[-1]
+ value = x[i+1].split()[-1]
+ print('step {}-{}, key {}, value {}'.format(
+ metsumm.STEPNO, stepno, key, value)
+ )
+ except RuntimeError:
+ return
diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py
index 226f035b..be3f8df0 100644
--- a/fairseq/models/wav2vec/wav2vec2.py
+++ b/fairseq/models/wav2vec/wav2vec2.py
@@ -27,7 +27,7 @@ from fairseq.modules import (
TransposeLast,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
-from fairseq.utils import buffered_arange
+from fairseq.utils import buffered_arange, index_put
@register_model("wav2vec2")
@@ -405,47 +405,66 @@ class Wav2Vec2Model(BaseFairseqModel):
return cls(args)
- def apply_mask(self, x, padding_mask):
+ def apply_mask(
+ self, x, padding_mask,
+ mask_indices=None, mask_channel_indices=None,
+ ):
B, T, C = x.shape
if self.mask_prob > 0:
- mask_indices = compute_mask_indices(
- (B, T),
- padding_mask,
- self.mask_prob,
- self.mask_length,
- self.mask_selection,
- self.mask_other,
- min_masks=2,
- no_overlap=self.no_mask_overlap,
- min_space=self.mask_min_space,
- )
- mask_indices = torch.from_numpy(mask_indices).to(x.device)
- x[mask_indices] = self.mask_emb
+ if mask_indices is None:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x = index_put(x, mask_indices, self.mask_emb)
else:
mask_indices = None
if self.mask_channel_prob > 0:
- mask_channel_indices = compute_mask_indices(
- (B, C),
- None,
- self.mask_channel_prob,
- self.mask_channel_length,
- self.mask_channel_selection,
- self.mask_channel_other,
- no_overlap=self.no_mask_channel_overlap,
- min_space=self.mask_channel_min_space,
- )
- mask_channel_indices = (
- torch.from_numpy(mask_channel_indices)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
- )
- x[mask_channel_indices] = 0
+ if mask_channel_indices is None:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x = index_put(x, mask_channel_indices, 0)
return x, mask_indices
- def sample_negatives(self, y, num):
+ def _get_neg_idxs(self, high, size, padding_counts=None):
+ if padding_counts is None:
+ neg_idxs = torch.randint(low=0, high=high-1, size=size)
+ else:
+ bsz, l = size
+ num = l // self.n_negatives
+ assert len(padding_counts) == bsz
+ neg_idxs = [
+ torch.randint(low=0, high=num-pc-1, size=(1, l))
+ for pc in padding_counts
+ ]
+ neg_idxs = torch.stack(neg_idxs)
+ return neg_idxs
+
+ def sample_negatives(self, y, num, padding_counts=None):
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
return y.new(0)
@@ -466,8 +485,9 @@ class Wav2Vec2Model(BaseFairseqModel):
.flatten()
)
- neg_idxs = torch.randint(
- low=0, high=high - 1, size=(bsz, self.n_negatives * num)
+ neg_idxs = self._get_neg_idxs(
+ high, (bsz, self.n_negatives * num),
+ padding_counts=padding_counts,
)
neg_idxs[neg_idxs >= tszs] += 1
@@ -511,14 +531,21 @@ class Wav2Vec2Model(BaseFairseqModel):
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
- logits /= self.logit_temp
+ logits = logits / self.logit_temp
- if neg_is_pos.any():
- logits[1:][neg_is_pos] = float("-inf")
+ if logits.device.type == 'xla' or neg_is_pos.any():
+ #pass
+ fillval = -float(2**30)
+ if not hasattr(self, '_inftensor'):
+ self._inftensor = torch.tensor(fillval).to(x.device)
+ logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
return logits
- def forward(self, source, padding_mask=None, mask=True, features_only=False):
+ def forward(
+ self, source, padding_mask=None, mask=True, features_only=False,
+ mask_indices=None, mask_channel_indices=None, padding_counts=None,
+ ):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
@@ -562,9 +589,17 @@ class Wav2Vec2Model(BaseFairseqModel):
features = self.project_inp(features)
if mask:
- x, mask_indices = self.apply_mask(features, padding_mask)
- if mask_indices is not None:
- y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1))
+ x, mask_indices = self.apply_mask(
+ features, padding_mask,
+ mask_indices=mask_indices,
+ mask_channel_indices=mask_channel_indices,
+ )
+ if x.device.type != 'xla' and mask_indices is not None:
+ # tpu-comment: reducing the size in a dynamic way causes
+ # too many recompilations on xla.
+ y = unmasked_features[mask_indices].view(
+ unmasked_features.size(0), -1, unmasked_features.size(-1)
+ )
else:
y = unmasked_features
else:
@@ -588,12 +623,18 @@ class Wav2Vec2Model(BaseFairseqModel):
y = self.project_q(y)
if self.negatives_from_everywhere:
- neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False)
- negs, _ = self.sample_negatives(neg_cands, y.size(1))
+ neg_cands, *_ = self.quantizer(
+ unmasked_features, produce_targets=False,
+ )
+ negs, _ = self.sample_negatives(
+ neg_cands, y.size(1), padding_counts=padding_counts,
+ )
negs = self.project_q(negs)
else:
- negs, _ = self.sample_negatives(y, y.size(1))
+ negs, _ = self.sample_negatives(
+ y, y.size(1), padding_counts=padding_counts,
+ )
if self.codebook_negatives > 0:
cb_negs = self.quantizer.sample_from_codebook(
@@ -608,17 +649,23 @@ class Wav2Vec2Model(BaseFairseqModel):
y = self.project_q(y)
if self.negatives_from_everywhere:
- negs, _ = self.sample_negatives(unmasked_features, y.size(1))
+ negs, _ = self.sample_negatives(
+ unmasked_features, y.size(1), padding_counts=padding_counts,
+ )
negs = self.project_q(negs)
else:
- negs, _ = self.sample_negatives(y, y.size(1))
+ negs, _ = self.sample_negatives(
+ y, y.size(1), padding_counts=padding_counts,
+ )
- x = x[mask_indices].view(x.size(0), -1, x.size(-1))
+ if x.device.type != 'xla':
+ # tpu-comment: reducing the size in a dynamic way causes
+ # too many recompilations on xla.
+ x = x[mask_indices].view(x.size(0), -1, x.size(-1))
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
-
x = self.final_proj(x)
x = self.compute_preds(x, y, negs)
@@ -811,7 +858,7 @@ class TransformerEncoder(nn.Module):
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
- x += x_conv
+ x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py
index 01ddd229..35a08af1 100644
--- a/fairseq/modules/gumbel_vector_quantizer.py
+++ b/fairseq/modules/gumbel_vector_quantizer.py
@@ -160,6 +160,7 @@ class GumbelVectorQuantizer(nn.Module):
avg_probs = torch.softmax(
x.view(bsz * tsz, self.groups, -1).float(), dim=-1
).mean(dim=0)
+ avg_probs = avg_probs.detach()
result["prob_perplexity"] = torch.exp(
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
).sum()
diff --git a/fairseq/options.py b/fairseq/options.py
index 171c6796..ce6edc34 100644
--- a/fairseq/options.py
+++ b/fairseq/options.py
@@ -238,6 +238,7 @@ def get_parser(desc, default_task="translation"):
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA')
+ parser.add_argument('--xla-metrics-debug', action='store_true', help='Print XLA debug info')
parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu')
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-bf16', action='store_true',
diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py
index f3363746..2683e705 100644
--- a/fairseq/tasks/audio_pretraining.py
+++ b/fairseq/tasks/audio_pretraining.py
@@ -8,7 +8,9 @@
import os
import sys
-from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset
+from fairseq.data import (
+ FileAudioDataset, Dictionary, AddTargetDataset, BucketPadLengthDataset
+)
from . import FairseqTask, register_task
@@ -31,6 +33,14 @@ class AudioPretrainingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
+ parser.add_argument(
+ '--num-batch-buckets', default=0, type=int,
+ help=(
+ 'if >0, then bucket source and target lengths into N '
+ 'buckets and pad accordingly; this is useful on TPUs '
+ 'to minimize the number of compilations'
+ ),
+ )
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--sample-rate",
@@ -103,6 +113,9 @@ class AudioPretrainingTask(FairseqTask):
min_length=self.args.min_sample_size,
pad=self.args.labels is not None or self.args.enable_padding,
normalize=self.args.normalize,
+ compute_mask_indices=self.args.tpu,
+ args=self.args,
+ num_buckets=self.args.num_batch_buckets or int(self.args.tpu),
)
if self.args.labels:
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index a91d12fd..37d578e3 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -347,7 +347,7 @@ class Trainer(object):
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
- epoch=epoch
+ epoch=epoch,
)
def get_valid_iterator(
@@ -368,7 +368,7 @@ class Trainer(object):
seed=self.args.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
- num_workers=self.args.num_workers
+ num_workers=self.args.num_workers,
)
def begin_epoch(self, epoch):
@@ -422,6 +422,7 @@ class Trainer(object):
try:
with maybe_no_sync():
# forward and backward
+
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
@@ -462,8 +463,7 @@ class Trainer(object):
# before marking step can lead to OOM errors.
# To handle gradient accumulation use case, we explicitly
# mark step here for every forward pass without a backward pass
- import torch_xla.core.xla_model as xm
- xm.mark_step()
+ self._xla_markstep_and_send_to_cpu()
if is_dummy_batch:
if torch.is_tensor(sample_size):
@@ -548,16 +548,23 @@ class Trainer(object):
if self.tpu:
# mark step on TPUs
- import torch_xla.core.xla_model as xm
- xm.mark_step()
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
+ logging_outputs = self._xla_markstep_and_send_to_cpu(
+ logging_outputs
+ )
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)
+ self._xla_markstep_and_send_to_cpu()
+ # FIXME: taylan when I put step closure, logging outputs is shrunk..
+ #xm.add_step_closure(
+ # self._reduce_and_log_stats,
+ # args=(logging_outputs, sample_size, grad_norm)
+ #)
# log whenever there's an XLA compilation, since these
# slow down training and may indicate opportunities for
@@ -593,9 +600,7 @@ class Trainer(object):
if self._dummy_batch == "DUMMY":
self._dummy_batch = sample
if self.tpu:
- import torch_xla.core.xla_model as xm
- xm.rendezvous('valid_step') # wait for all workers
- xm.mark_step()
+ self._xla_markstep_and_send_to_cpu()
with torch.no_grad():
self.model.eval()
@@ -641,6 +646,8 @@ class Trainer(object):
)
# log validation stats
+ if self.tpu:
+ logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
return logging_output
@@ -778,8 +785,9 @@ class Trainer(object):
utils.set_torch_seed(seed)
def _sync_stats(self):
- # Return True if it's using multiple GPUs and DDP or multiple GPUs with
- # BMUF and it's a bmuf sync with warmup iterations completed before.
+ # Return True if it's using multiple devices and DDP
+ # or multiple devices with BMUF and it's a bmuf sync
+ # with warmup iterations completed before.
if self.data_parallel_world_size == 1:
return False
elif self.args.use_bmuf:
@@ -912,7 +920,11 @@ class Trainer(object):
)
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
- if grad_norm is not None:
+ # tpu-comment: grad_norm is a tensor in XLA
+ if (
+ (not torch.is_tensor(grad_norm) and grad_norm is not None)
+ or (torch.is_tensor(grad_norm) and not torch.isnan(grad_norm))
+ ):
metrics.log_speed("ups", 1., priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0:
@@ -968,6 +980,13 @@ class Trainer(object):
logging.info("NOTE: XLA compilation detected; {}".format(message))
self._num_xla_compiles = num_xla_compiles
+ def _xla_markstep_and_send_to_cpu(self, data=None):
+ import torch_xla.core.xla_model as xm
+ xm.mark_step()
+ if data is not None:
+ from fairseq.utils import xla_device_to_cpu
+ return xla_device_to_cpu(data)
+
def _catalog_shared_params(module, memo=None, prefix=''):
if memo is None:
diff --git a/fairseq/utils.py b/fairseq/utils.py
index f6886033..7e8db9f6 100644
--- a/fairseq/utils.py
+++ b/fairseq/utils.py
@@ -81,6 +81,7 @@ def move_to_cuda(sample):
def move_to_cpu(sample):
+
def _move_to_cpu(tensor):
# PyTorch has poor support for half tensors (float16) on CPU.
# Move any such tensors to float32.
@@ -252,6 +253,9 @@ def convert_padding_direction(
def item(tensor):
+ # tpu-comment: making this a no-op for xla devices.
+ if torch.is_tensor(tensor) and tensor.device.type == 'xla':
+ return tensor.detach()
if hasattr(tensor, "item"):
return tensor.item()
if hasattr(tensor, "__getitem__"):
@@ -560,6 +564,23 @@ def get_tpu_device(args):
return xm.xla_device()
+def index_put(tensor, indices, value):
+ if tensor.device.type == 'xla':
+ for _ in range(indices.dim(), tensor.dim()):
+ indices = indices.unsqueeze(-1)
+ if indices.size(-1) < tensor.size(-1):
+ indices = indices.expand_as(tensor)
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
+ else:
+ tensor[indices] = value
+ return tensor
+
+
+def xla_device_to_cpu(dat):
+ import torch_xla.core.xla_model as xm
+ return xm._maybe_convert_to_cpu(dat)
+
+
class CudaEnvironment(object):
def __init__(self):
cur_device = torch.cuda.current_device()
diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py
index 806e4bc5..832f4746 100644
--- a/fairseq_cli/train.py
+++ b/fairseq_cli/train.py
@@ -96,7 +96,7 @@ def main(args):
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
)
logger.info(
- "max tokens per GPU = {} and max sentences per GPU = {}".format(
+ "max tokens per device = {} and max sentences per device = {}".format(
args.max_tokens, args.max_sentences
)
)
@@ -106,9 +106,7 @@ def main(args):
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
if args.tpu:
import torch_xla.core.xla_model as xm
-
xm.rendezvous("load_checkpoint") # wait for all workers
- xm.mark_step()
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
@@ -167,7 +165,6 @@ def tpu_data_loader(args, itr):
import torch_xla.distributed.parallel_loader as pl
xm.rendezvous("tpu_data_loader") # wait for all workers
- xm.mark_step()
device = utils.get_tpu_device(args)
return iterators.CountingIterator(
pl.ParallelLoader(itr, [device]).per_device_loader(device),
@@ -211,6 +208,12 @@ def train(args, trainer, task, epoch_itr):
should_stop = False
num_updates = trainer.get_num_updates()
for i, samples in enumerate(progress):
+
+ from fairseq.metsumm import metsumm as m
+ if not i % 50:
+ import torch_xla.core.xla_model as xm
+ if xm.is_master_ordinal():
+ m(str(i))
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
"train_step-%d" % i
):
@@ -226,7 +229,10 @@ def train(args, trainer, task, epoch_itr):
# reset mid-epoch stats after each log interval
# the end-of-epoch stats will still be preserved
+ # FIXME: taylan reset in closure!!!!!!!!!!!!!!!!!
metrics.reset_meters("train_inner")
+ if args.xla_metrics_debug:
+ metrics.xla_metrics_report()
end_of_epoch = not itr.has_next()
valid_losses, should_stop = validate_and_save(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment