Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created August 5, 2021 08:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yzh119/9c71e6b0b9c43741bf7f1288b195d275 to your computer and use it in GitHub Desktop.
Save yzh119/9c71e6b0b9c43741bf7f1288b195d275 to your computer and use it in GitHub Desktop.
ell spmm with multi-level tiling
import tvm
from tvm import tir
from tvm.script import ty
from tvm.tir.schedule.schedule import Schedule
@tvm.script.tir
def ell_spmm(indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None:
mb = tir.var('int32')
n = tir.var('int32')
kb = tir.var('int32')
block_size = tir.var('int32')
ell_cols = tir.var('int32')
indices = tir.match_buffer(indices_, [mb, ell_cols], 'int32')
A = tir.match_buffer(a_data, [mb, ell_cols, block_size, block_size], 'float32')
B = tir.match_buffer(b, [kb, block_size, n], 'float32')
C = tir.match_buffer(c, [mb, block_size, n], 'float32')
with tir.block([mb, tir.reduce_axis(0, ell_cols), tir.reduce_axis(0, block_size), block_size, n], 'spmm') as [io, ko, ki, ii, j]:
with tir.init():
C[io, ii, j] = 0.
C[io, ii, j] = C[io, ii, j] + A[io, ko, ii, ki] * B[indices[io, ko], ki, j]
def schedule(sch: tir.schedule):
block = sch.get_block("spmm")
io, ko, ki, ii, j = sch.get_loops(block)
ii, i_tc = sch.split(ii, factors=[none, 16])
ki, k_tc = sch.split(ki, factors=[none, 16])
j, j_tc = sch.split(j, factors=[none, 16])
sch.reorder(
io, j, ko, ii, ki,
i_tc, j_tc, k_tc,
)
block_inner = sch.blockize(i_tc)
block_outer, block_inner = block_inner, block
del block
i0, i1, i2, i3 = sch.split(io, factors=sch.sample_perfect_tile(io, n=4))
j0, j1, j2, j3, j4 = sch.split(j, factors=sch.sample_perfect_tile(j, n=5))
k0, k1 = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
sch.reorder(
# fmt: off
i0, j0, # s => blockidx.x
i1, j1, # s => vthread
i2, j2, # s => threadidx.x
# cache_write here
k0, # r
# vectorized cooperative fetching here
k1, # r
i3, j3, # s
ki, # r
ii, j4, # s
# fmt: on
)
block_idx = sch.fuse(i0, j0)
vthread = sch.fuse(i1, j1)
thread_idx = sch.fuse(i2, j2)
sch.bind(block_idx, "blockidx.x")
sch.bind(vthread, "vthread")
sch.bind(thread_idx, "threadidx.x")
block_write_c = sch.cache_write(block_outer, 0, "local")
block_outer, block_write_c = block_write_c, block_outer
sch.reverse_compute_at(block_write_c, thread_idx)
def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0)
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
fused_0, fused_1 = sch.split(fused, factors=[none, 4])
sch.mark_loop(fused_0, "loop_type", "lazy_cooperative_fetch")
sch.vectorize(fused_1)
fetch_to_shared(block_outer, 1, 2)
fetch_to_shared(block_outer, 2, 2)
# read indices from global to local
indices_read = sch.cache_read(block_inner, 3, "local")
loop = sch.get_loops(block_outer)[-1]
sch.compute_at(indices_read, loop)
# step 3. postproc-rewrite-tensorize
# step 3.1. cache read
loop = sch.get_loops(block_outer)[-1]
block_read_a = sch.cache_read(block_inner, 1, "wmma.matrix_a")
block_read_b = sch.cache_read(block_inner, 2, 'wmma.matrix_b')
sch.compute_at(block_read_a, loop)
sch.compute_at(block_read_b, loop)
# step 3.2. cache write
block_write_c = sch.cache_write(block_outer, 0, 'wmma.accumulator')
block_outer, block_write_c = block_write_c, block_outer
sch.reverse_compute_at(block_write_c, loop)
# step 3.3. decompose
loop = sch.get_loops(block_outer)[3]
block_init_c = sch.decompose_reduction(block_outer, loop)
print(tvm.script.asscript(sch.mod['main']))
if __name__ == '__main__':
f = ell_spmm
m, n, k = 4096, 512, 4096
block_size = 32
ell_cols = 16
indices_, a_data, b, c = f.params
f = f.specialize({indices_: tir.decl_buffer([m // block_size, ell_cols]),
a_data: tir.decl_buffer([m // block_size, ell_cols, block_size, block_size]),
b: tir.decl_buffer([k // block_size, block_size, n]),
c: tir.decl_buffer([m // block_size, block_size, n])})
sch = schedule(f, debug_mode=true, traced=true)
schedule(sch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment