Skip to content

Instantly share code, notes, and snippets.

@fabiovila
Last active June 1, 2024 02:06
Show Gist options
  • Save fabiovila/c0a41d8618beec0be0000670160638b8 to your computer and use it in GitHub Desktop.
Save fabiovila/c0a41d8618beec0be0000670160638b8 to your computer and use it in GitHub Desktop.
Torch nn.Linear equivalent in C with cblas_sgemm ( bias = False ).
#include <cblas.h>
#include <omp.h>
// gcc -shared -o lib.so -fPIC lib.c -I/opt/intel/oneapi/mkl/2024.1/include/ -fopenmp -DMKL_ILP64 -m64 -I"$MKLROOT/include"
void clinear(float *A, int ashape[3], float *B, int bshape[3], float *C) {
int M = ashape[1];
int N = bshape[0];
int K = ashape[2];
#pragma omp parallel num_threads(2) // the performance declined when using more than 2 threads in my cpu
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, 1.0, A, K, B, K,0.0, C, N);
}
import numpy as np
import ctypes
import torch
import torch.nn as nn
lib = ctypes.CDLL('./lib.so')
# Define a assinatura da função C
lib.clinear.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_float),ctypes.POINTER(ctypes.c_int32),ctypes.POINTER(ctypes.c_float)]
def CLinear(x,w):
w_ctypes = w.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
x_ctypes = x.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
wshape = np.array(w.shape , dtype=np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
xshape = np.array(x.shape , dtype=np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
output_array = np.zeros(shape = (1,xshape[1],wshape[0]), dtype=np.float32)
output_array_ctypes = output_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
lib.clinear(x_ctypes, xshape, w_ctypes, wshape, output_array_ctypes)
return np.ctypeslib.as_array(output_array_ctypes, shape=(1,xshape[1],wshape[0]))
bx = 10
bo = 10*2
rx = 10+1
m = nn.Linear(bo,bx, bias = False)
x = np.arange(bx*rx).reshape(1,rx,bx)
x = torch.from_numpy(x).float()
w = m.weight.detach().numpy()
x = x.detach().numpy()
out = CLinear(x,w)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment