Last active
March 19, 2020 03:23
-
-
Save ibeltagy/61ccddd83fdd3956573289297efd7f13 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import tvm | |
from tvm.contrib import dlpack | |
from tvm import te | |
def _compile_function(b0: int = 4, b1: int = 4, b2: int = 16): | |
bsz = te.var('bsz') | |
d1 = te.var('d1') | |
d2 = te.var('d2') | |
d3 = te.var('d3') | |
A = te.placeholder((bsz, d1, d3), name='A', dtype='float32') # first tensor | |
B = te.placeholder((bsz, d2, d3), name='B', dtype='float32') # second tensor | |
k = te.reduce_axis((0, d3), name='k') # dimension to sum over | |
output_shape = (bsz, d1, d2) # shape of the result tensor | |
algorithm = lambda l, i, j: te.sum(A[l, i, k] * B[l, j, k], axis=k) | |
R = te.compute(output_shape, algorithm, name='R') | |
s = te.create_schedule(R.op) | |
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [A, B], simple_mode=True))) | |
ko, ki = s[R].split(R.op.reduce_axis[0], factor=b0) | |
RF = s.rfactor(R, ki) | |
j_outer, j_inner = s[R].split(s[R].op.axis[1], factor=b1) | |
i_outer, i_inner = s[R].split(s[R].op.axis[2], factor=b2) | |
s[R].bind(j_outer, te.thread_axis("blockIdx.x")) | |
s[R].bind(j_inner, te.thread_axis("threadIdx.y")) | |
s[R].bind(i_outer, te.thread_axis("blockIdx.y")) | |
s[R].bind(i_inner, te.thread_axis("threadIdx.z")) | |
tx = te.thread_axis("threadIdx.x") | |
s[R].bind(s[R].op.reduce_axis[0], tx) | |
s[RF].compute_at(s[R], s[R].op.reduce_axis[0]) | |
s[R].set_store_predicate(tx.var.equal(0)) | |
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [A, B], simple_mode=True))) | |
return tvm.build(s, [A, B, R], target='cuda', target_host='llvm', name='mm') | |
if __name__ == "__main__": | |
mm_fun = _compile_function() | |
mm_fun_pytorch = dlpack.to_pytorch_func(mm_fun) # wrap it as a pytorch function | |
bsz = 24 | |
d1 = 16384 | |
d2 = 512 | |
d3 = 64 | |
time1 = time2 = 0 | |
A = torch.randn(bsz, d1, d3, device='cuda') | |
B = torch.randn(bsz, d2, d3, device='cuda') | |
R = A.new_empty(bsz, d1, d2) # allocate memory for the result tensor | |
for i in range(50): | |
torch.cuda.synchronize() | |
start = time.time() | |
mm_fun_pytorch(A, B, R) | |
torch.cuda.synchronize() | |
time1 += time.time() - start | |
torch.cuda.synchronize() | |
start = time.time() | |
A.bmm(B.transpose(dim0=1, dim1=2)) | |
torch.cuda.synchronize() | |
time2 += time.time() - start | |
if i < 3: # first three calls are usually slower than the rest | |
time1 = time2 = 0 | |
else: | |
print('TVM: {0:.5f}s, PyTorch: {1:.5f}s, Speedup: {2:.5f}x'.format(time1, time2, time1/time2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment