Created
October 20, 2017 12:25
-
-
Save wassname/b8b75b96d7899ee4fc3ea6559acc0af1 to your computer and use it in GitHub Desktop.
summarise tensorflow models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MODEL: value_function | |
variable shape dtype params | |
0 value_function/conv2d/Variable:0 [1, 3, 3, 2] float32_ref 18 | |
1 value_function/conv2d/Variable_1:0 [2] float32_ref 2 | |
2 value_function/conv2d_1/Variable:0 [1, 49, 2, 20] float32_ref 1960 | |
3 value_function/conv2d_1/Variable_1:0 [20] float32_ref 20 | |
4 value_function/conv2d_2/Variable:0 [1, 1, 21, 1] float32_ref 21 | |
5 value_function/conv2d_2/Variable_1:0 [1] float32_ref 1 | |
total params: 2022 | |
MODEL: baseline/baseline_state | |
variable shape dtype params | |
0 baseline/baseline_state/conv2d/Variable:0 [1, 3, 3, 3] float32_ref 27.0 | |
1 baseline/baseline_state/conv2d/Variable_1:0 [3] float32_ref 3.0 | |
2 baseline/baseline_state/conv2d_1/Variable:0 [1, 50, 3, 20] float32_ref 3000.0 | |
3 baseline/baseline_state/conv2d_1/Variable_1:0 [20] float32_ref 20.0 | |
4 baseline/baseline_state/conv2d_2/Variable:0 [1, 1, 20, 1] float32_ref 20.0 | |
5 baseline/baseline_state/conv2d_2/Variable_1:0 [1] float32_ref 1.0 | |
6 baseline/baseline_state/linear/Variable:0 [150, 1] float32_ref 150.0 | |
7 baseline/baseline_state/linear/Variable_1:0 [1] float32_ref 1.0 | |
8 baseline/baseline_state/beta1_power:0 [] float32_ref 1.0 | |
9 baseline/baseline_state/beta2_power:0 [] float32_ref 1.0 | |
total params: 3224.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# summarize models | |
from IPython.display import display | |
def summarize(scope_name): | |
with tf.variable_scope(scope_name) as scope: | |
variables = tf.contrib.framework.get_variables(scope=scope) | |
# Remove optimizer vars | |
variables = [var for var in variables if 'Adam' not in var.name] | |
# summarise network | |
params = [layer.name.split('/')+[np.prod(layer.shape.as_list())]+[layer.shape.as_list()] for layer in variables] | |
params = [[layer.name, layer.shape.as_list(), layer.dtype.name, np.prod(layer.shape.as_list())] for layer in variables] | |
params_df = pd.DataFrame(params, columns=['variable','shape','dtype','params']) | |
print('MODEL: ',scope_name, end='') | |
display(params_df) | |
print('total params:', params_df.params.sum()) | |
summarize('value_function') | |
summarize('baseline/baseline_state') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment