Skip to content

Instantly share code, notes, and snippets.

@dalgu90
Last active September 8, 2022 00:40
Show Gist options
  • Save dalgu90/a9952dfd372cbe1cdc529b204329e189 to your computer and use it in GitHub Desktop.
Save dalgu90/a9952dfd372cbe1cdc529b204329e189 to your computer and use it in GitHub Desktop.
Count the number of parameter in a TensorFlow checkpoint file. (Usage: python count_ckpt_param.py path-to-ckpt)
#!/usr/bin/env python
import sys
import tensorflow as tf
import numpy as np
if len(sys.argv) == 2:
ckpt_fpath = sys.argv[1]
else:
print('Usage: python count_ckpt_param.py path-to-ckpt')
sys.exit(1)
# Open TensorFlow ckpt
reader = tf.train.NewCheckpointReader(ckpt_fpath)
print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)
param_map = reader.get_variable_to_shape_map()
total_count = 0
for k, v in param_map.items():
if 'Momentum' not in k and 'global_step' not in k:
temp = np.prod(v)
total_count += temp
print('%s: %s => %d' % (k, str(v), temp))
print('Total Param Count: %d' % total_count)
@dalgu90
Copy link
Author

dalgu90 commented Apr 14, 2021

I have a question. why do you have this line: if 'Momentum' not in k and 'global_step' not in k:. Why don't we consider those parameters when counting the number of parameters?

Thanks @dtlam26 for answering. To add a comment, the variable named global_step is used in TF v1.x to represent the number of optimizer steps taken so far.
Also, there may be other variables not used in the model, by models or frameworks on top of TF, which you have to rule out manually.

@Jsevillamol
Copy link

When executing the script I get the error below; I assume this script does no longer work in tf2

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-deb2ed05f78c> in <module>()
      6 
      7 # Open TensorFlow ckpt
----> 8 reader = tf.train.NewCheckpointReader(ckpt_fpath)
      9 
     10 print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)

AttributeError: module 'tensorflow._api.v2.train' has no attribute 'NewCheckpointReader'

@dtlam26
Copy link

dtlam26 commented Aug 13, 2021

When executing the script I get the error below; I assume this script does no longer work in tf2

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-deb2ed05f78c> in <module>()
      6 
      7 # Open TensorFlow ckpt
----> 8 reader = tf.train.NewCheckpointReader(ckpt_fpath)
      9 
     10 print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)

AttributeError: module 'tensorflow._api.v2.train' has no attribute 'NewCheckpointReader'

you should switch to tf.compat.v1.train my dear 💪

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment