Skip to content

Instantly share code, notes, and snippets.

@apaszke
Forked from szagoruyko/pytorch-graphviz.py
Last active October 19, 2020 05:04

Revisions

  1. apaszke revised this gist Feb 9, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion pytorch-graphviz.py
    Original file line number Diff line number Diff line change
    @@ -9,7 +9,7 @@ def add_nodes(var):
    if isinstance(var, Variable):
    dot.node(str(id(var)), str(var.size()), fillcolor='lightblue')
    else:
    dot.node(str(id(var)), str(type(var)))
    dot.node(str(id(var)), type(var).__name__)
    seen.add(var)
    if hasattr(var, 'previous_functions'):
    for u in var.previous_functions:
  2. @szagoruyko szagoruyko revised this gist Jan 11, 2017. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion pytorch-graphviz.py
    Original file line number Diff line number Diff line change
    @@ -13,7 +13,7 @@ def add_nodes(var):
    seen.add(var)
    if hasattr(var, 'previous_functions'):
    for u in var.previous_functions:
    dot.edge(str(id(var)), str(id(u[0])))
    dot.edge(str(id(u[0])), str(id(var)))
    add_nodes(u[0])

    add_nodes(R.creator)
  3. @szagoruyko szagoruyko revised this gist Jan 11, 2017. 1 changed file with 3 additions and 4 deletions.
    7 changes: 3 additions & 4 deletions pytorch-graphviz.py
    Original file line number Diff line number Diff line change
    @@ -1,16 +1,15 @@
    from graphviz import Digraph

    dot = Digraph(comment='LRP')
    dot = Digraph(comment='LRP', node_attr={'style': 'filled', 'shape': 'box'})#, 'fillcolor': 'lightblue'})

    seen = set()

    def add_nodes(var):
    if var not in seen:
    if isinstance(var, Variable):
    name = str(var.size())
    dot.node(str(id(var)), str(var.size()), fillcolor='lightblue')
    else:
    name = str(type(var))
    dot.node(str(id(var)), name)
    dot.node(str(id(var)), str(type(var)))
    seen.add(var)
    if hasattr(var, 'previous_functions'):
    for u in var.previous_functions:
  4. @szagoruyko szagoruyko revised this gist Jan 11, 2017. 1 changed file with 10 additions and 5 deletions.
    15 changes: 10 additions & 5 deletions pytorch-graphviz.py
    Original file line number Diff line number Diff line change
    @@ -5,12 +5,17 @@
    seen = set()

    def add_nodes(var):
    if hasattr(var, 'previous_functions') and var not in seen:
    dot.node(str(id(var)), str(type(var)))
    if var not in seen:
    if isinstance(var, Variable):
    name = str(var.size())
    else:
    name = str(type(var))
    dot.node(str(id(var)), name)
    seen.add(var)
    for u in var.previous_functions:
    dot.edge(str(id(var)), str(id(u[0])))
    add_nodes(u[0])
    if hasattr(var, 'previous_functions'):
    for u in var.previous_functions:
    dot.edge(str(id(var)), str(id(u[0])))
    add_nodes(u[0])

    add_nodes(R.creator)

  5. @szagoruyko szagoruyko created this gist Jan 11, 2017.
    17 changes: 17 additions & 0 deletions pytorch-graphviz.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,17 @@
    from graphviz import Digraph

    dot = Digraph(comment='LRP')

    seen = set()

    def add_nodes(var):
    if hasattr(var, 'previous_functions') and var not in seen:
    dot.node(str(id(var)), str(type(var)))
    seen.add(var)
    for u in var.previous_functions:
    dot.edge(str(id(var)), str(id(u[0])))
    add_nodes(u[0])

    add_nodes(R.creator)

    dot.save('/tmp/lrp.dot')