Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 10, 2023 04:30
Show Gist options
  • Save tiandiao123/6c965e809f0e5ae8c58e0278bb330da9 to your computer and use it in GitHub Desktop.
Save tiandiao123/6c965e809f0e5ae8c58e0278bb330da9 to your computer and use it in GitHub Desktop.
#include <cublas_v2.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
#include <torch/types.h>
#include <c10/util/Half.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
#include <iostream>
#include <stdexcept>
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<256, 128, 64>; // <- threadblock tile M = 256, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- This is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 2;
// Put all the created template variables to create GemmSplitKParallel template variable
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
// this function currently only works for A100 GPU
void gemm_fp16_cutlass(torch::Tensor& A,
torch::Tensor& B,
torch::Tensor& C,
torch::Tensor& D,
float alpha_val,
float beta_val)
{
const int length_m = A.size(0);
const int length_k = A.size(1);
const int length_n = B.size(1);
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// // Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(alpha_val);
ElementComputeEpilogue beta = ElementComputeEpilogue(beta_val);
// Split K dimension into 1 partitions
int split_k_slices = 4;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
{reinterpret_cast<cutlass::half_t*>(A.data_ptr<at::Half>()), length_m}, // <- reference to matrix A on device
{reinterpret_cast<cutlass::half_t*>(B.data_ptr<at::Half>()), length_n},// <- reference to matrix B on device
{C.data_ptr<ElementOutput>(), length_m}, // <- reference to matrix C on device
{D.data_ptr<ElementOutput>(), length_m}, // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment