Last active
September 19, 2019 06:00
-
-
Save pervognsen/4aebd1cc0f81b8b56a601bf890616b42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DotGenerator(Visitor): | |
def __init__(self): | |
super().__init__() | |
self.lines = [] | |
self.next_id = 0 | |
def make_name(self, node, name=None): | |
if name is None: | |
name = "n%d" % self.next_id | |
self.next_id += 1 | |
self.values[node] = name | |
return name | |
def vertex(self, name, shape, label): | |
self.lines.append('%s [ shape = %s, label = "%s" ];' % (name, shape, label)) | |
def edge(self, from_name, to_name): | |
self.lines.append('%s:e -> %s:w;' % (from_name, to_name)) | |
def ConstantNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'circle', node.value) | |
return name | |
def BinaryNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'record', '{{<i0>|<i1>}|\\%s}' % node.op) | |
self.edge(self(node.left), name + ':i0') | |
self.edge(self(node.right), name + ':i1') | |
return name | |
def IndexNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'box', '[%d]' % node.index) | |
self.edge(self(node.operand), name) | |
return name | |
def SliceNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'box', '[%d:%d]' % (node.start, node.stop)) | |
self.edge(self(node.operand), name) | |
return name | |
def ConcatNode(self, node): | |
name = self.make_name(node) | |
label = '|'.join('<i%d>' % i for i in range(len(node.operands))) | |
self.vertex(name, 'record', '{{%s}|}' % label) | |
for i, operand in enumerate(node.operands): | |
self.edge(self(operand), '%s:i%d' % (name, i)) | |
return name | |
def InputNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'rarrow', node.name) | |
return name | |
def OutputNode(self, node): | |
name = self.make_name(node) | |
self.vertex(name, 'rarrow', node.name) | |
if node.operand: | |
self.edge(self(node.operand), name) | |
return name | |
def ModuleOutputNode(self, node): | |
return self.make_name(node, '%s:%s' % (self(node.module), node.name)) | |
def Module(self, module): | |
name = self.make_name(module) | |
inputs = '|'.join('<%s> %s' % (input_name, input_name) for input_name in module._inputs) | |
outputs = '|'.join('<%s> %s' % (output_name, output_name) for output_name in module._outputs) | |
self.vertex(name, 'record', '{{%s}|%s|{%s}}' % (inputs, type(module).__name__, outputs)) | |
for input_node, node in module._connections.items(): | |
self.edge(self(node), '%s:%s' % (name, input_node.name)) | |
return name | |
def default(self, x): | |
if isinstance(x, Module): | |
return self.Module(x) | |
else: | |
return super().default(x) | |
def generate_dot_file(module): | |
generator = DotGenerator() | |
for node in module._outputs.values(): | |
generator(node) | |
return 'digraph "%s" {\nrankdir = "LR";\n%s\n}\n' % (module.__name__, '\n'.join(generator.lines)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment