Created
June 19, 2020 23:21
-
-
Save ultrons/3e3e0d2a6c97c7c9e9a483b6af9a4043 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/examples/wav2vec/train-vq.sh b/examples/wav2vec/train-vq.sh | |
index e2065b6..cc48fba 100644 | |
--- a/examples/wav2vec/train-vq.sh | |
+++ b/examples/wav2vec/train-vq.sh | |
@@ -6,7 +6,7 @@ python \ | |
--tpu \ | |
--bf16 \ | |
--distributed-world-size 8 \ | |
---max-sentences 1 \ | |
+--max-sentences 8 \ | |
--num-workers 6 \ | |
--max-update 40000 \ | |
--save-interval 1 \ | |
@@ -23,8 +23,8 @@ python \ | |
--conv-feature-layers '[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]' \ | |
--conv-aggregator-layers '[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]' \ | |
--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ | |
- --log-keys '["prob_perplexity","code_perplexity","temp"]' --vq-type gumbel --vq-groups 2 --vq-depth 2 \ | |
--combine-groups --vq-vars 320 --vq-temp '(2,0.5,0.999995)' --prediction-steps 12 --warmup-updates 1000 \ | |
+--log-keys '["prob_perplexity","temp"]' --vq-type gumbel --vq-groups 2 --vq-depth 2 \ | |
--log-compression \ | |
--warmup-init-lr 1e-07 \ | |
--criterion binary_cross_entropy \ | |
@@ -34,3 +34,4 @@ python \ | |
--skip-invalid-size-inputs-valid-test \ | |
--log-interval 20 \ | |
--log-format simple | |
+ | |
diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py | |
index 84fcc68..0154467 100644 | |
--- a/fairseq/distributed_utils.py | |
+++ b/fairseq/distributed_utils.py | |
@@ -309,4 +309,7 @@ def all_reduce_dict( | |
return device_data[key] | |
raise KeyError | |
- return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) | |
+ result = OrderedDict([(key, get_from_stack(key)) for key in data_keys]) | |
+ del cpu_data | |
+ del device_data | |
+ return result | |
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py | |
index 2efc10e..945411e 100644 | |
--- a/fairseq/modules/gumbel_vector_quantizer.py | |
+++ b/fairseq/modules/gumbel_vector_quantizer.py | |
@@ -136,6 +136,7 @@ class GumbelVectorQuantizer(nn.Module): | |
avg_probs = torch.softmax( | |
x.view(bsz * tsz, self.groups, -1).float(), dim=-1 | |
).mean(dim=0) | |
+ avg_probs = avg_probs.detach() | |
result["prob_perplexity"] = torch.exp( | |
-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) | |
).sum() | |
diff --git a/fairseq/trainer.py b/fairseq/trainer.py | |
index a6b4f0c..c67895b 100644 | |
--- a/fairseq/trainer.py | |
+++ b/fairseq/trainer.py | |
@@ -525,7 +525,7 @@ class Trainer(object): | |
# only log stats every log_interval steps | |
# this causes wps to be misreported when log_interval > 1 | |
- self.logging_history += logging_outputs | |
+ self.logging_history.extend(logging_outputs) | |
self.cumm_sample_size += sample_size | |
logging_output = {} | |
if self.get_num_updates() % self.args.log_interval == 0: | |
@@ -851,6 +851,7 @@ class Trainer(object): | |
logging_outputs = [{k: data['logging_outputs_' + k] for k in log_keys}] | |
else: | |
logging_outputs = [] | |
+ del data | |
return logging_outputs, extra_stats_to_sum | |
def _check_grad_norms(self, grad_norm): | |
(torch-xla-nightly) sivaibhav@pytorch-dev-3:~/fairseq$ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment