Skip to content

Instantly share code, notes, and snippets.

@rreece
Created February 7, 2020 19:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rreece/be6c6f38624fc96f7b17cefba836eebd to your computer and use it in GitHub Desktop.
Save rreece/be6c6f38624fc96f7b17cefba836eebd to your computer and use it in GitHub Desktop.
print all the trainable parameters in a tensorflow model in current scope
def print_total_parameters():
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
print('%s dim=%i shape=%s params=%i' % (
variable.name,
len(shape),
shape,
variable_parameters,
))
total_parameters += variable_parameters
print('total_parameters = %i' % (total_parameters))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment