Created
June 5, 2020 23:44
-
-
Save emfomenk/4bdc70908cc5c30ad2a97e5030a22eaf to your computer and use it in GitHub Desktop.
BF16 gemm-like API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <cctype> | |
#include <cmath> | |
#include <cstdio> | |
#include <mutex> | |
#include <vector> | |
#include <type_traits> | |
#include "dnnl.hpp" | |
namespace proxy { | |
template <typename c_dt, bool beta_is_zero = true> | |
dnnl::status gemm_bf16bf16(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, | |
dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda, | |
const void *B, dnnl_dim_t ldb, float beta, c_dt *C, dnnl_dim_t ldc) { | |
using namespace dnnl; | |
using dims = memory::dims; | |
if ((int)get_effective_cpu_isa() < (int)cpu_isa::avx512_core) | |
return status::unimplemented; | |
static engine eng; | |
static matmul matmul_p; | |
static std::once_flag initialized; | |
static_assert(std::is_same<c_dt, float>::value || | |
std::is_same<c_dt, void>::value, "expect float or void (bf16)"); | |
constexpr memory::data_type bf16 = memory::data_type::bf16; | |
constexpr memory::data_type c_data_type = | |
std::is_same<c_dt, float>::value ? memory::data_type::f32 : bf16; | |
std::call_once(initialized, [=] { | |
eng = engine(engine::kind::cpu, 0); | |
memory::dims rt_rt_dims = {DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL}; | |
memory::dims rt_1_dims = {DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL}; | |
memory::desc ab_md(rt_rt_dims, bf16, rt_rt_dims); | |
memory::desc c_md(rt_rt_dims, c_data_type, rt_1_dims); | |
primitive_attr attr; | |
attr.set_output_scales(/* mask */ 0, {DNNL_RUNTIME_F32_VAL}); | |
if (!beta_is_zero) { | |
assert(beta == 1.f); // current limitation | |
post_ops po; | |
po.append_sum(beta); | |
attr.set_post_ops(po); | |
} | |
matmul::desc matmul_d(ab_md, ab_md, c_md); | |
matmul::primitive_desc matmul_pd(matmul_d, attr, eng, true); | |
if (matmul_pd) matmul_p = matmul(matmul_pd); | |
}); | |
bool ok = (bool)matmul_p | |
&& (!beta_is_zero || beta == 0.f) | |
&& (!!beta_is_zero || beta == 1.f); | |
if (!ok) return status::runtime_error; | |
dims a_strides = tolower(transa) == 'n' ? dims {lda, 1} : dims {1, lda}; | |
dims b_strides = tolower(transb) == 'n' ? dims {ldb, 1} : dims {1, ldb}; | |
memory A_m({{M, K}, bf16, a_strides}, eng, (void *)A); | |
memory B_m({{K, N}, bf16, b_strides}, eng, (void *)B); | |
memory C_m({{M, N}, c_data_type, {ldc, 1}}, eng, (void *)C); | |
// Prepare oneDNN memory for alpha | |
memory alpha_m({{1}, memory::data_type::f32, {1}}, eng, &alpha); | |
stream s(eng); | |
matmul_p.execute(s, | |
{{DNNL_ARG_SRC, A_m}, {DNNL_ARG_WEIGHTS, B_m}, {DNNL_ARG_DST, C_m}, | |
{DNNL_ARG_ATTR_OUTPUT_SCALES, alpha_m}}); | |
s.wait(); | |
return status::success; | |
} | |
} | |
dnnl::status gemm_bf16bf16bf16(char transa, char transb, dnnl_dim_t M, | |
dnnl_dim_t N, dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda, | |
const void *B, dnnl_dim_t ldb, void *C, dnnl_dim_t ldc) { | |
return proxy::gemm_bf16bf16<void>( | |
transa, transb, M, N, K, alpha, A, lda, B, ldb, 0, C, ldc); | |
} | |
dnnl::status gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M, | |
dnnl_dim_t N, dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda, | |
const void *B, dnnl_dim_t ldb, float *C, dnnl_dim_t ldc) { | |
return proxy::gemm_bf16bf16<float>( | |
transa, transb, M, N, K, alpha, A, lda, B, ldb, 0, C, ldc); | |
} | |
int main() { | |
const int M = 10, N = 20, K = 30; | |
short A[M * K], B[K * N], C_bf16[M * N]; | |
float C_f32[M * N]; | |
gemm_bf16bf16bf16('N', 'N', M, N, K, 1.f, A, K, B, N, C_bf16, N); | |
gemm_bf16bf16f32('N', 'N', M, N, K, 1.f, A, K, B, N, C_f32, N); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment