Skip to content

Instantly share code, notes, and snippets.

@maxrohleder
Last active January 12, 2023 08:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxrohleder/85ec07d16e9b1eb5837883e90b1ca8d7 to your computer and use it in GitHub Desktop.
Save maxrohleder/85ec07d16e9b1eb5837883e90b1ca8d7 to your computer and use it in GitHub Desktop.
Tensorflow to Pytorch conversion
import tensorflow as tf # tensorflow 1.x
import pickle
'''
<base_folder>
├───checkpoint
├───<model_name>.meta
├───<model_name>.data-00000-of-00001
└───<model_name>.index
'''
# First let's load meta graph and restore weights
sess = tf.Session()
saver = tf.train.import_meta_graph(r'<base_folder>\<model_name>.meta')
saver.restore(sess, tf.train.latest_checkpoint(r'<base_folder>'))
# get all trainable weights and save them in a dictionary
vars = sess.graph.get_collection('trainable_variables')
weights = {}
for v in vars:
weights[v.name] = sess.run(v) # retrieve the value from the tf backend
with open('weights.pickle', 'wb') as handle:
pickle.dump(weights, handle, protocol=pickle.HIGHEST_PROTOCOL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment