Skip to content

Instantly share code, notes, and snippets.

@nicolasvasilache
Last active June 24, 2021 06:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicolasvasilache/691ef992404c49dc9b5d543c4aa6db38 to your computer and use it in GitHub Desktop.
Save nicolasvasilache/691ef992404c49dc9b5d543c4aa6db38 to your computer and use it in GitHub Desktop.
Minimal Linalg + MKL shim
#include <assert.h>
#include <iostream>
#include "intel_mkl_dnn/include/mkldnn.hpp"
#include "intel_mkl_dnn/include/mkldnn_types.h"
#include "llvm-project/mlir/test/mlir-cpu-runner/include/cblas.h"
#include "llvm-project/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h"
extern "C" void linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X,
float f) {
*(X->data + X->offset) = f;
}
extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i) {
*(X->data + X->offset + i * X->strides[0]) = f;
}
}
extern "C" void linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i) {
StridedMemRefType<float, 1> V;
V.data = X->data;
V.offset = X->offset + i * X->strides[0];
V.sizes[0] = X->sizes[1];
V.strides[0] = X->strides[1];
linalg_fill_viewsxf32_f32(&V, f);
}
}
extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C) {
assert(A->strides[1] == B->strides[1]);
assert(A->strides[1] == C->strides[1]);
assert(A->strides[1] == 1);
assert(A->sizes[1] == B->sizes[0]);
assert(C->sizes[0] == A->sizes[0]);
assert(C->sizes[1] == B->sizes[1]);
assert(A->sizes[0] >= 0);
assert(B->sizes[0] >= 0);
assert(C->sizes[0] >= 0);
if (A->sizes[0] == 0 || A->sizes[1] == 0 || B->sizes[0] == 0 ||
B->sizes[1] == 0 || C->sizes[0] == 0 || C->sizes[1] == 0)
return;
assert(A->sizes[0] >= A->strides[1]);
assert(B->sizes[0] >= B->strides[1]);
assert(C->sizes[0] >= C->strides[1]);
float alpha = 1.0, beta = 1.0;
char transpose = 'T'; // mkldnn_sgemm BLAS API is column major!
int m = static_cast<int>(C->sizes[0]);
int n = static_cast<int>(C->sizes[1]);
int k = static_cast<int>(A->sizes[1]);
int lda = static_cast<int>(A->strides[0]);
int ldb = static_cast<int>(B->strides[0]);
int ldc = static_cast<int>(C->strides[0]);
auto res = mkldnn_sgemm(&transpose, &transpose, &m, &n, &k, &alpha,
A->data + A->offset, &lda, B->data + B->offset, &ldb,
&beta, C->data + C->offset, &ldc);
if (res != mkldnn_success) {
printMemRefMetaData(std::cerr, *A);
printMemRefMetaData(std::cerr, *B);
printMemRefMetaData(std::cerr, *C);
std::cerr << "MKLDNN failed with " << static_cast<int>(res) << "\n";
std::cerr << "m: " << m << "\n";
std::cerr << "n: " << n << "\n";
std::cerr << "k: " << k << "\n";
std::cerr << "lda: " << lda << "\n";
std::cerr << "ldb: " << ldb << "\n";
std::cerr << "ldc: " << ldc << "\n";
exit(1);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment