Skip to content

Instantly share code, notes, and snippets.

@oraoto
Last active September 28, 2018 07:07
Show Gist options
  • Save oraoto/798b3989d4984eed0daf595d4f9e5360 to your computer and use it in GitHub Desktop.
Save oraoto/798b3989d4984eed0daf595d4f9e5360 to your computer and use it in GitHub Desktop.
nnabla graph visualization
import nnabla as nn
import nnabla.functions as F
import numpy as np
import matplotlib.pyplot as plt
import graphviz as gv
def draw_graph(v, hide_params=True, op_as_edge=False):
graph = gv.Digraph()
params = nn.get_parameters(grad_only=False)
layer_count = {}
variables = {}
if type(v) is list:
v = F.sink(*v)
if op_as_edge:
hide_params = True
def add_variable(n, prefix):
if n not in variables.values():
n_name = prefix + str(id(n))
variables[n_name] = n
else:
n_name = list(variables.keys())[list(variables.values()).index(n)]
if hide_params and n in params.values():
return False
attrs = {
'label': str(n.shape),
'style': 'filled',
'shape': 'box',
'align': 'center'
}
if n.need_grad:
attrs['fillcolor'] = '#f8baff'
else:
attrs['fillcolor'] = '#cbffba'
if n in params.values():
attrs['size'] = ''
attrs['fillcolor'] = '#f8baff75'
graph.node(n_name, **attrs)
return n_name
def visit(f):
if f.name not in layer_count:
layer_count[f.name] = 1
else:
layer_count[f.name] += 1
if f.info.type_name == 'Sink':
return
f_name = f.name + '_' + str(layer_count[f.name])
if not op_as_edge:
graph.node(f_name)
inputs = []
outputs = []
for inp in f.inputs:
n_id = add_variable(inp, f_name+ '_Input')
if n_id:
inputs.append(n_id)
if n_id and not op_as_edge:
graph.edge(n_id, f_name)
for oup in f.outputs:
n_id = add_variable(oup, f_name + '_Output')
if n_id:
outputs.append(n_id)
if not op_as_edge:
graph.edge(f_name, n_id)
if op_as_edge:
for i in inputs:
for o in outputs:
graph.edge(i, o, label=f_name)
v.visit(visit)
return graph
#%%
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
from graph import draw_graph
def rnn(xs, h0, hidden=32):
hs = []
with nn.parameter_scope("rnn"):
h = h0
for x in xs:
with nn.parameter_scope("x2h"):
x2h = PF.affine(x, hidden, with_bias=False)
with nn.parameter_scope("h2h"):
h2h = PF.affine(h, hidden)
h = F.tanh(x2h + h2h)
hs.append(h)
with nn.parameter_scope("classifier"):
y = PF.affine(h, 10)
return y, hs
seq_x = [
nn.Variable([28, 28]),
nn.Variable([28, 28]),
nn.Variable([28, 28])
]
h0 = nn.Variable((28, 32))
y, hs = rnn(seq_x, h0, 32)
g = draw_graph(y, hide_params=False)
g.view()
#%%
g = draw_graph(y)
g.view()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment