Skip to content

Instantly share code, notes, and snippets.

@emfomenk
Created June 5, 2020 23:44
Show Gist options
  • Save emfomenk/4bdc70908cc5c30ad2a97e5030a22eaf to your computer and use it in GitHub Desktop.
Save emfomenk/4bdc70908cc5c30ad2a97e5030a22eaf to your computer and use it in GitHub Desktop.
BF16 gemm-like API
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <mutex>
#include <vector>
#include <type_traits>
#include "dnnl.hpp"
namespace proxy {
template <typename c_dt, bool beta_is_zero = true>
dnnl::status gemm_bf16bf16(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda,
const void *B, dnnl_dim_t ldb, float beta, c_dt *C, dnnl_dim_t ldc) {
using namespace dnnl;
using dims = memory::dims;
if ((int)get_effective_cpu_isa() < (int)cpu_isa::avx512_core)
return status::unimplemented;
static engine eng;
static matmul matmul_p;
static std::once_flag initialized;
static_assert(std::is_same<c_dt, float>::value ||
std::is_same<c_dt, void>::value, "expect float or void (bf16)");
constexpr memory::data_type bf16 = memory::data_type::bf16;
constexpr memory::data_type c_data_type =
std::is_same<c_dt, float>::value ? memory::data_type::f32 : bf16;
std::call_once(initialized, [=] {
eng = engine(engine::kind::cpu, 0);
memory::dims rt_rt_dims = {DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL};
memory::dims rt_1_dims = {DNNL_RUNTIME_DIM_VAL, DNNL_RUNTIME_DIM_VAL};
memory::desc ab_md(rt_rt_dims, bf16, rt_rt_dims);
memory::desc c_md(rt_rt_dims, c_data_type, rt_1_dims);
primitive_attr attr;
attr.set_output_scales(/* mask */ 0, {DNNL_RUNTIME_F32_VAL});
if (!beta_is_zero) {
assert(beta == 1.f); // current limitation
post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
matmul::desc matmul_d(ab_md, ab_md, c_md);
matmul::primitive_desc matmul_pd(matmul_d, attr, eng, true);
if (matmul_pd) matmul_p = matmul(matmul_pd);
});
bool ok = (bool)matmul_p
&& (!beta_is_zero || beta == 0.f)
&& (!!beta_is_zero || beta == 1.f);
if (!ok) return status::runtime_error;
dims a_strides = tolower(transa) == 'n' ? dims {lda, 1} : dims {1, lda};
dims b_strides = tolower(transb) == 'n' ? dims {ldb, 1} : dims {1, ldb};
memory A_m({{M, K}, bf16, a_strides}, eng, (void *)A);
memory B_m({{K, N}, bf16, b_strides}, eng, (void *)B);
memory C_m({{M, N}, c_data_type, {ldc, 1}}, eng, (void *)C);
// Prepare oneDNN memory for alpha
memory alpha_m({{1}, memory::data_type::f32, {1}}, eng, &alpha);
stream s(eng);
matmul_p.execute(s,
{{DNNL_ARG_SRC, A_m}, {DNNL_ARG_WEIGHTS, B_m}, {DNNL_ARG_DST, C_m},
{DNNL_ARG_ATTR_OUTPUT_SCALES, alpha_m}});
s.wait();
return status::success;
}
}
dnnl::status gemm_bf16bf16bf16(char transa, char transb, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda,
const void *B, dnnl_dim_t ldb, void *C, dnnl_dim_t ldc) {
return proxy::gemm_bf16bf16<void>(
transa, transb, M, N, K, alpha, A, lda, B, ldb, 0, C, ldc);
}
dnnl::status gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const void *A, dnnl_dim_t lda,
const void *B, dnnl_dim_t ldb, float *C, dnnl_dim_t ldc) {
return proxy::gemm_bf16bf16<float>(
transa, transb, M, N, K, alpha, A, lda, B, ldb, 0, C, ldc);
}
int main() {
const int M = 10, N = 20, K = 30;
short A[M * K], B[K * N], C_bf16[M * N];
float C_f32[M * N];
gemm_bf16bf16bf16('N', 'N', M, N, K, 1.f, A, K, B, N, C_bf16, N);
gemm_bf16bf16f32('N', 'N', M, N, K, 1.f, A, K, B, N, C_f32, N);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment