Skip to content

Instantly share code, notes, and snippets.

@wassname
Created October 20, 2017 12:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/b8b75b96d7899ee4fc3ea6559acc0af1 to your computer and use it in GitHub Desktop.
Save wassname/b8b75b96d7899ee4fc3ea6559acc0af1 to your computer and use it in GitHub Desktop.
summarise tensorflow models
MODEL: value_function
variable shape dtype params
0 value_function/conv2d/Variable:0 [1, 3, 3, 2] float32_ref 18
1 value_function/conv2d/Variable_1:0 [2] float32_ref 2
2 value_function/conv2d_1/Variable:0 [1, 49, 2, 20] float32_ref 1960
3 value_function/conv2d_1/Variable_1:0 [20] float32_ref 20
4 value_function/conv2d_2/Variable:0 [1, 1, 21, 1] float32_ref 21
5 value_function/conv2d_2/Variable_1:0 [1] float32_ref 1
total params: 2022
MODEL: baseline/baseline_state
variable shape dtype params
0 baseline/baseline_state/conv2d/Variable:0 [1, 3, 3, 3] float32_ref 27.0
1 baseline/baseline_state/conv2d/Variable_1:0 [3] float32_ref 3.0
2 baseline/baseline_state/conv2d_1/Variable:0 [1, 50, 3, 20] float32_ref 3000.0
3 baseline/baseline_state/conv2d_1/Variable_1:0 [20] float32_ref 20.0
4 baseline/baseline_state/conv2d_2/Variable:0 [1, 1, 20, 1] float32_ref 20.0
5 baseline/baseline_state/conv2d_2/Variable_1:0 [1] float32_ref 1.0
6 baseline/baseline_state/linear/Variable:0 [150, 1] float32_ref 150.0
7 baseline/baseline_state/linear/Variable_1:0 [1] float32_ref 1.0
8 baseline/baseline_state/beta1_power:0 [] float32_ref 1.0
9 baseline/baseline_state/beta2_power:0 [] float32_ref 1.0
total params: 3224.0
# summarize models
from IPython.display import display
def summarize(scope_name):
with tf.variable_scope(scope_name) as scope:
variables = tf.contrib.framework.get_variables(scope=scope)
# Remove optimizer vars
variables = [var for var in variables if 'Adam' not in var.name]
# summarise network
params = [layer.name.split('/')+[np.prod(layer.shape.as_list())]+[layer.shape.as_list()] for layer in variables]
params = [[layer.name, layer.shape.as_list(), layer.dtype.name, np.prod(layer.shape.as_list())] for layer in variables]
params_df = pd.DataFrame(params, columns=['variable','shape','dtype','params'])
print('MODEL: ',scope_name, end='')
display(params_df)
print('total params:', params_df.params.sum())
summarize('value_function')
summarize('baseline/baseline_state')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment