Skip to content

Instantly share code, notes, and snippets.

@joetm
Forked from dalgu90/count_ckpt_param.py
Created July 7, 2022 08:46
Show Gist options
  • Save joetm/50bab027ad5f919e39b672c62ea3175f to your computer and use it in GitHub Desktop.
Save joetm/50bab027ad5f919e39b672c62ea3175f to your computer and use it in GitHub Desktop.
Count the number of parameter in a TensorFlow checkpoint file. (Usage: python count_ckpt_param.py path-to-ckpt)
#!/usr/bin/env python
import sys
import tensorflow as tf
import numpy as np
if len(sys.argv) == 2:
ckpt_fpath = sys.argv[1]
else:
print('Usage: python count_ckpt_param.py path-to-ckpt')
sys.exit(1)
# Open TensorFlow ckpt
reader = tf.train.NewCheckpointReader(ckpt_fpath)
print('\nCount the number of parameters in ckpt file(%s)' % ckpt_fpath)
param_map = reader.get_variable_to_shape_map()
total_count = 0
for k, v in param_map.items():
if 'Momentum' not in k and 'global_step' not in k:
temp = np.prod(v)
total_count += temp
print('%s: %s => %d' % (k, str(v), temp))
print('Total Param Count: %d' % total_count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment