Skip to content

Instantly share code, notes, and snippets.

@vmoens
Created December 17, 2024 11:28
Show Gist options
  • Save vmoens/ab18938154997b94f6940b1fadde8999 to your computer and use it in GitHub Desktop.
Save vmoens/ab18938154997b94f6940b1fadde8999 to your computer and use it in GitHub Desktop.
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
import pydot
seq = Seq(
Mod(lambda x: x + 1, in_keys=["input"], out_keys=["intermediate"]),
Mod(lambda x, y: (x * y).sqrt(), in_keys=["input", "intermediate"], out_keys=["out_0"]),
Mod(lambda z, x: z - z, in_keys=["out_0", "intermediate"], out_keys=["out_1"]),
)
def edges(seq):
edges = []
for i, m in enumerate(seq):
out_keys = m.out_keys
for out_key in out_keys:
for j, next_m in enumerate(seq[i+1:], start=i+1):
if out_key in next_m.in_keys:
edges.append((out_key, (i, j)))
if out_key in next_m.out_keys:
break
return edges
graph = pydot.Dot("my_graph", graph_type="digraph", bgcolor="yellow", splines='curved')
graph.set_bgcolor("white")
# Add nodes
for in_key in seq.in_keys:
in_key_node = pydot.Node(in_key, label=in_key, shape="plaintext")
graph.add_node(in_key_node)
for out_key in seq.out_keys:
out_key_node = pydot.Node(out_key, label=out_key, shape="plaintext")
graph.add_node(out_key_node)
for i, node in enumerate(seq.module):
graph.add_node(pydot.Node(str(i), shape="box"))# label=str(node.module)))
for in_key in seq.in_keys:
if in_key in node.in_keys:
print("adding in_key edge")
my_edge = pydot.Edge(in_key, str(i), color="blue", style="arrow")
graph.add_edge(my_edge)
for out_key in seq.out_keys:
if out_key in node.out_keys:
print("adding out_key edge")
my_edge = pydot.Edge(str(i), out_key, color="blue", style="arrow")
graph.add_edge(my_edge)
# Add edges
for edgename, edge in edges(seq):
if edgename not in seq.out_keys:
my_edge = pydot.Edge(str(edge[0]), str(edge[1]), color="red", style="arrow", label=edgename, decorate=True)
graph.add_edge(my_edge)
else:
# my_edge = pydot.Edge(str(edge[0]), edgename, color="red", style="arrow")
# graph.add_edge(my_edge)
my_edge = pydot.Edge(edgename, str(edge[1]), color="red", style="arrow")
graph.add_edge(my_edge)
graph.write_png("/Users/vmoens/Downloads/my_graph.png")
@vmoens
Copy link
Author

vmoens commented Dec 17, 2024

my_graph

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment