Last active
November 28, 2021 04:49
-
-
Save yaoyaoding/4d38208febea78c7c0cccab2e23b9458 to your computer and use it in GitHub Desktop.
Extract the launch configuration from PrimFunc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import te, tir | |
def extract_launch_config(prim_func: tir.PrimFunc): | |
""" | |
Extract the launch configuration of given prim_func. | |
Parameters | |
---------- | |
prim_func : tvm.tir.PrimFunc | |
The prim func to analyze. | |
Returns | |
------- | |
ret : dict | |
Return a dict that maps the launch configuration name to its value. Such as {'threadIdx.x': 10, ...}. | |
""" | |
config = { | |
'blockIdx.x': 1, 'blockIdx.y': 1, 'blockIdx.z': 1, | |
'threadIdx.x': 1, 'threadIdx.y': 1, 'threadIdx.z': 1 | |
} | |
def find_attrs(node): | |
if isinstance(node, tir.stmt.AttrStmt): | |
attr_stmt = node | |
if isinstance(attr_stmt.node, tir.expr.IterVar): | |
iter_var = attr_stmt.node | |
tag = iter_var.thread_tag | |
if tag in config: | |
config[tag] = attr_stmt.value | |
tir.stmt_functor.post_order_visit(prim_func.body, find_attrs) | |
return config | |
def split(stage, axis, factors): | |
axes = [] | |
for f in reversed(factors): | |
axis, x = stage.split(axis, f) | |
axes.append(x) | |
axes.append(axis) | |
return list(reversed(axes)) | |
def bind_thread(stage, axes, tags): | |
for axis, tag in zip(axes, tags): | |
stage.bind(axis, te.thread_axis(tag)) | |
def matmul(n, block_size=16, tx=8, ty=4, tk=32): | |
A = te.placeholder(shape=(n, n), name='A') | |
B = te.placeholder(shape=(n, n), name='B') | |
k = te.reduce_axis(dom=(0, n), name='k') | |
C = te.compute((n, n), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C') | |
s = te.create_schedule(C.op) | |
assert isinstance(s, te.Schedule) | |
assert isinstance(s[C], te.Stage) | |
A_shared = s.cache_read(A, "shared", [C]) | |
A_local = s.cache_read(A_shared, "local", [C]) | |
B_shared = s.cache_read(B, "shared", [C]) | |
B_local = s.cache_read(B_shared, "local", [C]) | |
C_local = s.cache_write(C, "local") | |
x, y = s[C].op.axis | |
xb, xo, xi = split(s[C], x, (block_size, tx)) | |
yb, yo, yi = split(s[C], y, (block_size, ty)) | |
s[C].reorder(xb, yb, xo, yo, xi, yi) # leaves: xb, yb, xo, yo, xi, yi, k | |
bind_thread(s[C], (yb, xb, yo, xo), ("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y")) | |
s[C_local].compute_at(s[C], yo) | |
yi, xi = s[C_local].op.axis | |
k, = s[C_local].op.reduce_axis | |
ko, ki = s[C_local].split(k, tk) | |
s[C_local].reorder(ko, ki, yi, xi) | |
def optimize_read_cache(shared, local): | |
s[shared].compute_at(s[C_local], ko) | |
s[local].compute_at(s[C_local], ki) | |
y, x = s[shared].op.axis | |
# Note that we must split into block_size parts to reuse | |
# the previous axis threads | |
yo, yi = s[shared].split(y, nparts=block_size) | |
xo, xi = s[shared].split(x, nparts=block_size) | |
s[shared].reorder(yo, xo, yi, xi) | |
bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x")) | |
optimize_read_cache(A_shared, A_local) | |
optimize_read_cache(B_shared, B_local) | |
return s, [A, B, C] | |
def matmul_v2(block_size=16, tx=8, ty=4, tk=32): | |
n = te.var(name='n') | |
A = te.placeholder(shape=(n, n), name='A') | |
B = te.placeholder(shape=(n, n), name='B') | |
k = te.reduce_axis(dom=(0, n), name='k') | |
C = te.compute((n, n), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name='C') | |
s = te.create_schedule(C.op) | |
assert isinstance(s, te.Schedule) | |
assert isinstance(s[C], te.Stage) | |
A_shared = s.cache_read(A, "shared", [C]) | |
A_local = s.cache_read(A_shared, "local", [C]) | |
B_shared = s.cache_read(B, "shared", [C]) | |
B_local = s.cache_read(B_shared, "local", [C]) | |
C_local = s.cache_write(C, "local") | |
x, y = s[C].op.axis | |
xb, xo, xi = split(s[C], x, (block_size, tx)) | |
yb, yo, yi = split(s[C], y, (block_size, ty)) | |
s[C].reorder(xb, yb, xo, yo, xi, yi) # leaves: xb, yb, xo, yo, xi, yi, k | |
bind_thread(s[C], (yb, xb, yo, xo), ("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y")) | |
s[C_local].compute_at(s[C], yo) | |
yi, xi = s[C_local].op.axis | |
k, = s[C_local].op.reduce_axis | |
ko, ki = s[C_local].split(k, tk) | |
s[C_local].reorder(ko, ki, yi, xi) | |
def optimize_read_cache(shared, local): | |
s[shared].compute_at(s[C_local], ko) | |
s[local].compute_at(s[C_local], ki) | |
y, x = s[shared].op.axis | |
# Note that we must split into block_size parts to reuse | |
# the previous axis threads | |
yo, yi = s[shared].split(y, nparts=block_size) | |
xo, xi = s[shared].split(x, nparts=block_size) | |
s[shared].reorder(yo, xo, yi, xi) | |
bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x")) | |
optimize_read_cache(A_shared, A_local) | |
optimize_read_cache(B_shared, B_local) | |
return s, [n, A, B, C], tuple() | |
def get_te_schedule(name): | |
n = 2048 | |
if name == 'add': | |
A = te.placeholder((10,), name='A') | |
B = te.placeholder((10,), name='B') | |
C = te.compute((10,), lambda i: A[i] + B[i], name='C') | |
sch = te.create_schedule([C.op]) | |
sch[C].bind(C.op.axis[0], te.thread_axis("threadIdx.x")) | |
return sch, [A, B, C] | |
elif name == 'matmul': | |
return matmul(n=n) | |
elif name == 'matmul_v2': | |
sch, args, inputs = matmul_v2() | |
return sch, args | |
else: | |
raise ValueError() | |
if __name__ == '__main__': | |
for workload in ['add', 'matmul', 'matmul_v2']: | |
sch, args = get_te_schedule(workload) | |
with tvm.transform.PassContext(opt_level=3): | |
lib = tvm.lower(sch, args=args) | |
print(extract_launch_config(lib['main'])) | |
# Output: | |
# {'blockIdx.x': 1, 'blockIdx.y': 1, 'blockIdx.z': 1, 'threadIdx.x': 10, 'threadIdx.y': 1, 'threadIdx.z': 1} | |
# {'blockIdx.x': 32, 'blockIdx.y': 16, 'blockIdx.z': 1, 'threadIdx.x': 16, 'threadIdx.y': 16, 'threadIdx.z': 1} | |
# {'blockIdx.x': floordiv((n + 63), 64), 'blockIdx.y': floordiv((n + 127), 128), 'blockIdx.z': 1, 'threadIdx.x': 16, | |
# 'threadIdx.y': 16, 'threadIdx.z': 1} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment