Skip to content

Instantly share code, notes, and snippets.

@csarofeen
Created November 8, 2019 01:30
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csarofeen/c70958e2f413bf17a99713996a314cbe to your computer and use it in GitHub Desktop.
Save csarofeen/c70958e2f413bf17a99713996a314cbe to your computer and use it in GitHub Desktop.
import tvm
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
tz = tvm.thread_axis("threadIdx.z")
bx = tvm.thread_axis("blockIdx.x")
by = tvm.thread_axis("blockIdx.y")
bz = tvm.thread_axis("blockIdx.z")
M = 64
N = 64
K = 1024
k = tvm.reduce_axis((0, K), 'k')
A = tvm.placeholder((M, K), name='A')
B = tvm.placeholder((N, K), name='B')
C = tvm.compute(
(M, N),
lambda m, n:
tvm.sum(
A[m, k]
* B[n, k]
, axis=k),
name='C'
)
s = tvm.create_schedule(C.op)
deps = [A, B, C]
print(tvm.lower(s, deps, simple_mode=True))
"""
Each block will:
Load from A [m_tile*m_t_tile, k_tile*k_unroll]
Load from B [n_tile*n_t_tile, k_tile*k_unroll]
produce C results of [m_tile*m_t_tile, n_tile*n_t_tile]
Each thread will:
Load from A[m_t_tile, k_unroll]
Load from B[n_t_tile, k_unroll]
Produce C [m_t_tile, n_t_tile]
"""
m_tile = 2 #Maps to tidz
n_tile = 2 #Maps to tidy
k_tile = 2 #Maps to tidx
#Tile for each thread
m_t_tile = 2
n_t_tile = 2
k_unroll = 4 #The number of elements to vectorize loads on
s = tvm.create_schedule(C.op)
#C is the final reduction stage, it's responsible for the
#cross-thread reduction.
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=k_tile*k_unroll)
kio, kii = s[C].split(ki, factor=k_unroll)
#This is the second reduction stage. Each thread summs their
#local values, they have k_unroll values to reduce
s[C].set_store_predicate(tx.var.equal(0))
CF = s.rfactor(C, kio, -1)
s[CF].set_scope("shared")
#First reduction stage. This stage loads vector values from A/B
#and reduces that vector across the k dimension. Result is
#k_unroll values
ko, kii = s[CF].op.reduce_axis
CF2 = s.rfactor(CF, kii, -1)
s[CF2].set_scope("local")
#Set up our block and thread tiling
m, n = s[C].op.axis
kio, = s[C].op.reduce_axis
mo, no, mi, ni = s[C].tile(m, n, m_tile*m_t_tile, n_tile*n_t_tile)
mio, nio, mii, nii = s[C].tile(mi, ni, m_t_tile, n_t_tile)
#mo, no are outer dimensions to loop over, they get bound to grid dimensions
#mio, nio are the block tile dimensions
#mii, nii are the thread tile dimensions
#kio inner most dimension so we do 1 cross thread reduction at a time
s[C].reorder(mo, no, mio, nio, mii, nii, kio)
#Where does the block start?
block_point = nio
#Bind other reduction stages to this point
s[CF].compute_at(s[C], block_point)
s[CF2].compute_at(s[C], block_point)
s[C].bind(mo, by)
s[C].bind(no, bx)
s[C].bind(mio, tz)
s[C].bind(nio, ty)
s[C].bind(kio, tx)
mii, nii, kio = s[CF].op.axis
ko, = s[CF].op.reduce_axis
s[CF].bind(kio, tx)
mii, nii, kio, kii = s[CF2].op.axis
ko, = s[CF2].op.reduce_axis
s[CF2].reorder(ko, kio, mii, nii, kii)
s[CF2].compute_at(s[C], block_point)
s[CF2].vectorize(kii)
#s[CF2].bind(ko, tx) can't do this
"""
There's also an issue here if the above is fixed with
https://github.com/apache/incubator-tvm/pull/4270
#rA = s.cache_read(A, "shared", CF2)
#s[rA].compute_at(s[CF2], ko)
"""
print(tvm.lower(s, deps, simple_mode=True))
fcuda = tvm.build(s, deps, "cuda")
print(fcuda.imported_modules[0].get_source())
"""
#Validation but I wrote it in PyTorch to check perf against.
import torch
import numpy
ctx = tvm.context('cuda', 0)
_A = torch.FloatTensor(M, K).uniform_()
_B = torch.FloatTensor(N, K).uniform_()
tvm_out = tvm.nd.array(numpy.zeros( (M, N), dtype=C.dtype), ctx)
torch_tensors = [_A, _B]
np_tensors = [tensor.numpy() for tensor in torch_tensors]
tvm_tensors = [tvm.nd.array(tensor, ctx) for tensor in np_tensors]
torch_tensors = [tensor.cuda() for tensor in torch_tensors]
_A, _B = torch_tensors
torch_out = torch.mm(_A, _B.transpose(0, 1))
tvm_tensors.append(tvm_out)
fcuda(*tvm_tensors)
tvm.testing.assert_allclose(tvm_out.asnumpy(), torch_out.cpu().numpy(), rtol=1e-6, atol=.01)
print("Finished")
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment