Created
February 29, 2020 02:22
-
-
Save Laurawly/2ca295db566bb59966cae73fa0b051c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <mma.h> | |
extern "C" __global__ void default_function_kernel0( int* __restrict__ A, int* __restrict__ B, int* __restrict__ compute) { | |
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 8, 8, 32, int> compute_wmma_accumulator[2]; | |
__shared__ int A_shared[512]; | |
__shared__ int B_shared[512]; | |
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 8, 8, 32, nvcuda::wmma::experimental::precision::s4, nvcuda::wmma::row_major> A_shared_wmma_matrix_a[2]; | |
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 8, 8, 32, nvcuda::wmma::experimental::precision::s4, nvcuda::wmma::col_major> B_shared_wmma_matrix_b[1]; | |
for (int i_c_init = 0; i_c_init < 1; ++i_c_init) { | |
(void)nvcuda::wmma::fill_fragment(compute_wmma_accumulator[i_c_init], 0.000000e+00f); | |
} | |
for (int k1_outer = 0; k1_outer < 2; ++k1_outer) { | |
__syncthreads(); | |
for (int ax0_inner = 0; ax0_inner < 2; ++ax0_inner) { | |
for (int ax2_ax3_fused_inner = 0; ax2_ax3_fused_inner < 8; ++ax2_ax3_fused_inner) { | |
A_shared[((((((((int)threadIdx.y) * 2048) + (ax0_inner * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner)) / 8] = A[((((((((((int)blockIdx.x) * 8192) + (((int)threadIdx.y) * 4096)) + (ax0_inner * 2048)) + (k1_outer * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner)) / 8]; | |
} | |
} | |
for (int ax0_inner1 = 0; ax0_inner1 < 2; ++ax0_inner1) { | |
for (int ax2_ax3_fused_inner1 = 0; ax2_ax3_fused_inner1 < 8; ++ax2_ax3_fused_inner1) { | |
B_shared[((((((((int)threadIdx.y) * 2048) + (ax0_inner1 * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner1)) / 8] = B[((((((((((int)blockIdx.y) * 8192) + (((int)threadIdx.y) * 4096)) + (ax0_inner1 * 2048)) + (k1_outer * 1024)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + ax2_ax3_fused_inner1)) / 8]; | |
} | |
} | |
__syncthreads(); | |
for (int k1_inner = 0; k1_inner < 4; ++k1_inner) { | |
for (int ax0 = 0; ax0 < 2; ++ax0) { | |
(void)nvcuda::wmma::load_matrix_sync(A_shared_wmma_matrix_a[ax0], ((int *)A_shared + (((((int)threadIdx.y) * 2048) + (ax0 * 1024)) + (k1_inner * 256))/ 8), 32); | |
} | |
(void)nvcuda::wmma::load_matrix_sync(B_shared_wmma_matrix_b[0], ((int *)B_shared + ((((int)threadIdx.z) * 1024) + (k1_inner * 256)) / 8), 32); | |
for (int i_c = 0; i_c < 2; ++i_c) { | |
(void)nvcuda::wmma::mma_sync(compute_wmma_accumulator[i_c], A_shared_wmma_matrix_a[i_c], B_shared_wmma_matrix_b[0], compute_wmma_accumulator[i_c]); | |
} | |
} | |
} | |
for (int i_inner = 0; i_inner < 2; ++i_inner) { | |
(void)nvcuda::wmma::store_matrix_sync(((int *)compute + (((((((int)blockIdx.x) * 2048) + (((int)threadIdx.y) * 1024)) + (i_inner * 512)) + (((int)blockIdx.y) * 256)) + (((int)threadIdx.z) * 64))), compute_wmma_accumulator[i_inner], 8, nvcuda::wmma::mem_row_major); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment