Skip to content

Instantly share code, notes, and snippets.

@seungrokjung
Created May 5, 2023 12:22
GPTQ quantization kernel of dequantization + fp16 gemm operation compatible with Hipify. Original cuda code is from "https://github.com/oobabooga/GPTQ-for-LLaMa/blob/cuda/quant_cuda_kernel.cu"
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// atomicAdd for double-precision floating-point numbers on hardware with
// compute capability < 6.0 from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
__device__ double atomicAdd(
double* address,
double val
) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull,
assumed,
__double_as_longlong(val + __longlong_as_double(assumed))
);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
__global__ void VecQuant2MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
__global__ void VecQuant3MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
__global__ void VecQuant4MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
);
const int BLOCKWIDTH = 256;
const int BLOCKHEIGHT2 = 16;
const int BLOCKHEIGHT3 = 24;
const int BLOCKHEIGHT4 = 32;
const int BLOCKHEIGHT8 = 64;
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
void vecquant2matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int g_h = h * 16;
int k = 0;
int z_w = w / 16;
int z_mod = (w % 16) * 2;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
i += width;
k += 16;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant3matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int g_h = (h / 3) * 32;
int k = 0;
int z_w = (w / 32) * 3;
int z_mod = w % 32;
int z_bit;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit -= 22;
z_bit *= 3;
z_bit += 2;
z_w += 2;
} else if (z_bit > 10){
z_bit -= 11;
z_bit *= 3;
z_bit += 1;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
unsigned int tmp1;
unsigned int tmp2;
unsigned int tmp;
unsigned int z_tmp;
while (k < BLOCKWIDTH) {
tmp1 = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scale * scalar_t((z_tmp) + 1);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scale * scalar_t((z_tmp) + 1);
} else {
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
}
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp2 = as_unsigned(mat[i]);
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
tmp2 >>= 1;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp1 = as_unsigned(mat[i]);
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
tmp1 >>= 2;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
k += 10;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant4matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int g_h = h * 8;
int k = 0;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
i += width;
k += 8;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int g_h = h * 4;
int k = 0;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
i += width;
k += 4;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant2matmul_faster_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize,
int vec_height
) {
int batch = vec.size(0);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
VecQuant2MatMulKernelFaster<<<blocks, threads>>>(
(half2*) vec.data_ptr(),
mat.data_ptr<int>(),
mul.data_ptr<float>(),
scales.data_ptr<float>(),
zeros.data_ptr<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
}
__global__ void VecQuant2MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int b = blockIdx.z;
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ half2 blockvec[blockwidth2];
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x];
__shared__ half2 deq2[16][16];
int val = threadIdx.x / 16;
int off = threadIdx.x % 16;
for (; val < 16; val += BLOCKWIDTH / 16) {
deq2[val][off] = __halves2half2(
__int2half_rn(val & 0x3), __int2half_rn(val >> 2)
);
}
int i = width * h + w;
int g_h = h * 16;
int k = 0;
int z_w = w / 16;
int z_mod = (w % 16) * 2;
float res = 0;
unsigned int tmp;
__syncthreads();
while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
half2 res2 = {};
float2 res_vec = {};
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xf][off], scale, zero), blockvec[k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 12) & 0xf][off], scale, zero), blockvec[k + 3], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xf][off], scale, zero), blockvec[k + 4], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 20) & 0xf][off], scale, zero), blockvec[k + 5], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xf][off], scale, zero), blockvec[k + 6], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2);
i += width;
k += 8;
res_vec += __half22float2(res2);
res += res_vec.x + res_vec.y;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant3matmul_faster_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize,
int vec_height
) {
int batch = vec.size(0);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
VecQuant3MatMulKernelFaster<<<blocks, threads>>>(
(half2*) vec.data_ptr(),
mat.data_ptr<int>(),
mul.data_ptr<float>(),
scales.data_ptr<float>(),
zeros.data_ptr<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
}
__global__ void VecQuant3MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int b = blockIdx.z;
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ half2 blockvec[blockwidth2];
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x];
__shared__ half2 deq2[64][32];
int val = threadIdx.x / 32;
int off = threadIdx.x % 32;
for (; val < 64; val += BLOCKWIDTH / 32) {
deq2[val][off] = __halves2half2(
__int2half_rn(val & 0x7), __int2half_rn(val >> 3)
);
}
int i = width * h + w;
int g_h = (h / 3) * 32;
int k = 0;
int z_w = (w / 32) * 3;
int z_mod = w % 32;
int z_bit;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit -= 22;
z_bit *= 3;
z_bit += 2;
z_w += 2;
} else if (z_bit > 10){
z_bit -= 11;
z_bit *= 3;
z_bit += 1;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
float res = 0;
unsigned int tmp1;
unsigned int tmp2;
unsigned int tmp;
unsigned int z_tmp;
__syncthreads();
while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
} else {
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
}
half2 res2 = {};
float2 res_vec = {};
tmp1 = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width;
tmp2 = as_unsigned(mat[i]);
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c);
res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2);
tmp2 >>= 4;
k += 6;
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
i += width;
tmp1 = as_unsigned(mat[i]);
tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30);
res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2);
tmp1 >>= 2;
k += 5;
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
i += width;
k += 5;
res_vec += __half22float2(res2);
res += res_vec.x + res_vec.y;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant4matmul_faster_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
int groupsize,
int vec_height
) {
int batch = vec.size(0);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
VecQuant4MatMulKernelFaster<<<blocks, threads>>>(
(half2*) vec.data_ptr(),
mat.data_ptr<int>(),
mul.data_ptr<float>(),
scales.data_ptr<float>(),
zeros.data_ptr<int>(),
batch, vec_height, height, width, zero_width, groupsize
);
}
__global__ void VecQuant4MatMulKernelFaster(
const half2* __restrict__ vec,
const int* __restrict__ mat,
float* __restrict__ mul,
const float* __restrict__ scales,
const int* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width,
int zero_width,
int groupsize
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ half2 blockvec[blockwidth2];
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x];
__shared__ half2 deq2[256][8];
int val = threadIdx.x / 8;
int off = threadIdx.x % 8;
for (; val < 256; val += BLOCKWIDTH / 8) {
deq2[val][off] = __halves2half2(
__int2half_rn(val & 0xF), __int2half_rn(val >> 4)
);
}
int i = width * h + w;
int g_h = h * 8;
int k = 0;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
float res = 0;
unsigned int tmp;
__syncthreads();
while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
half2 res2 = {};
float2 res_vec = {};
tmp = as_unsigned(mat[i]);
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scale, zero), blockvec[k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width;
k += 4;
res_vec += __half22float2(res2);
res += res_vec.x + res_vec.y;
}
atomicAdd(&mul[b * width + w], res);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment