Created
February 6, 2024 05:09
-
-
Save zhuhaozhe/83481575093357ca603cfdd66d3fd37f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cpp_fused_mul_sum_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'float*'], ''' | |
#include "/tmp/torchinductor_root/lg/clghje745biezhrbrw5fghxqjaj76ck5jms7466s4ax63eruswf5.h" | |
extern "C" void kernel(const float* in_ptr0, | |
const float* in_ptr1, | |
float* out_ptr0) | |
{ | |
{ | |
{ | |
float tmp_acc0 = 0; | |
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0); | |
float tmp_acc0_arr[64]; | |
for (int tid = 0; tid < 64; tid++) | |
{ | |
tmp_acc0_arr[tid] = 0; | |
} | |
at::vec::Vectorized<float> tmp_acc0_vec_arr[64]; | |
for (int tid = 0; tid < 64; tid++) | |
{ | |
tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0); | |
} | |
#pragma omp parallel num_threads(64) | |
{ | |
int tid = omp_get_thread_num(); | |
float tmp_acc0_local = 0; | |
at::vec::Vectorized<float> tmp_acc0_vec_local = at::vec::Vectorized<float>(0); | |
#pragma omp for | |
for(long x0=static_cast<long>(0L); x0<static_cast<long>(48L); x0+=static_cast<long>(16L)) | |
{ | |
for(long x1=static_cast<long>(0L); x1<static_cast<long>(48L); x1+=static_cast<long>(16L)) | |
{ | |
float tmp0[16*16] __attribute__ ((aligned (16))); | |
at::vec::transpose_mxn<float,16,16>(in_ptr0 + static_cast<long>(x1 + (50L*x0)), static_cast<long>(50L), tmp0, 16); | |
for (long x1_inner = 0; x1_inner < 16; x1_inner++) | |
{ | |
auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*x1_inner)); | |
auto tmp2 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0 + (50L*x1) + (50L*x1_inner))); | |
auto tmp3 = tmp1 * tmp2; | |
tmp_acc0_vec_local = tmp_acc0_vec_local + tmp3; | |
} | |
} | |
for(long x1=static_cast<long>(48L); x1<static_cast<long>(50L); x1+=static_cast<long>(1L)) | |
{ | |
auto tmp0 = | |
[&] | |
{ | |
__at_align__ std::array<float, 16> tmpbuf; | |
#pragma GCC unroll 16 | |
for (long x0_inner = 0; x0_inner < 16; x0_inner++) | |
{ | |
tmpbuf[x0_inner] = in_ptr0[static_cast<long>(x1 + (50L*x0) + (50L*x0_inner))]; | |
} | |
return at::vec::Vectorized<float>::loadu(tmpbuf.data()); | |
} | |
() | |
; | |
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0 + (50L*x1))); | |
auto tmp2 = tmp0 * tmp1; | |
tmp_acc0_vec_local = tmp_acc0_vec_local + tmp2; | |
} | |
} | |
#pragma omp for | |
for(long x0=static_cast<long>(48L); x0<static_cast<long>(50L); x0+=static_cast<long>(1L)) | |
{ | |
for(long x1=static_cast<long>(0L); x1<static_cast<long>(50L); x1+=static_cast<long>(1L)) | |
{ | |
auto tmp0 = in_ptr0[static_cast<long>(x1 + (50L*x0))]; | |
auto tmp1 = in_ptr1[static_cast<long>(x0 + (50L*x1))]; | |
auto tmp2 = decltype(tmp0)(tmp0 * tmp1); | |
tmp_acc0_local = tmp_acc0_local + tmp2; | |
} | |
} | |
tmp_acc0_arr[tid] = tmp_acc0_local; | |
tmp_acc0_vec_arr[tid] = tmp_acc0_vec_local; | |
} | |
for (int tid = 0; tid < 64; tid++) | |
{ | |
tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid]; | |
} | |
for (int tid = 0; tid < 64; tid++) | |
{ | |
tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid]; | |
} | |
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec); | |
out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0); | |
} | |
} | |
} | |
''') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment