Last active
June 1, 2024 02:06
-
-
Save fabiovila/c0a41d8618beec0be0000670160638b8 to your computer and use it in GitHub Desktop.
Torch nn.Linear equivalent in C with cblas_sgemm ( bias = False ).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cblas.h> | |
#include <omp.h> | |
// gcc -shared -o lib.so -fPIC lib.c -I/opt/intel/oneapi/mkl/2024.1/include/ -fopenmp -DMKL_ILP64 -m64 -I"$MKLROOT/include" | |
void clinear(float *A, int ashape[3], float *B, int bshape[3], float *C) { | |
int M = ashape[1]; | |
int N = bshape[0]; | |
int K = ashape[2]; | |
#pragma omp parallel num_threads(2) // the performance declined when using more than 2 threads in my cpu | |
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, 1.0, A, K, B, K,0.0, C, N); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import ctypes | |
import torch | |
import torch.nn as nn | |
lib = ctypes.CDLL('./lib.so') | |
# Define a assinatura da função C | |
lib.clinear.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_float),ctypes.POINTER(ctypes.c_int32),ctypes.POINTER(ctypes.c_float)] | |
def CLinear(x,w): | |
w_ctypes = w.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) | |
x_ctypes = x.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) | |
wshape = np.array(w.shape , dtype=np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) | |
xshape = np.array(x.shape , dtype=np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) | |
output_array = np.zeros(shape = (1,xshape[1],wshape[0]), dtype=np.float32) | |
output_array_ctypes = output_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) | |
lib.clinear(x_ctypes, xshape, w_ctypes, wshape, output_array_ctypes) | |
return np.ctypeslib.as_array(output_array_ctypes, shape=(1,xshape[1],wshape[0])) | |
bx = 10 | |
bo = 10*2 | |
rx = 10+1 | |
m = nn.Linear(bo,bx, bias = False) | |
x = np.arange(bx*rx).reshape(1,rx,bx) | |
x = torch.from_numpy(x).float() | |
w = m.weight.detach().numpy() | |
x = x.detach().numpy() | |
out = CLinear(x,w) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment