Skip to content

Instantly share code, notes, and snippets.

@dalgu90
Last active September 8, 2022 00:40
Show Gist options
  • Save dalgu90/a9952dfd372cbe1cdc529b204329e189 to your computer and use it in GitHub Desktop.
Save dalgu90/a9952dfd372cbe1cdc529b204329e189 to your computer and use it in GitHub Desktop.
Count the number of parameter in a TensorFlow checkpoint file. (Usage: python count_ckpt_param.py path-to-ckpt)
#!/usr/bin/env python
import sys
import tensorflow as tf
import numpy as np
if len(sys.argv) == 2:
ckpt_fpath = sys.argv[1]
else:
print('Usage: python count_ckpt_param.py path-to-ckpt')
sys.exit(1)
# Open TensorFlow ckpt
reader = tf.train.NewCheckpointReader(ckpt_fpath)
print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)
param_map = reader.get_variable_to_shape_map()
total_count = 0
for k, v in param_map.items():
if 'Momentum' not in k and 'global_step' not in k:
temp = np.prod(v)
total_count += temp
print('%s: %s => %d' % (k, str(v), temp))
print('Total Param Count: %d' % total_count)
@dtlam26
Copy link

dtlam26 commented Aug 13, 2021

When executing the script I get the error below; I assume this script does no longer work in tf2

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-deb2ed05f78c> in <module>()
      6 
      7 # Open TensorFlow ckpt
----> 8 reader = tf.train.NewCheckpointReader(ckpt_fpath)
      9 
     10 print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)

AttributeError: module 'tensorflow._api.v2.train' has no attribute 'NewCheckpointReader'

you should switch to tf.compat.v1.train my dear 💪

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment