Skip to content

Instantly share code, notes, and snippets.

@ibeltagy
Created March 20, 2020 23:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ibeltagy/28b21bd0ba5b1e83ca83bc32794d1e0f to your computer and use it in GitHub Desktop.
Save ibeltagy/28b21bd0ba5b1e83ca83bc32794d1e0f to your computer and use it in GitHub Desktop.
Lowering:
=====================
// attr [R] storage_scope = "global"
allocate R[float32 * ((bsz*d1)*d2)]
produce R {
// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = bsz
// attr [R.local] storage_scope = "local"
allocate R.local[float32 * 64]
// attr [A.shared] storage_scope = "shared"
allocate A.shared[float32 * 512]
// attr [B.shared] storage_scope = "shared"
allocate B.shared[float32 * 512]
// attr [A.shared.local] storage_scope = "local"
allocate A.shared.local[float32 * 8]
// attr [B.shared.local] storage_scope = "local"
allocate B.shared.local[float32 * 8]
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = floordiv((d1 + 63), 64)
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = floordiv((d2 + 63), 64)
// attr [iter_var(threadIdx.y, range(min=0, ext=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, range(min=0, ext=8), threadIdx.x)] thread_extent = 8
produce R.local {
for (i.c.init, 0, 8) {
for (j.c.init, 0, 8) {
R.local[((i.c.init*8) + j.c.init)] = 0f
}
}
for (k.outer, 0, floordiv((d3 + 7), 8)) {
produce A.shared {
for (ax1.inner, 0, 8) {
if (likely(((((blockIdx.y*64) + (threadIdx.y*8)) + ax1.inner) < d1))) {
if (likely((((k.outer*8) + threadIdx.x) < d3))) {
A.shared[(((threadIdx.y*64) + (ax1.inner*8)) + threadIdx.x)] = A[(((blockIdx.z*stride) + ((((blockIdx.y*64) + (threadIdx.y*8)) + ax1.inner)*stride)) + (((k.outer*8) + threadIdx.x)*stride))]
}
}
}
}
produce B.shared {
for (ax1.inner, 0, 8) {
if (likely(((((blockIdx.x*64) + (threadIdx.y*8)) + ax1.inner) < d2))) {
if (likely((((k.outer*8) + threadIdx.x) < d3))) {
B.shared[(((threadIdx.y*64) + (ax1.inner*8)) + threadIdx.x)] = B[(((blockIdx.z*stride) + ((((blockIdx.x*64) + (threadIdx.y*8)) + ax1.inner)*stride)) + (((k.outer*8) + threadIdx.x)*stride))]
}
}
}
}
for (k.inner, 0, 8) {
produce A.shared.local {
for (ax1, 0, 8) {
if (likely(((((blockIdx.y*64) + (threadIdx.y*8)) + ax1) < d1))) {
if (likely((((k.outer*8) + k.inner) < d3))) {
A.shared.local[ax1] = A.shared[(((threadIdx.y*64) + (ax1*8)) + k.inner)]
}
}
}
}
produce B.shared.local {
for (ax1, 0, 8) {
if (likely(((((blockIdx.x*64) + (threadIdx.x*8)) + ax1) < d2))) {
if (likely((((k.outer*8) + k.inner) < d3))) {
B.shared.local[ax1] = B.shared[(((threadIdx.x*64) + (ax1*8)) + k.inner)]
}
}
}
}
for (i.c, 0, 8) {
for (j.c, 0, 8) {
if (likely((((k.outer*8) + k.inner) < d3))) {
if (likely(((((blockIdx.y*64) + (threadIdx.y*8)) + i.c) < d1))) {
if (likely(((((blockIdx.x*64) + (threadIdx.x*8)) + j.c) < d2))) {
if (likely((((k.outer*8) + k.inner) < d3))) {
R.local[((i.c*8) + j.c)] = (R.local[((i.c*8) + j.c)] + (A.shared.local[i.c]*B.shared.local[j.c]))
}
}
}
}
}
}
}
}
}
for (i.inner.inner, 0, 8) {
for (j.inner.inner, 0, 8) {
if (likely(((((blockIdx.y*64) + (threadIdx.y*8)) + i.inner.inner) < d1))) {
if (likely(((((blockIdx.x*64) + (threadIdx.x*8)) + j.inner.inner) < d2))) {
if (likely(((((blockIdx.y*64) + (threadIdx.y*8)) + i.inner.inner) < d1))) {
if (likely(((((blockIdx.x*64) + (threadIdx.x*8)) + j.inner.inner) < d2))) {
R[((((blockIdx.x*64) + (threadIdx.x*8)) + (((((blockIdx.y*64) + (threadIdx.y*8)) + (blockIdx.z*d1)) + i.inner.inner)*d2)) + j.inner.inner)] = R.local[((i.inner.inner*8) + j.inner.inner)]
}
}
}
}
}
}
}
Traceback (most recent call last):
File "scripts/band_mm_minimal.py", line 166, in <module>
mm_fun_basic = _compile_function_basic(bsz, d1, d2, d3)
File "scripts/band_mm_minimal.py", line 132, in _compile_function_basic
return tvm.build(s, [A, B, R], target='cuda', target_host='llvm', name='mm1')
File "/usr/tvm/python/tvm/driver/build_module.py", line 407, in build
fhost, mdev = _build_for_device(flist, tar, target_host)
File "/usr/tvm/python/tvm/driver/build_module.py", line 236, in _build_for_device
func = ir_pass.ThreadSync(func, "shared")
File "/usr/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x57) [0x7f3fd992ef17]
[bt] (7) /usr/tvm/build/libtvm.so(tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)+0x7d) [0x7f3fd9b926ad]
[bt] (6) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x57) [0x7f3fd992ef17]
[bt] (5) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x57) [0x7f3fd992ef17]
[bt] (4) /usr/tvm/build/libtvm.so(tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)+0x7d) [0x7f3fd9b926ad]
[bt] (3) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x57) [0x7f3fd992ef17]
[bt] (2) /usr/tvm/build/libtvm.so(tvm::tir::StorageAccessVisitor::VisitStmt_(tvm::tir::ForNode const*)+0x118) [0x7f3fd9a8d1f8]
[bt] (1) /usr/tvm/build/libtvm.so(tvm::tir::ThreadSyncPlanner::Summarize(std::vector<tvm::tir::StorageAccessVisitor::StmtEntry, std::allocator<tvm::tir::StorageAccessVisitor::StmtEntry> >, tvm::tir::ForNode const*)+0x6ae) [0x7f3fd9b362ee]
[bt] (0) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f3fd984f642]
File "/usr/tvm/src/tir/pass/storage_sync.cc", line 93
TVMError: Check failed: condition_counter() == 0 (1 vs. 0) : Cannot insert syncs inside condition
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment