Skip to content

Instantly share code, notes, and snippets.

@dfyz
Last active June 27, 2024 20:09
Show Gist options
  • Save dfyz/e49664a6f3610f5d1fcf1539984b8a4f to your computer and use it in GitHub Desktop.
Save dfyz/e49664a6f3610f5d1fcf1539984b8a4f to your computer and use it in GitHub Desktop.
// nvcc -O2 -std=c++17 -gencode=arch=compute_80,code=sm_80 -I .../cutlass/include -I .../cutlass/tools/util/include --expt-relaxed-constexpr -lcublas -o main main.cu
#include <cutlass/gemm/device/default_gemm_configuration.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_grouped.h>
#include <cutlass/gemm/kernel/gemm_grouped.h>
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
#include <cutlass/util/host_tensor.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <array>
#include <iostream>
#include <string>
#include <vector>
#define CUDA_OR_DIE(call) { \
cudaError_t status = call; \
if (status != cudaSuccess) { \
std::cerr << "CUDA Error at: " << __FILE__ << ":" << __LINE__ << std::endl; \
std::cerr << " " << cudaGetErrorString(status) << std::endl; \
exit(status); \
} \
}
#define CUBLAS_OR_DIE(call) { \
cublasStatus_t status = call; \
if (status != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "cuBLAS Error at: " << __FILE__ << ":" << __LINE__ << std::endl; \
std::cerr << " " << status << std::endl; \
exit(1); \
} \
}
using OpClass = cutlass::arch::OpClassTensorOp;
using Arch = cutlass::arch::Sm80;
using ElementA = cutlass::bfloat16_t;
using ElementB = cutlass::bfloat16_t;
using ElementC = cutlass::bfloat16_t;
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = float;
using Config = cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
Arch,
ElementA,
ElementB,
ElementC,
ElementAccumulator
>;
template <typename T>
T* CopyToDevice(const std::vector<T>& hostSrc) {
void* deviceDst;
const size_t size = hostSrc.size() * sizeof(T);
CUDA_OR_DIE(cudaMalloc(&deviceDst, size));
CUDA_OR_DIE(cudaMemcpy(deviceDst, hostSrc.data(), size, cudaMemcpyHostToDevice));
return (T*)deviceDst;
}
int main(int argc, char** argv) {
if (argc != 2) {
std::cerr << "Usage: " << argv[0] << " nop|cublas|cutlass|cublas_grouped|cutlass_grouped" << std::endl;
return 1;
}
constexpr int M = 4096;
constexpr int N = 14336;
cutlass::HostTensor<cutlass::bfloat16_t, cutlass::layout::RowMajor> out{{M, N}};
auto* host_out = out.host_data();
auto* device_out = out.device_data();
memset(host_out, 1, sizeof(cutlass::bfloat16_t) * M * N);
out.sync_device();
cudaStream_t stream;
CUDA_OR_DIE(cudaStreamCreate(&stream));
constexpr size_t numKS = 8;
constexpr size_t totalK = 16384;
std::array<int, numKS> kS = {219, 2246, 5, 8103, 1, 1117, 4693, 0};
void* a;
void* b;
void* c;
CUDA_OR_DIE(cudaMalloc(&a, sizeof(ElementA) * totalK * M));
CUDA_OR_DIE(cudaMalloc(&b, sizeof(ElementB) * totalK * N));
CUDA_OR_DIE(cudaMalloc(&c, sizeof(ElementC) * numKS * M * N));
// The problem filling code is copied from the `grouped_gemm` library.
std::vector<int64_t> ldaHost(numKS), offsetsA(numKS);
std::vector<int64_t> ldbHost(numKS), offsetsB(numKS);
std::vector<int64_t> ldcHost(numKS), offsetsC(numKS);
int64_t elementsA = 0, elementsB = 0, elementsC = 0;
std::vector<ElementA *> ptrAHost(numKS);
std::vector<ElementB *> ptrBHost(numKS);
std::vector<ElementC *> ptrCHost(numKS);
std::vector<cutlass::gemm::GemmCoord> problem_sizes_host(numKS);
for (int i = 0; i < numKS; ++i) {
cutlass::gemm::GemmCoord problem{M, N, kS[i]};
problem_sizes_host[i] = problem;
ldaHost[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldbHost[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldcHost[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
offsetsA[i] = elementsA;
offsetsB[i] = elementsB;
offsetsC[i] = elementsC;
if (kS[i] > 0) {
ptrAHost[i] = (ElementA*)a + offsetsA[i];
ptrBHost[i] = (ElementB*)b + offsetsB[i];
ptrCHost[i] = (ElementC*)c + offsetsC[i];
} else {
ptrAHost[i] = nullptr;
ptrBHost[i] = nullptr;
ptrCHost[i] = nullptr;
}
elementsA += problem.m() * problem.k();
elementsB += problem.k() * problem.n();
elementsC += problem.m() * problem.n();
}
auto lda = CopyToDevice(ldaHost);
auto ldb = CopyToDevice(ldbHost);
auto ldc = CopyToDevice(ldcHost);
auto ptrA = CopyToDevice(ptrAHost);
auto ptrB = CopyToDevice(ptrBHost);
auto ptrC = CopyToDevice(ptrCHost);
std::string mode = argv[1];
if (mode == "cublas") {
cublasHandle_t handle;
CUBLAS_OR_DIE(cublasCreate(&handle));
CUBLAS_OR_DIE(cublasSetStream(handle, stream));
float alpha = 1.0f;
float beta = 0.0f;
CUBLAS_OR_DIE(cublasGemmEx(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
M,
N,
0,
&alpha,
nullptr,
CUDA_R_16BF,
M,
nullptr,
CUDA_R_16BF,
N,
&beta,
device_out,
CUDA_R_16BF,
M,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT
));
CUDA_OR_DIE(cudaStreamSynchronize(stream));
CUBLAS_OR_DIE(cublasDestroy(handle));
out.sync_host();
}
else if (mode == "cutlass") {
using DeviceGemm = cutlass::gemm::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
float,
OpClass,
Arch,
Config::ThreadblockShape,
Config::WarpShape,
Config::InstructionShape,
Config::EpilogueOutputOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
Config::kStages
>;
DeviceGemm gemm;
auto status = gemm({
{M, N, 0},
{nullptr, M},
{nullptr, N},
{device_out, N},
{device_out, N},
{1.0f, 0.0f}
}, nullptr, stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl;
}
CUDA_OR_DIE(cudaStreamSynchronize(stream));
out.sync_host();
} else if (mode == "cublas_grouped") {
cublasHandle_t handle;
CUBLAS_OR_DIE(cublasCreate(&handle));
CUBLAS_OR_DIE(cublasSetStream(handle, stream));
std::vector<int> msHost(numKS), nsHost(numKS), ksHost(numKS);
for (size_t ii = 0; ii < numKS; ++ii) {
msHost[ii] = problem_sizes_host[ii].m();
nsHost[ii] = problem_sizes_host[ii].n();
ksHost[ii] = problem_sizes_host[ii].k();
}
std::vector<cublasOperation_t> aOpsHost(numKS, CUBLAS_OP_N);
std::vector<cublasOperation_t> bOpsHost(numKS, CUBLAS_OP_T);
std::vector<cutlass::bfloat16_t> alphasHost(numKS, (cutlass::bfloat16_t)1.0f);
std::vector<cutlass::bfloat16_t> betasHost(numKS, (cutlass::bfloat16_t)0.0f);
std::vector<int> groupSizesHost(numKS, 1);
std::vector<int> ldaIntHost(ldaHost.begin(), ldaHost.end());
std::vector<int> ldbIntHost(ldbHost.begin(), ldbHost.end());
std::vector<int> ldcIntHost(ldcHost.begin(), ldcHost.end());
CUBLAS_OR_DIE(cublasGemmGroupedBatchedEx(
handle,
aOpsHost.data(),
bOpsHost.data(),
msHost.data(),
nsHost.data(),
msHost.data(),
alphasHost.data(),
(const void *const *)ptrA,
CUDA_R_16BF,
ldaIntHost.data(),
(const void *const *)ptrB,
CUDA_R_16BF,
ldbIntHost.data(),
betasHost.data(),
(void *const *)ptrC,
CUDA_R_16BF,
ldcIntHost.data(),
numKS,
groupSizesHost.data(),
CUBLAS_COMPUTE_32F
));
CUDA_OR_DIE(cudaStreamSynchronize(stream));
CUBLAS_OR_DIE(cublasDestroy(handle));
CUDA_OR_DIE(cudaMemcpy(host_out, (ElementC*)c + 7 * M * N, sizeof(cutlass::bfloat16_t) * M * N, cudaMemcpyDeviceToHost));
} else if (mode == "cutlass_grouped") {
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
ElementA,
LayoutA,
cutlass::ComplexTransform::kNone,
Config::kAlignmentA,
ElementB,
LayoutB,
cutlass::ComplexTransform::kNone,
Config::kAlignmentB,
ElementC,
LayoutC,
float,
OpClass,
Arch,
Config::ThreadblockShape,
Config::WarpShape,
Config::InstructionShape,
Config::EpilogueOutputOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
Config::kStages
>::GemmKernel;
using DeviceGemmGrouped = cutlass::gemm::device::GemmGrouped<GroupedGemmKernel>;
auto problemSizes = CopyToDevice(problem_sizes_host);
DeviceGemmGrouped gemm;
typename DeviceGemmGrouped::EpilogueOutputOp::Params epOp(/*alpha=*/1.0f, /*beta=*/0.0f);
auto status = gemm.initialize({
problemSizes,
(int)numKS,
DeviceGemmGrouped::sufficient(),
epOp,
ptrA, ptrB, ptrC, ptrC,
lda, ldb, ldc, ldc,
nullptr,
}, nullptr);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to initialize grouped GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl;
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Failed to run grouped GEMM: " << cutlass::cutlassGetStatusString(status) << std::endl;
}
CUDA_OR_DIE(cudaStreamSynchronize(stream));
CUDA_OR_DIE(cudaMemcpy(host_out, (ElementC*)c + 7 * M * N, sizeof(cutlass::bfloat16_t) * M * N, cudaMemcpyDeviceToHost));
} else {
if (mode != "nop") {
std::cerr << "Unknown mode: " << mode << std::endl;
return 1;
}
}
for (size_t ii = 0; ii < M * N; ++ii) {
if (host_out[ii] != 0.0f) {
std::cerr << "Mismatch at index " << ii << ": " << host_out[ii] << std::endl;
break;
}
}
CUDA_OR_DIE(cudaStreamDestroy(stream));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment