Last active
June 27, 2024 20:09
-
-
Save dfyz/e49664a6f3610f5d1fcf1539984b8a4f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// nvcc -O2 -std=c++17 -gencode=arch=compute_80,code=sm_80 -I .../cutlass/include -I .../cutlass/tools/util/include --expt-relaxed-constexpr -lcublas -o main main.cu | |
#include <cutlass/gemm/device/default_gemm_configuration.h> | |
#include <cutlass/layout/matrix.h> | |
#include <cutlass/numeric_types.h> | |
#include <cutlass/gemm/device/gemm.h> | |
#include <cutlass/gemm/device/gemm_grouped.h> | |
#include <cutlass/gemm/kernel/gemm_grouped.h> | |
#include <cutlass/gemm/kernel/default_gemm_grouped.h> | |
#include <cutlass/util/host_tensor.h> | |
#include <cuda_runtime.h> | |
#include <cublas_v2.h> | |
#include <array> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
#define CUDA_OR_DIE(call) { \ | |
cudaError_t status = call; \ | |
if (status != cudaSuccess) { \ | |
std::cerr << "CUDA Error at: " << __FILE__ << ":" << __LINE__ << std::endl; \ | |
std::cerr << " " << cudaGetErrorString(status) << std::endl; \ | |
exit(status); \ | |
} \ | |
} | |
#define CUBLAS_OR_DIE(call) { \ | |
cublasStatus_t status = call; \ | |
if (status != CUBLAS_STATUS_SUCCESS) { \ | |
std::cerr << "cuBLAS Error at: " << __FILE__ << ":" << __LINE__ << std::endl; \ | |
std::cerr << " " << status << std::endl; \ | |
exit(1); \ | |
} \ | |
} | |
using OpClass = cutlass::arch::OpClassTensorOp; | |
using Arch = cutlass::arch::Sm80; | |
using ElementA = cutlass::bfloat16_t; | |
using ElementB = cutlass::bfloat16_t; | |
using ElementC = cutlass::bfloat16_t; | |
using LayoutA = cutlass::layout::ColumnMajor; | |
using LayoutB = cutlass::layout::RowMajor; | |
using LayoutC = cutlass::layout::RowMajor; | |
using ElementAccumulator = float; | |
using Config = cutlass::gemm::device::DefaultGemmConfiguration< | |
OpClass, | |
Arch, | |
ElementA, | |
ElementB, | |
ElementC, | |
ElementAccumulator | |
>; | |
template <typename T> | |
T* CopyToDevice(const std::vector<T>& hostSrc) { | |
void* deviceDst; | |
const size_t size = hostSrc.size() * sizeof(T); | |
CUDA_OR_DIE(cudaMalloc(&deviceDst, size)); | |
CUDA_OR_DIE(cudaMemcpy(deviceDst, hostSrc.data(), size, cudaMemcpyHostToDevice)); | |
return (T*)deviceDst; | |
} | |
int main(int argc, char** argv) { | |
if (argc != 2) { | |
std::cerr << "Usage: " << argv[0] << " nop|cublas|cutlass|cublas_grouped|cutlass_grouped" << std::endl; | |
return 1; | |
} | |
constexpr int M = 4096; | |
constexpr int N = 14336; | |
cutlass::HostTensor<cutlass::bfloat16_t, cutlass::layout::RowMajor> out{{M, N}}; | |
auto* host_out = out.host_data(); | |
auto* device_out = out.device_data(); | |
memset(host_out, 1, sizeof(cutlass::bfloat16_t) * M * N); | |
out.sync_device(); | |
cudaStream_t stream; | |
CUDA_OR_DIE(cudaStreamCreate(&stream)); | |
constexpr size_t numKS = 8; | |
constexpr size_t totalK = 16384; | |
std::array<int, numKS> kS = {219, 2246, 5, 8103, 1, 1117, 4693, 0}; | |
void* a; | |
void* b; | |
void* c; | |
CUDA_OR_DIE(cudaMalloc(&a, sizeof(ElementA) * totalK * M)); | |
CUDA_OR_DIE(cudaMalloc(&b, sizeof(ElementB) * totalK * N)); | |
CUDA_OR_DIE(cudaMalloc(&c, sizeof(ElementC) * numKS * M * N)); | |
// The problem filling code is copied from the `grouped_gemm` library. | |
std::vector<int64_t> ldaHost(numKS), offsetsA(numKS); | |
std::vector<int64_t> ldbHost(numKS), offsetsB(numKS); | |
std::vector<int64_t> ldcHost(numKS), offsetsC(numKS); | |
int64_t elementsA = 0, elementsB = 0, elementsC = 0; | |
std::vector<ElementA *> ptrAHost(numKS); | |
std::vector<ElementB *> ptrBHost(numKS); | |
std::vector<ElementC *> ptrCHost(numKS); | |
std::vector<cutlass::gemm::GemmCoord> problem_sizes_host(numKS); | |
for (int i = 0; i < numKS; ++i) { | |
cutlass::gemm::GemmCoord problem{M, N, kS[i]}; | |
problem_sizes_host[i] = problem; | |
ldaHost[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); | |
ldbHost[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); | |
ldcHost[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); | |
offsetsA[i] = elementsA; | |
offsetsB[i] = elementsB; | |
offsetsC[i] = elementsC; | |
if (kS[i] > 0) { | |
ptrAHost[i] = (ElementA*)a + offsetsA[i]; | |
ptrBHost[i] = (ElementB*)b + offsetsB[i]; | |
ptrCHost[i] = (ElementC*)c + offsetsC[i]; | |
} else { | |
ptrAHost[i] = nullptr; | |
ptrBHost[i] = nullptr; | |
ptrCHost[i] = nullptr; | |
} | |
elementsA += problem.m() * problem.k(); | |
elementsB += problem.k() * problem.n(); | |
elementsC += problem.m() * problem.n(); | |
} | |
auto lda = CopyToDevice(ldaHost); | |
auto ldb = CopyToDevice(ldbHost); | |
auto ldc = CopyToDevice(ldcHost); | |
auto ptrA = CopyToDevice(ptrAHost); | |
auto ptrB = CopyToDevice(ptrBHost); | |
auto ptrC = CopyToDevice(ptrCHost); | |
std::string mode = argv[1]; | |
if (mode == "cublas") { | |
cublasHandle_t handle; | |
CUBLAS_OR_DIE(cublasCreate(&handle)); | |
CUBLAS_OR_DIE(cublasSetStream(handle, stream)); | |
float alpha = 1.0f; | |
float beta = 0.0f; | |
CUBLAS_OR_DIE(cublasGemmEx( | |
handle, | |
CUBLAS_OP_N, | |
CUBLAS_OP_T, | |
M, | |
N, | |
0, | |
&alpha, | |
nullptr, | |
CUDA_R_16BF, | |
M, | |
nullptr, | |
CUDA_R_16BF, | |
N, | |
&beta, | |
device_out, | |
CUDA_R_16BF, | |
M, | |
CUBLAS_COMPUTE_32F, | |
CUBLAS_GEMM_DEFAULT | |
)); | |
CUDA_OR_DIE(cudaStreamSynchronize(stream)); | |
CUBLAS_OR_DIE(cublasDestroy(handle)); | |
out.sync_host(); | |
} | |
else if (mode == "cutlass") { | |
using DeviceGemm = cutlass::gemm::device::Gemm< | |
ElementA, | |
LayoutA, | |
ElementB, | |
LayoutB, | |
ElementC, | |
LayoutC, | |
float, | |
OpClass, | |
Arch, | |
Config::ThreadblockShape, | |
Config::WarpShape, | |
Config::InstructionShape, | |
Config::EpilogueOutputOp, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, | |
Config::kStages | |
>; | |
DeviceGemm gemm; | |
auto status = gemm({ | |
{M, N, 0}, | |
{nullptr, M}, | |
{nullptr, N}, | |
{device_out, N}, | |
{device_out, N}, | |
{1.0f, 0.0f} | |
}, nullptr, stream); | |
if (status != cutlass::Status::kSuccess) { | |
std::cerr << "Failed to run GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl; | |
} | |
CUDA_OR_DIE(cudaStreamSynchronize(stream)); | |
out.sync_host(); | |
} else if (mode == "cublas_grouped") { | |
cublasHandle_t handle; | |
CUBLAS_OR_DIE(cublasCreate(&handle)); | |
CUBLAS_OR_DIE(cublasSetStream(handle, stream)); | |
std::vector<int> msHost(numKS), nsHost(numKS), ksHost(numKS); | |
for (size_t ii = 0; ii < numKS; ++ii) { | |
msHost[ii] = problem_sizes_host[ii].m(); | |
nsHost[ii] = problem_sizes_host[ii].n(); | |
ksHost[ii] = problem_sizes_host[ii].k(); | |
} | |
std::vector<cublasOperation_t> aOpsHost(numKS, CUBLAS_OP_N); | |
std::vector<cublasOperation_t> bOpsHost(numKS, CUBLAS_OP_T); | |
std::vector<cutlass::bfloat16_t> alphasHost(numKS, (cutlass::bfloat16_t)1.0f); | |
std::vector<cutlass::bfloat16_t> betasHost(numKS, (cutlass::bfloat16_t)0.0f); | |
std::vector<int> groupSizesHost(numKS, 1); | |
std::vector<int> ldaIntHost(ldaHost.begin(), ldaHost.end()); | |
std::vector<int> ldbIntHost(ldbHost.begin(), ldbHost.end()); | |
std::vector<int> ldcIntHost(ldcHost.begin(), ldcHost.end()); | |
CUBLAS_OR_DIE(cublasGemmGroupedBatchedEx( | |
handle, | |
aOpsHost.data(), | |
bOpsHost.data(), | |
msHost.data(), | |
nsHost.data(), | |
msHost.data(), | |
alphasHost.data(), | |
(const void *const *)ptrA, | |
CUDA_R_16BF, | |
ldaIntHost.data(), | |
(const void *const *)ptrB, | |
CUDA_R_16BF, | |
ldbIntHost.data(), | |
betasHost.data(), | |
(void *const *)ptrC, | |
CUDA_R_16BF, | |
ldcIntHost.data(), | |
numKS, | |
groupSizesHost.data(), | |
CUBLAS_COMPUTE_32F | |
)); | |
CUDA_OR_DIE(cudaStreamSynchronize(stream)); | |
CUBLAS_OR_DIE(cublasDestroy(handle)); | |
CUDA_OR_DIE(cudaMemcpy(host_out, (ElementC*)c + 7 * M * N, sizeof(cutlass::bfloat16_t) * M * N, cudaMemcpyDeviceToHost)); | |
} else if (mode == "cutlass_grouped") { | |
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | |
ElementA, | |
LayoutA, | |
cutlass::ComplexTransform::kNone, | |
Config::kAlignmentA, | |
ElementB, | |
LayoutB, | |
cutlass::ComplexTransform::kNone, | |
Config::kAlignmentB, | |
ElementC, | |
LayoutC, | |
float, | |
OpClass, | |
Arch, | |
Config::ThreadblockShape, | |
Config::WarpShape, | |
Config::InstructionShape, | |
Config::EpilogueOutputOp, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, | |
Config::kStages | |
>::GemmKernel; | |
using DeviceGemmGrouped = cutlass::gemm::device::GemmGrouped<GroupedGemmKernel>; | |
auto problemSizes = CopyToDevice(problem_sizes_host); | |
DeviceGemmGrouped gemm; | |
typename DeviceGemmGrouped::EpilogueOutputOp::Params epOp(/*alpha=*/1.0f, /*beta=*/0.0f); | |
auto status = gemm.initialize({ | |
problemSizes, | |
(int)numKS, | |
DeviceGemmGrouped::sufficient(), | |
epOp, | |
ptrA, ptrB, ptrC, ptrC, | |
lda, ldb, ldc, ldc, | |
nullptr, | |
}, nullptr); | |
if (status != cutlass::Status::kSuccess) { | |
std::cerr << "Failed to initialize grouped GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl; | |
} | |
status = gemm.run(stream); | |
if (status != cutlass::Status::kSuccess) { | |
std::cerr << "Failed to run grouped GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl; | |
} | |
CUDA_OR_DIE(cudaStreamSynchronize(stream)); | |
CUDA_OR_DIE(cudaMemcpy(host_out, (ElementC*)c + 7 * M * N, sizeof(cutlass::bfloat16_t) * M * N, cudaMemcpyDeviceToHost)); | |
} else { | |
if (mode != "nop") { | |
std::cerr << "Unknown mode: " << mode << std::endl; | |
return 1; | |
} | |
} | |
for (size_t ii = 0; ii < M * N; ++ii) { | |
if (host_out[ii] != 0.0f) { | |
std::cerr << "Mismatch at index " << ii << ": " << host_out[ii] << std::endl; | |
break; | |
} | |
} | |
CUDA_OR_DIE(cudaStreamDestroy(stream)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment