Skip to content

Instantly share code, notes, and snippets.

@michel-steuwer
Last active September 11, 2020 15:50
Show Gist options
  • Save michel-steuwer/f5742f80c61775aef1180969fb307a0a to your computer and use it in GitHub Desktop.
Save michel-steuwer/f5742f80c61775aef1180969fb307a0a to your computer and use it in GitHub Desktop.
TVM Matrix Multiplication example
# Optimized algorithm
k = tvm.reduce_axis((0, K), 'k')
A = tvm.placeholder((M, K), name='A')
B = tvm.placeholder((K, N), name='B')
pB = tvm.compute((N / 32, K, 32), lambda x, y, z: B[y, x * 32 + z], name='pB')
C = tvm.compute((M,N), lambda x,y:tvm.sum(A[x,k] * pB[y//32,k,tvm.indexmod(y,32)], axis=k),name='C')
# Parallel schedule
s = tvm.create_schedule(C.op)
CC = s.cache_write(C, 'global')
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32)
s[CC].compute_at(s[C], yo)
xc, yc = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, factor=4)
s[CC].reorder(ko, xc, ki, yc)
s[CC].unroll(ki)
s[CC].vectorize(yc)
s[C].parallel(xo)
x, y, z = s[pB].op.axis
s[pB].vectorize(z)
s[pB].parallel(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment