Skip to content

Instantly share code, notes, and snippets.

@N-McA
Created April 23, 2018 18:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save N-McA/c6d5f3e7cbd7b9c4d44126430b823a10 to your computer and use it in GitHub Desktop.
Save N-McA/c6d5f3e7cbd7b9c4d44126430b823a10 to your computer and use it in GitHub Desktop.
header = r'''
\begin{tikzpicture}[node distance = 2mm, auto]
%% Auto Generated
'''
raw_b = r'''
\node [block, below= of glove] (conv1) {
\begin{tabular}{cc}
Conv1D & Input: $n$x100 \\
64x5 Dilation 1 & Output: $n$x64 \\
\end{tabular}
};
'''
b_escaped = (raw_b
.replace('{', '{{')
.replace('}', '}}')
.replace('\\', '\\\\')
)
# print(b_escaped)
first_block = '''
\\node [block, text width=8cm] ({block_id}) {{
\\begin{{tabular}}{{cc}}
{layer_name} & \\\\
{layer_params} & Output: ${output_size}$ \\\\
\\end{{tabular}}
}};
'''
template_block = '''
\\node [block, text width=8cm, below= of {prev_block_id}] ({block_id}) {{
\\begin{{tabular}}{{cc}}
{layer_name} & Input: ${input_size}$ \\\\
{layer_params} & Output: ${output_size}$ \\\\
\\end{{tabular}}
}};
'''
print(template_block.format(
layer_name='Conv',
layer_params='Foo',
prev_block_id='Foo',
block_id='bar',
input_size='bing',
output_size='bish'
))
footer = r'''
\end{tikzpicture}
'''
interesting_params = [
'alpha',
'units',
'filters',
'kernel_size',
'strides',
'dilation_rate',
'activation'
]
def name_or_str(x):
s = str(x).strip()
if s.startswith('<function'):
return x.__name__
return s
def get_p_string(layer):
ps = []
for p in interesting_params:
if p in layer.__dict__:
s = '{}={}'.format(p, name_or_str(layer.__dict__[p]))
ps.append(s)
return ', '.join(sorted(ps))
def shape_string(shape):
if len(shape) == 3:
return 'b, \ell, {}'.format(shape[-1])
if len(shape) == 2:
return 'b, \ell'
print(header, end='')
layer_n = 0
layer = model.layers[layer_n]
b = first_block.format(
layer_name=layer.__class__.__name__,
layer_params=get_p_string(layer),
block_id='layer_{}'.format(layer_n),
output_size=shape_string(layer.output_shape),
)
print(b, end='')
for layer_n in range(1, len(model.layers)):
layer = model.layers[layer_n]
b = template_block.format(
layer_name=layer.__class__.__name__,
layer_params=get_p_string(layer),
prev_block_id='layer_{}'.format(layer_n - 1),
block_id='layer_{}'.format(layer_n),
input_size=shape_string(layer.input_shape),
output_size=shape_string(layer.output_shape),
)
print(b, end='')
print(footer, end='')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment