Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Last active November 28, 2021 04:49
Show Gist options
  • Save yaoyaoding/4d38208febea78c7c0cccab2e23b9458 to your computer and use it in GitHub Desktop.
Save yaoyaoding/4d38208febea78c7c0cccab2e23b9458 to your computer and use it in GitHub Desktop.
Extract the launch configuration from PrimFunc.
import tvm
from tvm import te, tir
def extract_launch_config(prim_func: tir.PrimFunc):
"""
Extract the launch configuration of given prim_func.
Parameters
----------
prim_func : tvm.tir.PrimFunc
The prim func to analyze.
Returns
-------
ret : dict
Return a dict that maps the launch configuration name to its value. Such as {'threadIdx.x': 10, ...}.
"""
config = {
'blockIdx.x': 1, 'blockIdx.y': 1, 'blockIdx.z': 1,
'threadIdx.x': 1, 'threadIdx.y': 1, 'threadIdx.z': 1
}
def find_attrs(node):
if isinstance(node, tir.stmt.AttrStmt):
attr_stmt = node
if isinstance(attr_stmt.node, tir.expr.IterVar):
iter_var = attr_stmt.node
tag = iter_var.thread_tag
if tag in config:
config[tag] = attr_stmt.value
tir.stmt_functor.post_order_visit(prim_func.body, find_attrs)
return config
def split(stage, axis, factors):
axes = []
for f in reversed(factors):
axis, x = stage.split(axis, f)
axes.append(x)
axes.append(axis)
return list(reversed(axes))
def bind_thread(stage, axes, tags):
for axis, tag in zip(axes, tags):
stage.bind(axis, te.thread_axis(tag))
def matmul(n, block_size=16, tx=8, ty=4, tk=32):
A = te.placeholder(shape=(n, n), name='A')
B = te.placeholder(shape=(n, n), name='B')
k = te.reduce_axis(dom=(0, n), name='k')
C = te.compute((n, n), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C')
s = te.create_schedule(C.op)
assert isinstance(s, te.Schedule)
assert isinstance(s[C], te.Stage)
A_shared = s.cache_read(A, "shared", [C])
A_local = s.cache_read(A_shared, "local", [C])
B_shared = s.cache_read(B, "shared", [C])
B_local = s.cache_read(B_shared, "local", [C])
C_local = s.cache_write(C, "local")
x, y = s[C].op.axis
xb, xo, xi = split(s[C], x, (block_size, tx))
yb, yo, yi = split(s[C], y, (block_size, ty))
s[C].reorder(xb, yb, xo, yo, xi, yi) # leaves: xb, yb, xo, yo, xi, yi, k
bind_thread(s[C], (yb, xb, yo, xo), ("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y"))
s[C_local].compute_at(s[C], yo)
yi, xi = s[C_local].op.axis
k, = s[C_local].op.reduce_axis
ko, ki = s[C_local].split(k, tk)
s[C_local].reorder(ko, ki, yi, xi)
def optimize_read_cache(shared, local):
s[shared].compute_at(s[C_local], ko)
s[local].compute_at(s[C_local], ki)
y, x = s[shared].op.axis
# Note that we must split into block_size parts to reuse
# the previous axis threads
yo, yi = s[shared].split(y, nparts=block_size)
xo, xi = s[shared].split(x, nparts=block_size)
s[shared].reorder(yo, xo, yi, xi)
bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x"))
optimize_read_cache(A_shared, A_local)
optimize_read_cache(B_shared, B_local)
return s, [A, B, C]
def matmul_v2(block_size=16, tx=8, ty=4, tk=32):
n = te.var(name='n')
A = te.placeholder(shape=(n, n), name='A')
B = te.placeholder(shape=(n, n), name='B')
k = te.reduce_axis(dom=(0, n), name='k')
C = te.compute((n, n), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C')
s = te.create_schedule(C.op)
assert isinstance(s, te.Schedule)
assert isinstance(s[C], te.Stage)
A_shared = s.cache_read(A, "shared", [C])
A_local = s.cache_read(A_shared, "local", [C])
B_shared = s.cache_read(B, "shared", [C])
B_local = s.cache_read(B_shared, "local", [C])
C_local = s.cache_write(C, "local")
x, y = s[C].op.axis
xb, xo, xi = split(s[C], x, (block_size, tx))
yb, yo, yi = split(s[C], y, (block_size, ty))
s[C].reorder(xb, yb, xo, yo, xi, yi) # leaves: xb, yb, xo, yo, xi, yi, k
bind_thread(s[C], (yb, xb, yo, xo), ("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y"))
s[C_local].compute_at(s[C], yo)
yi, xi = s[C_local].op.axis
k, = s[C_local].op.reduce_axis
ko, ki = s[C_local].split(k, tk)
s[C_local].reorder(ko, ki, yi, xi)
def optimize_read_cache(shared, local):
s[shared].compute_at(s[C_local], ko)
s[local].compute_at(s[C_local], ki)
y, x = s[shared].op.axis
# Note that we must split into block_size parts to reuse
# the previous axis threads
yo, yi = s[shared].split(y, nparts=block_size)
xo, xi = s[shared].split(x, nparts=block_size)
s[shared].reorder(yo, xo, yi, xi)
bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x"))
optimize_read_cache(A_shared, A_local)
optimize_read_cache(B_shared, B_local)
return s, [n, A, B, C], tuple()
def get_te_schedule(name):
n = 2048
if name == 'add':
A = te.placeholder((10,), name='A')
B = te.placeholder((10,), name='B')
C = te.compute((10,), lambda i: A[i] + B[i], name='C')
sch = te.create_schedule([C.op])
sch[C].bind(C.op.axis[0], te.thread_axis("threadIdx.x"))
return sch, [A, B, C]
elif name == 'matmul':
return matmul(n=n)
elif name == 'matmul_v2':
sch, args, inputs = matmul_v2()
return sch, args
else:
raise ValueError()
if __name__ == '__main__':
for workload in ['add', 'matmul', 'matmul_v2']:
sch, args = get_te_schedule(workload)
with tvm.transform.PassContext(opt_level=3):
lib = tvm.lower(sch, args=args)
print(extract_launch_config(lib['main']))
# Output:
# {'blockIdx.x': 1, 'blockIdx.y': 1, 'blockIdx.z': 1, 'threadIdx.x': 10, 'threadIdx.y': 1, 'threadIdx.z': 1}
# {'blockIdx.x': 32, 'blockIdx.y': 16, 'blockIdx.z': 1, 'threadIdx.x': 16, 'threadIdx.y': 16, 'threadIdx.z': 1}
# {'blockIdx.x': floordiv((n + 63), 64), 'blockIdx.y': floordiv((n + 127), 128), 'blockIdx.z': 1, 'threadIdx.x': 16,
# 'threadIdx.y': 16, 'threadIdx.z': 1}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment