Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active September 19, 2019 06:00
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pervognsen/4aebd1cc0f81b8b56a601bf890616b42 to your computer and use it in GitHub Desktop.
Save pervognsen/4aebd1cc0f81b8b56a601bf890616b42 to your computer and use it in GitHub Desktop.
class DotGenerator(Visitor):
def __init__(self):
super().__init__()
self.lines = []
self.next_id = 0
def make_name(self, node, name=None):
if name is None:
name = "n%d" % self.next_id
self.next_id += 1
self.values[node] = name
return name
def vertex(self, name, shape, label):
self.lines.append('%s [ shape = %s, label = "%s" ];' % (name, shape, label))
def edge(self, from_name, to_name):
self.lines.append('%s:e -> %s:w;' % (from_name, to_name))
def ConstantNode(self, node):
name = self.make_name(node)
self.vertex(name, 'circle', node.value)
return name
def BinaryNode(self, node):
name = self.make_name(node)
self.vertex(name, 'record', '{{<i0>|<i1>}|\\%s}' % node.op)
self.edge(self(node.left), name + ':i0')
self.edge(self(node.right), name + ':i1')
return name
def IndexNode(self, node):
name = self.make_name(node)
self.vertex(name, 'box', '[%d]' % node.index)
self.edge(self(node.operand), name)
return name
def SliceNode(self, node):
name = self.make_name(node)
self.vertex(name, 'box', '[%d:%d]' % (node.start, node.stop))
self.edge(self(node.operand), name)
return name
def ConcatNode(self, node):
name = self.make_name(node)
label = '|'.join('<i%d>' % i for i in range(len(node.operands)))
self.vertex(name, 'record', '{{%s}|}' % label)
for i, operand in enumerate(node.operands):
self.edge(self(operand), '%s:i%d' % (name, i))
return name
def InputNode(self, node):
name = self.make_name(node)
self.vertex(name, 'rarrow', node.name)
return name
def OutputNode(self, node):
name = self.make_name(node)
self.vertex(name, 'rarrow', node.name)
if node.operand:
self.edge(self(node.operand), name)
return name
def ModuleOutputNode(self, node):
return self.make_name(node, '%s:%s' % (self(node.module), node.name))
def Module(self, module):
name = self.make_name(module)
inputs = '|'.join('<%s> %s' % (input_name, input_name) for input_name in module._inputs)
outputs = '|'.join('<%s> %s' % (output_name, output_name) for output_name in module._outputs)
self.vertex(name, 'record', '{{%s}|%s|{%s}}' % (inputs, type(module).__name__, outputs))
for input_node, node in module._connections.items():
self.edge(self(node), '%s:%s' % (name, input_node.name))
return name
def default(self, x):
if isinstance(x, Module):
return self.Module(x)
else:
return super().default(x)
def generate_dot_file(module):
generator = DotGenerator()
for node in module._outputs.values():
generator(node)
return 'digraph "%s" {\nrankdir = "LR";\n%s\n}\n' % (module.__name__, '\n'.join(generator.lines))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment