Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 8, 2023 08:17
Show Gist options
  • Save tiandiao123/0706e748bff9675724e667210b13b8d3 to your computer and use it in GitHub Desktop.
Save tiandiao123/0706e748bff9675724e667210b13b8d3 to your computer and use it in GitHub Desktop.
#include <torch/extension.h>
#include <cutlass/gemm/gemm.h>
#include <cutlass/epilogue/thread/linear_combination.h>
torch::Tensor bmm_fp16_fp16_f32(torch::Tensor A, torch::Tensor B, float alpha) {
int batch_size = A.size(0);
int M = A.size(1);
int N = B.size(1);
int K = A.size(2);
auto C = torch::empty({batch_size, M, N}, torch::dtype(torch::kFloat32).device(A.device()));
int lda = A.size(2);
int ldb = B.size(2);
int ldc = C.size(2);
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementOutput = float;
using ElementInputA = at::Half;
using ElementInputB = at::Half;
using ElementAccumulator = float; // It's common to use float32 as accumulator for fp16 operations
using ElementComputeEpilogue = float;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>;
using Gemm = cutlass::gemm::device::GemmBatched<
ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, ElementOutput,
LayoutOutput, ElementAccumulator, cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp>;
long long int batch_stride_A = M * K;
long long int batch_stride_B = N * K;
long long int batch_stride_C = M * N;
Gemm gemm_op;
typename Gemm::Arguments arguments{
{M, N, K}, {A.data_ptr<ElementInputA>(), lda},
batch_stride_A, {B.data_ptr<ElementInputB>(), ldb},
batch_stride_B, {C.data_ptr<ElementOutput>(), ldc},
batch_stride_C, {C.data_ptr<ElementOutput>(), ldc},
batch_stride_C, {alpha, 0},
batch_size};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
return C;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment