Skip to content

Instantly share code, notes, and snippets.

@kadeng
Created December 20, 2023 17:53
Show Gist options
  • Save kadeng/8fa35f5e42ed111f8de8d8623f16ec88 to your computer and use it in GitHub Desktop.
Save kadeng/8fa35f5e42ed111f8de8d8623f16ec88 to your computer and use it in GitHub Desktop.
Cutlass Error repro cases
#!/bin/bash
# Change the environment variables to point to Cutlass and CUDA Toolkit and run this,
# passing any of the standalone repro_N.cu files as argument. It will compile and run the
# example.
set -x
export REPRO_CUTLASS_PATH=/home/klondenberg/github/pytorch/pytorch/third_party/cutlass
export REPRO_CUDA_PATH=/home/klondenberg/local/cuda121
$REPRO_CUDA_PATH/bin/nvcc -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -Xcompiler=-fPIC --use_fast_math -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I${REPRO_CUTLASS_PATH}/include -I${REPRO_CUTLASS_PATH}/tools/library/include -I${REPRO_CUTLASS_PATH}/tools/library/src -I${REPRO_CUTLASS_PATH}/tools/util/include -L${REPRO_CUDA_PATH}/lib64 -L${REPRO_CUDA_PATH}/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -DNDEBUG -DCUTLASS_DEBUG_TRACE_LEVEL=1 -o "${@}.exe" "$@"
"./${@}.exe"
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using bfloat16 = nv_bfloat16;
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecializedCooperative;
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = float;
using ElementD = cutlass::half_t;
using ElementC = cutlass::half_t;
using TileShapeMNK = cute::Shape<cute::_128, cute::_256, cute::_64>;
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;
using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
using EVT_expr_1 = ADDMM_EVT;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>;
;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
float, float,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
EpilogueScheduleType,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue_functor
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_128, cute::_256, cute::_64>,
cute::Shape<cute::_2,cute::_1,cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_epilogue,
cutlass::gemm::StreamKScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma :
public cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
int64_t B = 1;
int64_t M = 1024L;
int64_t K = 256L;
int64_t N = 109760L;
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts
// for now we just pick the SM count of the first GPU
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(N),
static_cast<coord_t>(M),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(W), // ElementA const* ptr_A
{
256L /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(X), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
1024L /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc)
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta
{}, // leaf op+args : C
{ // binary op : alpha * acc
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha
{}, // leaf op+args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
} // end ternary op
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
(cutlass::half_t*)(Bias), // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
0L /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
cute::Int<1>{} /* stride_y0 */,
109760L /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
{
if (!X) {
int64_t X_size = 262144L;
if (X_size > 0) {
throw std::runtime_error("input X is null but size is not 0!");
}
}
}
{
if (!W) {
int64_t W_size = 28098560L;
if (W_size > 0) {
throw std::runtime_error("input W is null but size is not 0!");
}
}
}
{
if (!Bias) {
int64_t Bias_size = 112394240L;
if (Bias_size > 0) {
throw std::runtime_error("input Bias is null but size is not 0!");
}
}
}
{
if (!Y) {
int64_t Y_size = 112394240L;
if (Y_size > 0) {
throw std::runtime_error("input Y is null but size is not 0!");
}
}
}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed, int repetitions) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t; // may not be void
using ElementD = cutlass::half_t;
cutlass::DeviceAllocation<ElementA> X_data(262144);
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data(28098560);
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data(109760);
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data(112394240);
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl;
workspace_size_ptr = nullptr;
for (int i=0; i<repetitions; i++) {
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
// warmup
run_standalone(1, 2);
// repeat
return run_standalone(2, 10);
}
#endif
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using bfloat16 = nv_bfloat16;
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = float;
using ElementD = cutlass::half_t;
using ElementC = cutlass::half_t;
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>;
using ClusterShapeMNK = cute::Shape<cute::_1,cute::_1,cute::_1>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;
using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
using EVT_expr_1 = ADDMM_EVT;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>;
;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
float, float,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
EpilogueScheduleType,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
float,
cute::Shape<cute::_64, cute::_32, cute::_32>,
cute::Shape<cute::_1,cute::_1,cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_epilogue,
cutlass::gemm::PersistentScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma :
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cuda_cutlass_gemm_0(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
int64_t B = 1;
int64_t M = 1024L;
int64_t K = 5952L;
int64_t N = 1024L;
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts
// for now we just pick the SM count of the first GPU
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(N),
static_cast<coord_t>(M),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(W), // ElementA const* ptr_A
{
5952L /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(X), // ElementB const* ptr_B
{
5952L /* stride_w1 */,
cute::Int<1>{} /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc)
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta
{}, // leaf op+args : C
{ // binary op : alpha * acc
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha
{}, // leaf op+args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
} // end ternary op
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
(cutlass::half_t*)(Bias), // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
0L /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
cute::Int<1>{} /* stride_y0 */,
1024L /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_1x1x1_0_tnn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
{
if (!X) {
int64_t X_size = 6094848L;
if (X_size > 0) {
throw std::runtime_error("input X is null but size is not 0!");
}
}
}
{
if (!W) {
int64_t W_size = 6094848L;
if (W_size > 0) {
throw std::runtime_error("input W is null but size is not 0!");
}
}
}
{
if (!Bias) {
int64_t Bias_size = 1048576L;
if (Bias_size > 0) {
throw std::runtime_error("input Bias is null but size is not 0!");
}
}
}
{
if (!Y) {
int64_t Y_size = 1048576L;
if (Y_size > 0) {
throw std::runtime_error("input Y is null but size is not 0!");
}
}
}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed, int repetitions) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t; // may not be void
using ElementD = cutlass::half_t;
cutlass::DeviceAllocation<ElementA> X_data(6094848);
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data(6094848);
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data(1024);
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data(1048576);
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl;
workspace_size_ptr = nullptr;
for (int i=0; i<repetitions; i++) {
cuda_cutlass_gemm_0(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
// warmup
run_standalone(1, 2);
// repeat
return run_standalone(2, 10);
}
#endif
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using bfloat16 = nv_bfloat16;
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = float;
using ElementD = cutlass::half_t;
using ElementC = cutlass::half_t;
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>;
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;
using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
using EVT_expr_1 = ADDMM_EVT;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>;
;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
float, float,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
EpilogueScheduleType,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_64, cute::_32, cute::_32>,
cute::Shape<cute::_2,cute::_1,cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue,
cutlass::gemm::PersistentScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma :
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
int64_t B = 1;
int64_t M = 1024L;
int64_t K = 256L;
int64_t N = 109760L;
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts
// for now we just pick the SM count of the first GPU
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(N),
static_cast<coord_t>(M),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(W), // ElementA const* ptr_A
{
256L /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(X), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
1024L /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc)
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta
{}, // leaf op+args : C
{ // binary op : alpha * acc
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha
{}, // leaf op+args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
} // end ternary op
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
(cutlass::half_t*)(Bias), // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
0L /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
cute::Int<1>{} /* stride_y0 */,
109760L /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
{
if (!X) {
int64_t X_size = 262144L;
if (X_size > 0) {
throw std::runtime_error("input X is null but size is not 0!");
}
}
}
{
if (!W) {
int64_t W_size = 28098560L;
if (W_size > 0) {
throw std::runtime_error("input W is null but size is not 0!");
}
}
}
{
if (!Bias) {
int64_t Bias_size = 112394240L;
if (Bias_size > 0) {
throw std::runtime_error("input Bias is null but size is not 0!");
}
}
}
{
if (!Y) {
int64_t Y_size = 112394240L;
if (Y_size > 0) {
throw std::runtime_error("input Y is null but size is not 0!");
}
}
}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed, int repetitions) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t; // may not be void
using ElementD = cutlass::half_t;
cutlass::DeviceAllocation<ElementA> X_data(262144);
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data(28098560);
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data(109760);
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data(112394240);
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl;
workspace_size_ptr = nullptr;
for (int i=0; i<repetitions; i++) {
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
// warmup
run_standalone(1, 2);
// repeat
return run_standalone(2, 10);
}
#endif
#include <exception>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
// We compile all models with -fvisibility=hidden. Any symbols that need to be
// exposed in the final shared library must be declared with PT_EXPORT to make
// them visible.
#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++)
#define PT_EXPORT __attribute__((__visibility__("default")))
#else
#ifdef _WIN32
#define PT_EXPORT __declspec(dllexport)
#else
#define PT_EXPORT
#endif
#endif
using bfloat16 = nv_bfloat16;
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \
cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \
throw std::runtime_error(msg); \
} \
}
// Used as pass-through functor in EVT just for type casting / rounding
template <typename T>
struct identity_op {
CUTLASS_HOST_DEVICE
T operator()(T val) const { return val; }
};
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
static_assert(cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecialized> ||
cute::is_same_v<EpilogueScheduleType, cutlass::epilogue::TmaWarpSpecializedCooperative>,
"Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue");
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = float;
using ElementD = cutlass::half_t;
using ElementC = cutlass::half_t;
using TileShapeMNK = cute::Shape<cute::_64, cute::_32, cute::_32>;
using ClusterShapeMNK = cute::Shape<cute::_2,cute::_1,cute::_1>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;
using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
using EVT_expr_1 = ADDMM_EVT;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor = cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,EVT_expr_1>;
;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
float, float,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
cutlass::half_t, cutlass::layout::ColumnMajor, 8,
EpilogueScheduleType,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue_functor
>::CollectiveOp;
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, cutlass::layout::RowMajor, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
cute::Shape<cute::_64, cute::_32, cute::_32>,
cute::Shape<cute::_2,cute::_1,cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong
>::CollectiveOp;
// Gemm operator cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_mainloop,
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_epilogue,
cutlass::gemm::PersistentScheduler>;
// Define named type
struct cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma :
public cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_base { };
using cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma>;
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT int cuda_cutlass_gemm_1(const half* Bias, const half* X, const half* W, half* Y, size_t* workspace_size, uint8_t* workspace, cudaStream_t stream) {
try {
int64_t B = 1;
int64_t M = 1024L;
int64_t K = 256L;
int64_t N = 109760L;
using ElementComputeEpilogue = cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
// @TODO kadeng: Add support for Multi-GPU machines with heterogeneous SM counts
// for now we just pick the SM count of the first GPU
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type::Arguments arguments;
// Initialize GemmUniversal3xInstance arguments.
arguments = {
cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
{
static_cast<coord_t>(N),
static_cast<coord_t>(M),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
(cutlass::half_t*)(W), // ElementA const* ptr_A
{
256L /* stride_x0 */,
cute::Int<1>{} /* stride_x1 */,
0 /* batch_stride_x */
}, // StrideA dA
(cutlass::half_t*)(X), // ElementB const* ptr_B
{
cute::Int<1>{} /* stride_w1 */,
1024L /* stride_w0 */,
0 /* batch_stride_w */
}, // StrideB dB
}, // MainloopArguments mainloop
// see https://tinyurl.com/4rk89z48
{
{
{ // ADDMM Arguments: ternary op : beta * C + (alpha * acc)
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : beta
{}, // leaf op+args : C
{ // binary op : alpha * acc
{{static_cast<ElementAcc>(1.000000)}}, // leaf op+args : alpha
{}, // leaf op+args : acc
{} // binary args : multiplies
}, // end binary op
{} // ternary args : multiply_add
} // end ternary op
}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
(cutlass::half_t*)(Bias), // ElementC const* ptr_C
{
cute::Int<1>{} /* stride_bias0 */,
0L /* stride_bias1 */,
0 /* batch_stride_bias */
}, // StrideC dC
(cutlass::half_t*)(Y), // ElementD const* ptr_D
{
cute::Int<1>{} /* stride_y0 */,
109760L /* stride_y1 */,
0 /* batch_stride_y */
}, // StrideD dD
}, // EpilogueArguments epilogue,
hw_info
};
cutlass3x_sm90_tensorop_s64x32x16gemm_f16_f16_f32_f16_f16_64x32x32_2x1x1_0_ttn_align8_warpspecialized_pingpong_epi_tma_device_type gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
{
if (!X) {
int64_t X_size = 262144L;
if (X_size > 0) {
throw std::runtime_error("input X is null but size is not 0!");
}
}
}
{
if (!W) {
int64_t W_size = 28098560L;
if (W_size > 0) {
throw std::runtime_error("input W is null but size is not 0!");
}
}
}
{
if (!Bias) {
int64_t Bias_size = 112394240L;
if (Bias_size > 0) {
throw std::runtime_error("input Bias is null but size is not 0!");
}
}
}
{
if (!Y) {
int64_t Y_size = 112394240L;
if (Y_size > 0) {
throw std::runtime_error("input Y is null but size is not 0!");
}
}
}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed, int repetitions) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t; // may not be void
using ElementD = cutlass::half_t;
cutlass::DeviceAllocation<ElementA> X_data(262144);
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data(28098560);
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data(109760);
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data(112394240);
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;" << std::endl;
workspace_size_ptr = nullptr;
for (int i=0; i<repetitions; i++) {
cuda_cutlass_gemm_1(((const half*)X_data.get()), ((const half*)W_data.get()), ((const half*)Bias_data.get()), ((half*)Y_data.get()), workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);;
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
// warmup
run_standalone(1, 2);
// repeat
return run_standalone(2, 10);
}
#endif
#!/bin/bash
# Change the environment variables to point to Cutlass and CUDA Toolkit and run this,
# passing any of the standalone repro_N.cu files as argument. It will compile and run the
# example.
# This will create a debug build and run it through compute-sanitizer
set -x
export REPRO_CUTLASS_PATH=/home/klondenberg/github/pytorch/pytorch/third_party/cutlass
export REPRO_CUDA_PATH=/home/klondenberg/local/cuda121
$REPRO_CUDA_PATH/bin/nvcc -g -G -t=0 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -w -gencode=arch=compute_90a,code=[sm_90a,compute_90a] -O1 -std=c++17 --expt-relaxed-constexpr -Xcompiler=-fPIC --use_fast_math -Xcompiler=-fno-strict-aliasing -Xcompiler -fvisibility=hidden -Xcompiler=-Wconversion -I${REPRO_CUTLASS_PATH}/include -I${REPRO_CUTLASS_PATH}/tools/library/include -I${REPRO_CUTLASS_PATH}/tools/library/src -I${REPRO_CUTLASS_PATH}/tools/util/include -L${REPRO_CUDA_PATH}/lib64 -L${REPRO_CUDA_PATH}/lib64/stubs -lcuda -lcudart -DGENERATE_STANDALONE_RUNNER -DNDEBUG -DCUTLASS_DEBUG_TRACE_LEVEL=1 -o "${@}.debug.exe" "$@"
"${REPRO_CUDA_PATH}/bin/compute-sanitizer" "./${@}.debug.exe"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment