-
-
Save reuben/b68b9085f7b293580f8431156a33daa9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/DeepSpeech.py b/DeepSpeech.py | |
index 758e4670..2d3dbbff 100755 | |
--- a/DeepSpeech.py | |
+++ b/DeepSpeech.py | |
@@ -749,7 +749,12 @@ def export(): | |
output_names = ",".join(output_names_tensors + output_names_ops) | |
# Create a saver using variables from the above newly created graph | |
- saver = tfv1.train.Saver() | |
+ def fixup(name): | |
+ if name.startswith('cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/'): | |
+ return name.replace('cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/', 'lstm_fused_cell/') | |
+ return name | |
+ mapping = {fixup(v.op.name): v for v in tfv1.global_variables()} | |
+ saver = tfv1.train.Saver(mapping) | |
# Restore variables from training checkpoint | |
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment