Skip to content

Instantly share code, notes, and snippets.

@Laurawly
Created February 29, 2020 02:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Laurawly/2ca295db566bb59966cae73fa0b051c2 to your computer and use it in GitHub Desktop.
Save Laurawly/2ca295db566bb59966cae73fa0b051c2 to your computer and use it in GitHub Desktop.
#include <mma.h>
extern "C" __global__ void default_function_kernel0( int* __restrict__ A, int* __restrict__ B, int* __restrict__ compute) {
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 8, 8, 32, int> compute_wmma_accumulator[2];
__shared__ int A_shared[512];
__shared__ int B_shared[512];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 8, 8, 32, nvcuda::wmma::experimental::precision::s4, nvcuda::wmma::row_major> A_shared_wmma_matrix_a[2];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 8, 8, 32, nvcuda::wmma::experimental::precision::s4, nvcuda::wmma::col_major> B_shared_wmma_matrix_b[1];
for (int i_c_init = 0; i_c_init < 1; ++i_c_init) {
(void)nvcuda::wmma::fill_fragment(compute_wmma_accumulator[i_c_init], 0.000000e+00f);
}
for (int k1_outer = 0; k1_outer < 2; ++k1_outer) {
__syncthreads();
for (int ax0_inner = 0; ax0_inner < 2; ++ax0_inner) {
for (int ax2_ax3_fused_inner = 0; ax2_ax3_fused_inner < 8; ++ax2_ax3_fused_inner) {
A_shared[((((((((int)threadIdx.y) * 2048) + (ax0_inner * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner)) / 8] = A[((((((((((int)blockIdx.x) * 8192) + (((int)threadIdx.y) * 4096)) + (ax0_inner * 2048)) + (k1_outer * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner)) / 8];
}
}
for (int ax0_inner1 = 0; ax0_inner1 < 2; ++ax0_inner1) {
for (int ax2_ax3_fused_inner1 = 0; ax2_ax3_fused_inner1 < 8; ++ax2_ax3_fused_inner1) {
B_shared[((((((((int)threadIdx.y) * 2048) + (ax0_inner1 * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner1)) / 8] = B[((((((((((int)blockIdx.y) * 8192) + (((int)threadIdx.y) * 4096)) + (ax0_inner1 * 2048)) + (k1_outer * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner1)) / 8];
}
}
__syncthreads();
for (int k1_inner = 0; k1_inner < 4; ++k1_inner) {
for (int ax0 = 0; ax0 < 2; ++ax0) {
(void)nvcuda::wmma::load_matrix_sync(A_shared_wmma_matrix_a[ax0], ((int *)A_shared + (((((int)threadIdx.y) * 2048) + (ax0 * 1024)) + (k1_inner * 256))/ 8), 32);
}
(void)nvcuda::wmma::load_matrix_sync(B_shared_wmma_matrix_b[0], ((int *)B_shared + ((((int)threadIdx.z) * 1024) + (k1_inner * 256)) / 8), 32);
for (int i_c = 0; i_c < 2; ++i_c) {
(void)nvcuda::wmma::mma_sync(compute_wmma_accumulator[i_c], A_shared_wmma_matrix_a[i_c], B_shared_wmma_matrix_b[0], compute_wmma_accumulator[i_c]);
}
}
}
for (int i_inner = 0; i_inner < 2; ++i_inner) {
(void)nvcuda::wmma::store_matrix_sync(((int *)compute + (((((((int)blockIdx.x) * 2048) + (((int)threadIdx.y) * 1024)) + (i_inner * 512)) + (((int)blockIdx.y) * 256)) + (((int)threadIdx.z) * 64))), compute_wmma_accumulator[i_inner], 8, nvcuda::wmma::mem_row_major);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment