-
-
Save KireinaHoro/93a934edc3e472ccc7c592b88e915a59 to your computer and use it in GitHub Desktop.
Matrix multiplication tensorize
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "gemm.py", line 106, in <module> | |
print(tvm.lower(s, [A, B, C], simple_mode=True)) | |
File "/home/jsteward/work/tvm/python/tvm/build_module.py", line 392, in lower | |
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) | |
File "tvm/_ffi/_cython/./function.pxi", line 304, in tvm._ffi._cy3.core.FunctionBase.__call__ | |
File "tvm/_ffi/_cython/./function.pxi", line 249, in tvm._ffi._cy3.core.FuncCall | |
File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL | |
tvm._ffi.base.TVMError: Traceback (most recent call last): | |
[bt] (8) /home/jsteward/work/tvm/build/libtvm.so(tvm::NodeFunctor<tvm::Stmt (tvm::runtime::ObjectRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::runtime::ObjectRef const&, tvm::Stmt const&, tvm::ir::IRMutator*) const+0x62) [0x7f62d4095212] | |
[bt] (7) /home/jsteward/work/tvm/build/libtvm.so(+0x6642eb) [0x7f62d42dd2eb] | |
[bt] (6) /home/jsteward/work/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate_(tvm::ir::For const*, tvm::Stmt const&)+0xb9) [0x7f62d42df4c9] | |
[bt] (5) /home/jsteward/work/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate(tvm::Stmt)+0x5b) [0x7f62d409537b] | |
[bt] (4) /home/jsteward/work/tvm/build/libtvm.so(tvm::NodeFunctor<tvm::Stmt (tvm::runtime::ObjectRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::runtime::ObjectRef const&, tvm::Stmt const&, tvm::ir::IRMutator*) const+0x62) [0x7f62d4095212] | |
[bt] (3) /home/jsteward/work/tvm/build/libtvm.so(+0x66424b) [0x7f62d42dd24b] | |
[bt] (2) /home/jsteward/work/tvm/build/libtvm.so(tvm::ir::StorageFlattener::Mutate_(tvm::ir::AttrStmt const*, tvm::Stmt const&)+0x86e) [0x7f62d438d45e] | |
[bt] (1) /home/jsteward/work/tvm/build/libtvm.so(tvm::ir::StorageFlattener::HandleBufferBindScope(tvm::ir::AttrStmt const*)+0xafa) [0x7f62d438a79a] | |
[bt] (0) /home/jsteward/work/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x43) [0x7f62d4024013] | |
File "../src/pass/storage_flatten.cc", line 432 | |
TVMError: Check failed: slice->strides.size() == 0U (2 vs. 0) : Trying to bind compact buffer to strided one strides=[512, 1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import te | |
import numpy as np | |
#target = "llvm -device=riscv_cpu -target=riscv64-unknown-linux-gnu -mfloat-abi=soft" | |
target = "llvm" | |
dtype = "int8" | |
DIM = 8 | |
N, M, L = 512, 512, 512 | |
A = tvm.placeholder((N, L), name='A', dtype=dtype) | |
B = tvm.placeholder((L, M), name='B', dtype=dtype) | |
k = tvm.reduce_axis((0, L), name='k') | |
C = tvm.compute((N, M), lambda i, j: | |
tvm.sum(A[i, k] * B[k, j], axis=k), name='C') | |
s = tvm.create_schedule(C.op) | |
factor = 16 | |
x, y = C.op.axis | |
z, = C.op.reduce_axis | |
yo, yi = s[C].split(y, factor=factor) | |
xo, xi = s[C].split(x, factor=factor) | |
zo, zi = s[C].split(z, factor=factor) | |
s[C].reorder(xo, yo, zo, xi, yi, zi) | |
def intrinsic_gemm(i, j, k): | |
""" | |
(i, k) * (k, j) | |
""" | |
a = tvm.placeholder((i, k), name='a', dtype=dtype) | |
b = tvm.placeholder((k, j), name='b', dtype=dtype) | |
kk = tvm.reduce_axis((0, k), name='k') | |
c = tvm.compute((i, j), lambda ii, jj: | |
tvm.sum(A[ii, kk] * B[kk, jj], axis=kk), name='c') | |
strideA = tvm.var("sA") | |
Ab = tvm.decl_buffer(a.shape, a.dtype, | |
name="A", | |
offset_factor=1, | |
strides=[strideA, 1]) | |
strideB = tvm.var("sB") | |
Bb = tvm.decl_buffer(b.shape, b.dtype, | |
name="B", | |
offset_factor=1, | |
strides=[strideB, 1]) | |
strideC = tvm.var("sC") | |
Cb = tvm.decl_buffer(c.shape, c.dtype, | |
name="C", | |
offset_factor=1, | |
strides=[strideC, 1]) | |
# Reduce-update pattern | |
def intrin_func(ins, outs): | |
aa, bb = ins | |
cc, = outs | |
print(aa.strides, bb.strides, cc.strides) | |
def _body(): | |
ib = tvm.ir_builder.create() | |
ib.emit(tvm.call_extern("int32", "matmul_kernel", | |
aa.access_ptr("r"), | |
bb.access_ptr("r"), | |
cc.access_ptr("r"), | |
cc.access_ptr("w"), | |
i // DIM, j // DIM, k // DIM, | |
i % DIM, j % DIM, k % DIM, | |
strideA, strideB, strideC, strideC, | |
False, False)) | |
return ib.get() | |
def _reset(): | |
ib = tvm.ir_builder.create() | |
ib.emit(tvm.call_extern("int32", "matmul_reset", | |
cc.access_ptr("w"), | |
i // DIM, j // DIM, | |
i % DIM, j % DIM, | |
strideC)) | |
return ib.get() | |
# standalone (without reduce axis split), reset, update | |
return _body(), _reset(), _body() | |
with tvm.build_config(offset_factor=1): | |
return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}, name="sp_gemm") | |
gemm = intrinsic_gemm(factor, factor, factor) | |
s[C].tensorize(xi, gemm) | |
# Implementation for the gemv_update function - C code | |
def intrinsic_impls(filename): | |
with open(filename, "r") as f: | |
cc_code = f.read() | |
from tvm.contrib import util, clang | |
clang.find_clang(required=True) | |
temp = util.tempdir() | |
ll_path = temp.relpath("temp.ll") | |
# Create LLVM ir from C++ source code | |
import os | |
ll_code = clang.create_llvm(cc_code, output=ll_path, options=[ | |
"-O3", | |
f"-I{os.path.dirname(os.path.realpath(filename))}" | |
]) | |
return ll_code | |
s[C].pragma(xo, "import_llvm", intrinsic_impls("kernel-cpu.c")) | |
#s[C].pragma(yo, "epilogue", "do_fence") | |
print(tvm.lower(s, [A, B, C], simple_mode=True)) | |
func = tvm.build(s, [A, B, C], target=target, name="gemm") | |
out_llvm_ir = "gemm.ll" | |
with open(out_llvm_ir, "w") as f: | |
f.write(func.get_source()) | |
print(f"Written LLVM IR to {out_llvm_ir}.") | |
if target == "llvm": | |
from topi.util import get_const_tuple | |
dtype = A.dtype | |
ctx = tvm.context("cpu", 0) | |
ti = np.iinfo("int8") | |
a = np.random.uniform(low=ti.min, high=ti.max, size=get_const_tuple(A.shape)).astype(dtype) | |
b = np.random.uniform(low=ti.min, high=ti.max, size=get_const_tuple(B.shape)).astype(dtype) | |
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx) | |
func(tvm.nd.array(a, ctx), tvm.nd.array(b, ctx), c) | |
np.testing.assert_allclose(c.asnumpy(), np.dot(a, b), rtol=0) | |
print("Kernel executed and passed test.") | |
else: | |
print(f"Not running kernel for target '{target}'.'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "gemmini_params.h" | |
#include <assert.h> | |
#include <stdbool.h> | |
#include <stddef.h> | |
// fence | |
#define gemmini_fence() asm volatile("fence") | |
void do_fence() { gemmini_fence(); } | |
// CPU version of the kernel that runs on the device. | |
// Kernel size requirements for scratchpad size are enforced. | |
// C(i, j) = A(i, k) * B(k, j) + D(i, j) | |
// i = I * DIM, j = j * DIM, k = k * DIM | |
// A_row_len: row length for A, used for stride | |
// | |
// this function is expected to be called from CPU, so no need for padding | |
int32_t matmul_reset(elem_t *C, size_t I, size_t J, size_t pad_I, size_t pad_J, | |
size_t C_row_len) { | |
for (int i = 0; i < I * DIM - pad_I; ++i) { | |
for (int j = 0; j < J * DIM - pad_J; ++j) { | |
C[i * C_row_len + j] = 0; | |
} | |
} | |
} | |
int32_t matmul_kernel(const elem_t *A, const elem_t *B, const acc_t *D, | |
elem_t *C, size_t I, size_t J, size_t K, size_t pad_I, | |
size_t pad_J, size_t pad_K, size_t A_row_len, | |
size_t B_row_len, size_t D_row_len, size_t C_row_len, | |
bool no_bias, bool repeating_bias) { | |
assert(!repeating_bias && "repeating bias not supported"); | |
assert((no_bias || D) && "bias requested but D is NULL"); | |
if (no_bias) { | |
for (int i = 0; i < I * DIM - pad_I; ++i) { | |
for (int j = 0; j < J * DIM - pad_J; ++j) { | |
C[i * C_row_len + j] = 0; | |
for (int k = 0; k < K * DIM - pad_K; ++k) { | |
C[i * C_row_len + j] += A[i * A_row_len + k] * B[k * B_row_len + j]; | |
} | |
} | |
} | |
} else { | |
for (int i = 0; i < I * DIM - pad_I; ++i) { | |
for (int j = 0; j < J * DIM - pad_J; ++j) { | |
C[i * C_row_len + j] = D[i * D_row_len + j]; | |
for (int k = 0; k < K * DIM - pad_K; ++k) { | |
C[i * C_row_len + j] += A[i * A_row_len + k] * B[k * B_row_len + j]; | |
} | |
} | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment