Skip to content

Instantly share code, notes, and snippets.

@reuben
Created August 29, 2019 18:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reuben/b68b9085f7b293580f8431156a33daa9 to your computer and use it in GitHub Desktop.
Save reuben/b68b9085f7b293580f8431156a33daa9 to your computer and use it in GitHub Desktop.
diff --git a/DeepSpeech.py b/DeepSpeech.py
index 758e4670..2d3dbbff 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -749,7 +749,12 @@ def export():
output_names = ",".join(output_names_tensors + output_names_ops)
# Create a saver using variables from the above newly created graph
- saver = tfv1.train.Saver()
+ def fixup(name):
+ if name.startswith('cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/'):
+ return name.replace('cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/', 'lstm_fused_cell/')
+ return name
+ mapping = {fixup(v.op.name): v for v in tfv1.global_variables()}
+ saver = tfv1.train.Saver(mapping)
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment