Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 8, 2023 08:15
Show Gist options
  • Save tiandiao123/5ff3b7f06e6de27b6374718f7fce5e25 to your computer and use it in GitHub Desktop.
Save tiandiao123/5ff3b7f06e6de27b6374718f7fce5e25 to your computer and use it in GitHub Desktop.
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// Define half precision type for CUTLASS
using cutlass::half_t;
bool gemm_fp16_tensorcore(
int length_m,
int length_n,
int length_k,
cutlass::HostTensor<half_t, cutlass::layout::RowMajor>& tensor_a,
cutlass::HostTensor<half_t, cutlass::layout::ColumnMajor>& tensor_b,
cutlass::HostTensor<half_t, cutlass::layout::RowMajor>& tensor_c) {
// Define data types and layouts
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = half_t;
using ElementInputB = half_t;
using ElementOutput = half_t;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 256, 64>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>;
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>;
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
constexpr int NumStages = 2;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
// Create problem size
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize alpha and beta for computation
ElementComputeEpilogue alpha = 1.0f;
ElementComputeEpilogue beta = 0.0f;
int split_k_slices = 1;
typename Gemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(), // Use tensor C for both input and output
{alpha, beta},
split_k_slices};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return false;
}
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return false;
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
return false;
}
cudaDeviceSynchronize();
tensor_c.sync_host();
return true;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment