Skip to content

Instantly share code, notes, and snippets.

@python273
Last active July 5, 2024 01:26
Show Gist options
  • Save python273/0dc136fbc63559188ab279c07329e891 to your computer and use it in GitHub Desktop.
Save python273/0dc136fbc63559188ab279c07329e891 to your computer and use it in GitHub Desktop.
TinyJit vis WIP
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import JIT
from tinygrad.nn.optim import SGD
from tinygrad.nn.state import get_parameters
class TinyNet:
def __init__(self):
self.l1 = nn.Linear(784, 128, bias=False)
self.l2 = nn.Linear(128, 10, bias=False)
def __call__(self, x):
x = self.l1(x)
x = x.leakyrelu()
x = self.l2(x)
return x
net = TinyNet()
optim = SGD(get_parameters(net))
JIT.value = 2
@TinyJit
def train_step(batch, labels):
optim.zero_grad()
x = net(batch)
loss = x.sub(labels).square().mean()
loss.backward()
optim.step()
ia = Tensor.randn(64, 784)
ib = Tensor.randn(64, 10)
with Tensor.train():
for _ in range(3):
r = train_step(ia, ib)
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.engine.realize import BufferCopy
from dataclasses import dataclass
from collections import defaultdict
import re
import os
import networkx as nx
def strip_coloring(text):
return re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]').sub('', text)
@dataclass
class BufInfo:
loaded: bool = False
stored: bool = False
G = nx.DiGraph()
for ex, ei in enumerate(train_step.jit_cache):
# if isinstance(ei.prg, BufferCopy): continue # TODO:
# ei.prg.p.uops.print()
# Collect bufs load/store info
buf_info = defaultdict(lambda: BufInfo())
for uop in ei.prg.p.uops:
uop: UOp = uop
if uop.op is UOps.LOAD:
if uop.src[0].op is UOps.DEFINE_LOCAL: continue
assert uop.src[0].op is UOps.DEFINE_GLOBAL
buf_info[uop.src[0].arg[0]].loaded = True
if uop.op is UOps.STORE:
if uop.src[0].op is UOps.DEFINE_LOCAL: continue
assert uop.src[0].op is UOps.DEFINE_GLOBAL
buf_info[uop.src[0].arg[0]].stored = True
prg_id = f'{ex} {id(ei.prg)}'
G.add_node(prg_id, label=f'#{ex} ' + strip_coloring(ei.prg.display_name))
for bufx, buf in enumerate(ei.bufs):
if buf is None: continue
# if buf.size == 1: continue
G.add_node(
id(buf),
label=str(buf).replace(':', '_'),
shape='rectangle',
)
if buf_info[bufx].loaded: G.add_edge(id(buf), prg_id, label=str(bufx))
if buf_info[bufx].stored: G.add_edge(prg_id, id(buf), label=str(bufx))
for k, v in train_step.input_replace.items():
if k[0] != ex: continue
G.add_node(f'input_{v}', shape='rectangle')
G.add_edge(f'input_{v}', f'{k[0]} {id(train_step.jit_cache[k[0]].prg)}', label=str(k[1])) # TODO: dir
fn = '_jit'
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot -Tsvg {fn}.dot -o {fn}.svg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment