Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Last active August 8, 2023 13:42
Show Gist options
  • Save tiandiao123/d3274b63b7c4f9378d96f11fcd2aced5 to your computer and use it in GitHub Desktop.
Save tiandiao123/d3274b63b7c4f9378d96f11fcd2aced5 to your computer and use it in GitHub Desktop.
#include <cublas_v2.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include <iostream>
#include <stdexcept>
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = at::Half; // <- data type of elements in input matrix A
using ElementInputB = at::Half; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>;
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- This is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Put all the created template variables to create GemmSplitKParallel template variable
using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp>;
// this function currently only works for A100 GPU
torch::Tensor bmm_fp16_cutlass(torch::Tensor& A, torch::Tensor& B, torch::Tensor& C, float alpha) {
const int length_m = A.size(0);
const int length_k = A.size(1);
const int length_n = B.size(1);
auto D = torch::empty({length_m, length_n}, torch::dtype(torch::kFloat32).device(A.device()));
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
int split_k_slices = 16;
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
A.data_ptr<ElementInputA>(), // <- reference to matrix A on device
B.data_ptr<ElementInputB>(), // <- reference to matrix B on device
C.data_ptr<ElementOutput>(), // <- reference to matrix C on device
D.data_ptr<ElementOutput>(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
return D;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment