Skip to content

Instantly share code, notes, and snippets.

@malfet
Created August 6, 2023 21:31
Show Gist options
  • Save malfet/d9aaf3faf8b62e073f963085aa7d629b to your computer and use it in GitHub Desktop.
Save malfet/d9aaf3faf8b62e073f963085aa7d629b to your computer and use it in GitHub Desktop.
// My attempt at FP8 matmul implementation
#include <iostream>
#include <vector>
#include <numeric>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <stdio.h>
void checkStatus(cublasStatus_t status) {
if (status)
throw std::runtime_error("cublasLt call failed " + std::to_string(status));
}
void checkStatus(cudaError_t status) {
if (status)
throw std::runtime_error("cudart call failed " + std::to_string(status));
}
template<typename T>
struct GPUMem {
GPUMem(size_t num_): num(num_) {
checkStatus(cudaMalloc(&ptr, sizeof(T)*num));
}
GPUMem(const GPUMem&) = delete; // delete copy constructor
~GPUMem() {
auto rc = cudaFree(ptr);
if (rc) {
std::cerr << "Failed to free GPU mem " + rc << std::endl;
}
}
void toGPU(const std::vector<T>& data) {
checkStatus(cudaMemcpy(ptr, data.data(), sizeof(T)*std::min(data.size(), num), cudaMemcpyHostToDevice));
}
std::vector<T> toCPU() const {
std::vector<T> data(num);
checkStatus(cudaMemcpy(data.data(), ptr, sizeof(T)*num, cudaMemcpyDeviceToHost));
return data;
}
operator T*() const { return ptr; }
T* data() const { return ptr; }
size_t size() const { return num; }
private:
T *ptr;
size_t num;
};
__global__ void convert_to_fp8_e4m3(float* in, __nv_fp8_e4m3* out) {
out[threadIdx.x] = __nv_fp8_e4m3(in[threadIdx.x]);
}
void full_fp8_e4m3(const std::vector<float> data, GPUMem<__nv_fp8_e4m3>& ptr) {
GPUMem<float> gpu_data(data.size());
gpu_data.toGPU(data);
convert_to_fp8_e4m3<<<1, data.size(), 1>>>(gpu_data, ptr);
checkStatus(cudaDeviceSynchronize());
}
__global__ void convert_from_fp8_e4m3(__nv_fp8_e4m3* in, float* out) {
out[threadIdx.x] = float(in[threadIdx.x]);
}
std::vector<float> get_fp8_e4m3(GPUMem<__nv_fp8_e4m3>& ptr) {
GPUMem<float> gpu_data(ptr.size());
convert_from_fp8_e4m3<<<1, ptr.size(), 1>>>(ptr, gpu_data);
return gpu_data.toCPU();
}
template<typename T>
void cublasLtMatmulDescSetAttribute(cublasLtMatmulDesc_t desc, cublasLtMatmulDescAttributes_t attr, const T value) {
checkStatus(cublasLtMatmulDescSetAttribute(desc, attr, &value, sizeof(value)));
}
void gemm_blaslt(GPUMem<__nv_fp8_e4m3>& A, GPUMem<__nv_fp8_e4m3>& B, GPUMem<__nv_fp8_e4m3>& D, int m, int n, int k, GPUMem<float>& weights) {
static GPUMem<char> workspace(32*1024*1024);
cublasLtHandle_t handle = nullptr;
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL;
cublasLtMatmulPreference_t preference = NULL;
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
checkStatus(cublasLtCreate(&handle));
checkStatus(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N);
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, weights);
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, weights+1);
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, weights+2);
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, weights+3);
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, weights+4);
checkStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, m, k, m));
checkStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, k, n, k));
checkStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, m, n, m));
checkStatus(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_8F_E4M3, m, n, m));
cublasLtMatmulDescSetAttribute<int8_t>(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, 0);
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_DEFAULT);
checkStatus(cublasLtMatmulPreferenceCreate(&preference));
auto workspaceSize = workspace.size();
checkStatus(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
checkStatus(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult,
&returnedResults));
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
float alpha = 1.0;
float beta = 0.0;
checkStatus(cublasLtMatmul(handle,
operationDesc,
&alpha,
A,
Adesc,
B,
Bdesc,
&beta,
nullptr,
Cdesc,
D,
Ddesc,
&heuristicResult.algo,
workspace,
workspaceSize,
0));
}
int main() {
constexpr unsigned n = 16;
GPUMem<__nv_fp8_e4m3> A(n*n);
GPUMem<__nv_fp8_e4m3> B(n*n);
GPUMem<__nv_fp8_e4m3> C(n*n);
GPUMem<float> weights(5);
{
// Init
std::vector<float> adat(n*n);
std::iota(adat.begin(), adat.end(), 1.0);
full_fp8_e4m3(adat, A);
std::vector<float> bdat(n*n, .2);
full_fp8_e4m3(bdat, B);
weights.toGPU(std::vector<float> (weights.size(), 1.0));
}
gemm_blaslt(A, B, C, n, n, n, weights);
get_fp8_e4m3(C);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment