Created
August 6, 2023 21:31
-
-
Save malfet/d9aaf3faf8b62e073f963085aa7d629b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// My attempt at FP8 matmul implementation | |
#include <iostream> | |
#include <vector> | |
#include <numeric> | |
#include <cublasLt.h> | |
#include <cuda_fp8.h> | |
#include <stdio.h> | |
void checkStatus(cublasStatus_t status) { | |
if (status) | |
throw std::runtime_error("cublasLt call failed " + std::to_string(status)); | |
} | |
void checkStatus(cudaError_t status) { | |
if (status) | |
throw std::runtime_error("cudart call failed " + std::to_string(status)); | |
} | |
template<typename T> | |
struct GPUMem { | |
GPUMem(size_t num_): num(num_) { | |
checkStatus(cudaMalloc(&ptr, sizeof(T)*num)); | |
} | |
GPUMem(const GPUMem&) = delete; // delete copy constructor | |
~GPUMem() { | |
auto rc = cudaFree(ptr); | |
if (rc) { | |
std::cerr << "Failed to free GPU mem " + rc << std::endl; | |
} | |
} | |
void toGPU(const std::vector<T>& data) { | |
checkStatus(cudaMemcpy(ptr, data.data(), sizeof(T)*std::min(data.size(), num), cudaMemcpyHostToDevice)); | |
} | |
std::vector<T> toCPU() const { | |
std::vector<T> data(num); | |
checkStatus(cudaMemcpy(data.data(), ptr, sizeof(T)*num, cudaMemcpyDeviceToHost)); | |
return data; | |
} | |
operator T*() const { return ptr; } | |
T* data() const { return ptr; } | |
size_t size() const { return num; } | |
private: | |
T *ptr; | |
size_t num; | |
}; | |
__global__ void convert_to_fp8_e4m3(float* in, __nv_fp8_e4m3* out) { | |
out[threadIdx.x] = __nv_fp8_e4m3(in[threadIdx.x]); | |
} | |
void full_fp8_e4m3(const std::vector<float> data, GPUMem<__nv_fp8_e4m3>& ptr) { | |
GPUMem<float> gpu_data(data.size()); | |
gpu_data.toGPU(data); | |
convert_to_fp8_e4m3<<<1, data.size(), 1>>>(gpu_data, ptr); | |
checkStatus(cudaDeviceSynchronize()); | |
} | |
__global__ void convert_from_fp8_e4m3(__nv_fp8_e4m3* in, float* out) { | |
out[threadIdx.x] = float(in[threadIdx.x]); | |
} | |
std::vector<float> get_fp8_e4m3(GPUMem<__nv_fp8_e4m3>& ptr) { | |
GPUMem<float> gpu_data(ptr.size()); | |
convert_from_fp8_e4m3<<<1, ptr.size(), 1>>>(ptr, gpu_data); | |
return gpu_data.toCPU(); | |
} | |
template<typename T> | |
void cublasLtMatmulDescSetAttribute(cublasLtMatmulDesc_t desc, cublasLtMatmulDescAttributes_t attr, const T value) { | |
checkStatus(cublasLtMatmulDescSetAttribute(desc, attr, &value, sizeof(value))); | |
} | |
void gemm_blaslt(GPUMem<__nv_fp8_e4m3>& A, GPUMem<__nv_fp8_e4m3>& B, GPUMem<__nv_fp8_e4m3>& D, int m, int n, int k, GPUMem<float>& weights) { | |
static GPUMem<char> workspace(32*1024*1024); | |
cublasLtHandle_t handle = nullptr; | |
cublasLtMatmulDesc_t operationDesc = nullptr; | |
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; | |
cublasLtMatmulPreference_t preference = NULL; | |
int returnedResults = 0; | |
cublasLtMatmulHeuristicResult_t heuristicResult = {}; | |
checkStatus(cublasLtCreate(&handle)); | |
checkStatus(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); | |
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); | |
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); | |
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, weights); | |
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, weights+1); | |
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, weights+2); | |
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, weights+3); | |
cublasLtMatmulDescSetAttribute<float*>(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, weights+4); | |
checkStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, m, k, m)); | |
checkStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, k, n, k)); | |
checkStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, m, n, m)); | |
checkStatus(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_8F_E4M3, m, n, m)); | |
cublasLtMatmulDescSetAttribute<int8_t>(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, 0); | |
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_DEFAULT); | |
checkStatus(cublasLtMatmulPreferenceCreate(&preference)); | |
auto workspaceSize = workspace.size(); | |
checkStatus(cublasLtMatmulPreferenceSetAttribute( | |
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, | |
&workspaceSize, sizeof(workspaceSize))); | |
checkStatus(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, | |
Ddesc, preference, 1, &heuristicResult, | |
&returnedResults)); | |
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); | |
float alpha = 1.0; | |
float beta = 0.0; | |
checkStatus(cublasLtMatmul(handle, | |
operationDesc, | |
&alpha, | |
A, | |
Adesc, | |
B, | |
Bdesc, | |
&beta, | |
nullptr, | |
Cdesc, | |
D, | |
Ddesc, | |
&heuristicResult.algo, | |
workspace, | |
workspaceSize, | |
0)); | |
} | |
int main() { | |
constexpr unsigned n = 16; | |
GPUMem<__nv_fp8_e4m3> A(n*n); | |
GPUMem<__nv_fp8_e4m3> B(n*n); | |
GPUMem<__nv_fp8_e4m3> C(n*n); | |
GPUMem<float> weights(5); | |
{ | |
// Init | |
std::vector<float> adat(n*n); | |
std::iota(adat.begin(), adat.end(), 1.0); | |
full_fp8_e4m3(adat, A); | |
std::vector<float> bdat(n*n, .2); | |
full_fp8_e4m3(bdat, B); | |
weights.toGPU(std::vector<float> (weights.size(), 1.0)); | |
} | |
gemm_blaslt(A, B, C, n, n, n, weights); | |
get_fp8_e4m3(C); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment