Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created July 30, 2021 14:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yzh119/33786c349be41731c0dd7acacac2dd58 to your computer and use it in GitHub Desktop.
Save yzh119/33786c349be41731c0dd7acacac2dd58 to your computer and use it in GitHub Desktop.
sparse workloads in tir
import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def csr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None:
m = tir.var('int32')
n = tir.var('int32')
k = tir.var('int32')
nnz = tir.var('int32')
indptr = tir.match_buffer(indptr_, [m + 1], 'int32')
indices = tir.match_buffer(indices_, [nnz], 'int32')
A = tir.match_buffer(a_data, [nnz], 'float32')
B = tir.match_buffer(b, [k, n], 'float32')
C = tir.match_buffer(c, [m, n], 'float32')
with tir.block([m, n], 'spmm_outer') as [vi, vj]:
with tir.init():
C[vi, vj] = 0.
with tir.block([tir.reduce_axis(indptr[vi], indptr[vi + 1])], 'spmm_inner') as [vk]:
C[vi, vj] = C[vi, vj] + A[vk] * B[indices[vk], vj]
@tvm.script.tir
def csr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None:
m = tir.var('int32')
n = tir.var('int32')
k = tir.var('int32')
nnz = tir.var('int32')
row = tir.match_buffer(row_, [nnz,], 'int32')
col = tir.match_buffer(col_, [nnz,], 'int32')
A = tir.match_buffer(a, [m, k], 'float32')
B = tir.match_buffer(b, [k, n], 'float32')
C = tir.match_buffer(c, [nnz,], 'float32')
with tir.block([nnz, tir.reduce_axis(0, k)], 'sddmm') as [eid, vk]:
with tir.init():
C[eid] = 0.
C[eid] = C[eid] + A[row[eid], vk] * B[vk, col[eid]]
@tvm.script.tir
def bsr_spmm(indptr_: ty.handle, indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None:
mb = tir.var('int32')
n = tir.var('int32')
kb = tir.var('int32')
nnzb = tir.var('int32')
block_size = tir.var('int32')
indptr = tir.match_buffer(indptr_, [mb + 1], 'int32')
indices = tir.match_buffer(indices_, [nnzb], 'int32')
A = tir.match_buffer(a_data, [nnzb, block_size, block_size], 'float32')
B = tir.match_buffer(b, [kb, block_size, n], 'float32')
C = tir.match_buffer(c, [mb, block_size, n], 'float32')
with tir.block([mb, tir.reduce_axis(0, block_size), block_size, n], 'spmm_outer') as [io, ki, ii, j]:
with tir.init():
C[io, ii, j] = 0.
with tir.block([tir.reduce_axis(indptr[io], indptr[io + 1])], 'spmm_inner') as [ko]:
C[io, ii, j] = C[io, ii, j] + A[ko, ii, ki] * B[indices[ko], ki, j]
@tvm.script.tir
def bsr_sddmm(row_: ty.handle, col_: ty.handle, a: ty.handle, b: ty.handle, c: ty.handle) -> None:
mb = tir.var('int32')
nb = tir.var('int32')
k = tir.var('int32')
nnzb = tir.var('int32')
block_size = tir.var('int32')
row = tir.match_buffer(row_, [nnzb,], 'int32')
col = tir.match_buffer(col_, [nnzb,], 'int32')
A = tir.match_buffer(a, [mb, block_size, k], 'float32')
B = tir.match_buffer(b, [k, nb, block_size], 'float32')
C = tir.match_buffer(c, [nnzb, block_size, block_size], 'float32')
with tir.block([nnzb, block_size, block_size, tir.reduce_axis(0, k)], 'sddmm') as [bid, vi, vj, vk]:
with tir.init():
C[bid, vi, vj] = 0.
C[bid, vi, vj] = C[bid, vi, vj] + A[row[bid], vi, vk] * B[vk, col[bid], vj]
@tvm.script.tir
def ell_spmm(indices_: ty.handle, a_data: ty.handle, b: ty.handle, c: ty.handle) -> None:
mb = tir.var('int32')
n = tir.var('int32')
kb = tir.var('int32')
block_size = tir.var('int32')
ell_cols = tir.var('int32')
indices = tir.match_buffer(indices_, [mb, ell_cols], 'int32')
A = tir.match_buffer(a_data, [mb, ell_cols, block_size, block_size], 'float32')
B = tir.match_buffer(b, [kb, block_size, n], 'float32')
C = tir.match_buffer(c, [mb, block_size, n], 'float32')
with tir.block([mb, tir.reduce_axis(0, ell_cols), tir.reduce_axis(0, block_size), block_size, n], 'spmm') as [io, ko, ki, ii, j]:
with tir.init():
C[io, ii, j] = 0.
C[io, ii, j] = C[io, ii, j] + A[io, ko, ii, ki] * B[indices[io, ko], ki, j]
if __name__ == '__main__':
print(csr_spmm)
print(csr_sddmm)
print(bsr_spmm)
print(bsr_sddmm)
print(ell_spmm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment