Skip to content

Instantly share code, notes, and snippets.

@ibeltagy
Last active March 19, 2020 03:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ibeltagy/61ccddd83fdd3956573289297efd7f13 to your computer and use it in GitHub Desktop.
Save ibeltagy/61ccddd83fdd3956573289297efd7f13 to your computer and use it in GitHub Desktop.
import time
import torch
import tvm
from tvm.contrib import dlpack
from tvm import te
def _compile_function(b0: int = 4, b1: int = 4, b2: int = 16):
bsz = te.var('bsz')
d1 = te.var('d1')
d2 = te.var('d2')
d3 = te.var('d3')
A = te.placeholder((bsz, d1, d3), name='A', dtype='float32') # first tensor
B = te.placeholder((bsz, d2, d3), name='B', dtype='float32') # second tensor
k = te.reduce_axis((0, d3), name='k') # dimension to sum over
output_shape = (bsz, d1, d2) # shape of the result tensor
algorithm = lambda l, i, j: te.sum(A[l, i, k] * B[l, j, k], axis=k)
R = te.compute(output_shape, algorithm, name='R')
s = te.create_schedule(R.op)
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [A, B], simple_mode=True)))
ko, ki = s[R].split(R.op.reduce_axis[0], factor=b0)
RF = s.rfactor(R, ki)
j_outer, j_inner = s[R].split(s[R].op.axis[1], factor=b1)
i_outer, i_inner = s[R].split(s[R].op.axis[2], factor=b2)
s[R].bind(j_outer, te.thread_axis("blockIdx.x"))
s[R].bind(j_inner, te.thread_axis("threadIdx.y"))
s[R].bind(i_outer, te.thread_axis("blockIdx.y"))
s[R].bind(i_inner, te.thread_axis("threadIdx.z"))
tx = te.thread_axis("threadIdx.x")
s[R].bind(s[R].op.reduce_axis[0], tx)
s[RF].compute_at(s[R], s[R].op.reduce_axis[0])
s[R].set_store_predicate(tx.var.equal(0))
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [A, B], simple_mode=True)))
return tvm.build(s, [A, B, R], target='cuda', target_host='llvm', name='mm')
if __name__ == "__main__":
mm_fun = _compile_function()
mm_fun_pytorch = dlpack.to_pytorch_func(mm_fun) # wrap it as a pytorch function
bsz = 24
d1 = 16384
d2 = 512
d3 = 64
time1 = time2 = 0
A = torch.randn(bsz, d1, d3, device='cuda')
B = torch.randn(bsz, d2, d3, device='cuda')
R = A.new_empty(bsz, d1, d2) # allocate memory for the result tensor
for i in range(50):
torch.cuda.synchronize()
start = time.time()
mm_fun_pytorch(A, B, R)
torch.cuda.synchronize()
time1 += time.time() - start
torch.cuda.synchronize()
start = time.time()
A.bmm(B.transpose(dim0=1, dim1=2))
torch.cuda.synchronize()
time2 += time.time() - start
if i < 3: # first three calls are usually slower than the rest
time1 = time2 = 0
else:
print('TVM: {0:.5f}s, PyTorch: {1:.5f}s, Speedup: {2:.5f}x'.format(time1, time2, time1/time2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment