-
-
Save csarofeen/c70958e2f413bf17a99713996a314cbe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
tx = tvm.thread_axis("threadIdx.x") | |
ty = tvm.thread_axis("threadIdx.y") | |
tz = tvm.thread_axis("threadIdx.z") | |
bx = tvm.thread_axis("blockIdx.x") | |
by = tvm.thread_axis("blockIdx.y") | |
bz = tvm.thread_axis("blockIdx.z") | |
M = 64 | |
N = 64 | |
K = 1024 | |
k = tvm.reduce_axis((0, K), 'k') | |
A = tvm.placeholder((M, K), name='A') | |
B = tvm.placeholder((N, K), name='B') | |
C = tvm.compute( | |
(M, N), | |
lambda m, n: | |
tvm.sum( | |
A[m, k] | |
* B[n, k] | |
, axis=k), | |
name='C' | |
) | |
s = tvm.create_schedule(C.op) | |
deps = [A, B, C] | |
print(tvm.lower(s, deps, simple_mode=True)) | |
""" | |
Each block will: | |
Load from A [m_tile*m_t_tile, k_tile*k_unroll] | |
Load from B [n_tile*n_t_tile, k_tile*k_unroll] | |
produce C results of [m_tile*m_t_tile, n_tile*n_t_tile] | |
Each thread will: | |
Load from A[m_t_tile, k_unroll] | |
Load from B[n_t_tile, k_unroll] | |
Produce C [m_t_tile, n_t_tile] | |
""" | |
m_tile = 2 #Maps to tidz | |
n_tile = 2 #Maps to tidy | |
k_tile = 2 #Maps to tidx | |
#Tile for each thread | |
m_t_tile = 2 | |
n_t_tile = 2 | |
k_unroll = 4 #The number of elements to vectorize loads on | |
s = tvm.create_schedule(C.op) | |
#C is the final reduction stage, it's responsible for the | |
#cross-thread reduction. | |
k, = s[C].op.reduce_axis | |
ko, ki = s[C].split(k, factor=k_tile*k_unroll) | |
kio, kii = s[C].split(ki, factor=k_unroll) | |
#This is the second reduction stage. Each thread summs their | |
#local values, they have k_unroll values to reduce | |
s[C].set_store_predicate(tx.var.equal(0)) | |
CF = s.rfactor(C, kio, -1) | |
s[CF].set_scope("shared") | |
#First reduction stage. This stage loads vector values from A/B | |
#and reduces that vector across the k dimension. Result is | |
#k_unroll values | |
ko, kii = s[CF].op.reduce_axis | |
CF2 = s.rfactor(CF, kii, -1) | |
s[CF2].set_scope("local") | |
#Set up our block and thread tiling | |
m, n = s[C].op.axis | |
kio, = s[C].op.reduce_axis | |
mo, no, mi, ni = s[C].tile(m, n, m_tile*m_t_tile, n_tile*n_t_tile) | |
mio, nio, mii, nii = s[C].tile(mi, ni, m_t_tile, n_t_tile) | |
#mo, no are outer dimensions to loop over, they get bound to grid dimensions | |
#mio, nio are the block tile dimensions | |
#mii, nii are the thread tile dimensions | |
#kio inner most dimension so we do 1 cross thread reduction at a time | |
s[C].reorder(mo, no, mio, nio, mii, nii, kio) | |
#Where does the block start? | |
block_point = nio | |
#Bind other reduction stages to this point | |
s[CF].compute_at(s[C], block_point) | |
s[CF2].compute_at(s[C], block_point) | |
s[C].bind(mo, by) | |
s[C].bind(no, bx) | |
s[C].bind(mio, tz) | |
s[C].bind(nio, ty) | |
s[C].bind(kio, tx) | |
mii, nii, kio = s[CF].op.axis | |
ko, = s[CF].op.reduce_axis | |
s[CF].bind(kio, tx) | |
mii, nii, kio, kii = s[CF2].op.axis | |
ko, = s[CF2].op.reduce_axis | |
s[CF2].reorder(ko, kio, mii, nii, kii) | |
s[CF2].compute_at(s[C], block_point) | |
s[CF2].vectorize(kii) | |
#s[CF2].bind(ko, tx) can't do this | |
""" | |
There's also an issue here if the above is fixed with | |
https://github.com/apache/incubator-tvm/pull/4270 | |
#rA = s.cache_read(A, "shared", CF2) | |
#s[rA].compute_at(s[CF2], ko) | |
""" | |
print(tvm.lower(s, deps, simple_mode=True)) | |
fcuda = tvm.build(s, deps, "cuda") | |
print(fcuda.imported_modules[0].get_source()) | |
""" | |
#Validation but I wrote it in PyTorch to check perf against. | |
import torch | |
import numpy | |
ctx = tvm.context('cuda', 0) | |
_A = torch.FloatTensor(M, K).uniform_() | |
_B = torch.FloatTensor(N, K).uniform_() | |
tvm_out = tvm.nd.array(numpy.zeros( (M, N), dtype=C.dtype), ctx) | |
torch_tensors = [_A, _B] | |
np_tensors = [tensor.numpy() for tensor in torch_tensors] | |
tvm_tensors = [tvm.nd.array(tensor, ctx) for tensor in np_tensors] | |
torch_tensors = [tensor.cuda() for tensor in torch_tensors] | |
_A, _B = torch_tensors | |
torch_out = torch.mm(_A, _B.transpose(0, 1)) | |
tvm_tensors.append(tvm_out) | |
fcuda(*tvm_tensors) | |
tvm.testing.assert_allclose(tvm_out.asnumpy(), torch_out.cpu().numpy(), rtol=1e-6, atol=.01) | |
print("Finished") | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment