-
-
Save dalgu90/a9952dfd372cbe1cdc529b204329e189 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
import sys | |
import tensorflow as tf | |
import numpy as np | |
if len(sys.argv) == 2: | |
ckpt_fpath = sys.argv[1] | |
else: | |
print('Usage: python count_ckpt_param.py path-to-ckpt') | |
sys.exit(1) | |
# Open TensorFlow ckpt | |
reader = tf.train.NewCheckpointReader(ckpt_fpath) | |
print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath) | |
param_map = reader.get_variable_to_shape_map() | |
total_count = 0 | |
for k, v in param_map.items(): | |
if 'Momentum' not in k and 'global_step' not in k: | |
temp = np.prod(v) | |
total_count += temp | |
print('%s: %s => %d' % (k, str(v), temp)) | |
print('Total Param Count: %d' % total_count) |
I have a question. why do you have this line: if 'Momentum' not in k and 'global_step' not in k:. Why don't we consider those parameters when counting the number of parameters?
those params are born during training and that is not considered as Model params. The Model params will not only been used in training but also in inference/test which doesn't require training information of optimizer it current training iterations.
I have a question. why do you have this line: if 'Momentum' not in k and 'global_step' not in k:. Why don't we consider those parameters when counting the number of parameters?
Thanks @dtlam26 for answering. To add a comment, the variable named global_step
is used in TF v1.x to represent the number of optimizer steps taken so far.
Also, there may be other variables not used in the model, by models or frameworks on top of TF, which you have to rule out manually.
When executing the script I get the error below; I assume this script does no longer work in tf2
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-9-deb2ed05f78c> in <module>()
6
7 # Open TensorFlow ckpt
----> 8 reader = tf.train.NewCheckpointReader(ckpt_fpath)
9
10 print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)
AttributeError: module 'tensorflow._api.v2.train' has no attribute 'NewCheckpointReader'
When executing the script I get the error below; I assume this script does no longer work in tf2
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-9-deb2ed05f78c> in <module>() 6 7 # Open TensorFlow ckpt ----> 8 reader = tf.train.NewCheckpointReader(ckpt_fpath) 9 10 print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath) AttributeError: module 'tensorflow._api.v2.train' has no attribute 'NewCheckpointReader'
you should switch to tf.compat.v1.train my dear 💪
I have a question. why do you have this line: if 'Momentum' not in k and 'global_step' not in k:. Why don't we consider those parameters when counting the number of parameters?