Skip to content

Instantly share code, notes, and snippets.

@ultrons
Created June 19, 2020 23:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ultrons/3e3e0d2a6c97c7c9e9a483b6af9a4043 to your computer and use it in GitHub Desktop.
Save ultrons/3e3e0d2a6c97c7c9e9a483b6af9a4043 to your computer and use it in GitHub Desktop.
diff --git a/examples/wav2vec/train-vq.sh b/examples/wav2vec/train-vq.sh
index e2065b6..cc48fba 100644
--- a/examples/wav2vec/train-vq.sh
+++ b/examples/wav2vec/train-vq.sh
@@ -6,7 +6,7 @@ python \
--tpu \
--bf16 \
--distributed-world-size 8 \
---max-sentences 1 \
+--max-sentences 8 \
--num-workers 6 \
--max-update 40000 \
--save-interval 1 \
@@ -23,8 +23,8 @@ python \
--conv-feature-layers '[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]' \
--conv-aggregator-layers '[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]' \
--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \
- --log-keys '["prob_perplexity","code_perplexity","temp"]' --vq-type gumbel --vq-groups 2 --vq-depth 2 \
--combine-groups --vq-vars 320 --vq-temp '(2,0.5,0.999995)' --prediction-steps 12 --warmup-updates 1000 \
+--log-keys '["prob_perplexity","temp"]' --vq-type gumbel --vq-groups 2 --vq-depth 2 \
--log-compression \
--warmup-init-lr 1e-07 \
--criterion binary_cross_entropy \
@@ -34,3 +34,4 @@ python \
--skip-invalid-size-inputs-valid-test \
--log-interval 20 \
--log-format simple
+
diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py
index 84fcc68..0154467 100644
--- a/fairseq/distributed_utils.py
+++ b/fairseq/distributed_utils.py
@@ -309,4 +309,7 @@ def all_reduce_dict(
return device_data[key]
raise KeyError
- return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
+ result = OrderedDict([(key, get_from_stack(key)) for key in data_keys])
+ del cpu_data
+ del device_data
+ return result
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py
index 2efc10e..945411e 100644
--- a/fairseq/modules/gumbel_vector_quantizer.py
+++ b/fairseq/modules/gumbel_vector_quantizer.py
@@ -136,6 +136,7 @@ class GumbelVectorQuantizer(nn.Module):
avg_probs = torch.softmax(
x.view(bsz * tsz, self.groups, -1).float(), dim=-1
).mean(dim=0)
+ avg_probs = avg_probs.detach()
result["prob_perplexity"] = torch.exp(
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)
).sum()
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index a6b4f0c..c67895b 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -525,7 +525,7 @@ class Trainer(object):
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
- self.logging_history += logging_outputs
+ self.logging_history.extend(logging_outputs)
self.cumm_sample_size += sample_size
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
@@ -851,6 +851,7 @@ class Trainer(object):
logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}]
else:
logging_outputs = []
+ del data
return logging_outputs, extra_stats_to_sum
def _check_grad_norms(self, grad_norm):
(torch-xla-nightly) sivaibhav@pytorch-dev-3:~/fairseq$
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment