Skip to content

Instantly share code, notes, and snippets.

@vinx13
Last active January 28, 2021 05:55
Show Gist options
  • Save vinx13/7e9741252efcab5f2435bb1abb76a265 to your computer and use it in GitHub Desktop.
Save vinx13/7e9741252efcab5f2435bb1abb76a265 to your computer and use it in GitHub Desktop.
import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def foo(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128,), "float32")
B = tir.match_buffer(b, (128,), "int32")
C = tir.match_buffer(c, (128,), "float32")
reducer = tir.comm_reducer(lambda x, y: (x + y), tir.float32(0))
for i in range(0, 128):
for j in range(0, B[i]):
with tir.block([128, tir.reduce_axis(0, B[i])], 'C') as [ii, jj]:
reducer.step(C[ii], A[jj])
def sch2():
s = tir.create_schedule(foo)
C = s.get_block('C')
CC = s.cache_write(C, 0, 'local')
i = s.get_axes(C)
io,ii=s.split(i, factor=4)
s.compute_at(CC, io)
return s.func
print(tvm.script.asscript(sch2()))
'''
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# var definition
i = tir.var("int32")
C = tir.match_buffer(c, [128], elem_offset=0, align=128, offset_factor=1)
A = tir.match_buffer(a, [128], elem_offset=0, align=128, offset_factor=1)
B = tir.match_buffer(b, [128], dtype="int32", elem_offset=0, align=128, offset_factor=1)
reducer = tir.comm_reducer(lambda x, y: (x + y), tir.float32(0))
# body
with tir.block([], "root") as []:
tir.reads([])
tir.writes([])
C_local = tir.buffer_allocate([128], elem_offset=0, scope="local", align=128, offset_factor=1)
for ax0_outer in range(0, 32):
for ax0, ax1 in tir.grid(4, B[i]): <<<< REDUCE AXIS IS NOT UPDATED, I assume B[i] should be B[ax0_outer * 4 + ax0]
with tir.block([128, tir.reduce_axis(0, B[i])], "C") as [ii, jj]: <<<< REDUCE AXIS IS NOT UPDATED
tir.bind(ii, ((ax0_outer*4) + ax0))
tir.bind(jj, ax1)
tir.reads([C_local[ii:(ii + 1)], A[jj:(jj + 1)]])
tir.writes([C_local[ii:(ii + 1)]])
reducer.step(C_local[ii], A[jj])
for ax0_inner in range(0, 4):
with tir.block([128], "C_local") as [v0]:
tir.bind(v0, ((ax0_outer*4) + ax0_inner))
tir.reads([C_local[v0:(v0 + 1)]])
tir.writes([C[v0:(v0 + 1)]])
C[v0] = C_local[v0]
'''
with tvm.target.Target('llvm'):
func = sch2()
print(tvm.lower(func, None, simple_mode=True))
tvm.build(func)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment