Skip to content

Instantly share code, notes, and snippets.

@zhuhaozhe
Created February 6, 2024 05:09
Show Gist options
  • Save zhuhaozhe/83481575093357ca603cfdd66d3fd37f to your computer and use it in GitHub Desktop.
Save zhuhaozhe/83481575093357ca603cfdd66d3fd37f to your computer and use it in GitHub Desktop.
cpp_fused_mul_sum_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'float*'], '''
#include "/tmp/torchinductor_root/lg/clghje745biezhrbrw5fghxqjaj76ck5jms7466s4ax63eruswf5.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
float* out_ptr0)
{
{
{
float tmp_acc0 = 0;
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
float tmp_acc0_arr[64];
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0_arr[tid] = 0;
}
at::vec::Vectorized<float> tmp_acc0_vec_arr[64];
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0);
}
#pragma omp parallel num_threads(64)
{
int tid = omp_get_thread_num();
float tmp_acc0_local = 0;
at::vec::Vectorized<float> tmp_acc0_vec_local = at::vec::Vectorized<float>(0);
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(48L); x0+=static_cast<long>(16L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(48L); x1+=static_cast<long>(16L))
{
float tmp0[16*16] __attribute__ ((aligned (16)));
at::vec::transpose_mxn<float,16,16>(in_ptr0 + static_cast<long>(x1 + (50L*x0)), static_cast<long>(50L), tmp0, 16);
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*x1_inner));
auto tmp2 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0 + (50L*x1) + (50L*x1_inner)));
auto tmp3 = tmp1 * tmp2;
tmp_acc0_vec_local = tmp_acc0_vec_local + tmp3;
}
}
for(long x1=static_cast<long>(48L); x1<static_cast<long>(50L); x1+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x0_inner = 0; x0_inner < 16; x0_inner++)
{
tmpbuf[x0_inner] = in_ptr0[static_cast<long>(x1 + (50L*x0) + (50L*x0_inner))];
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0 + (50L*x1)));
auto tmp2 = tmp0 * tmp1;
tmp_acc0_vec_local = tmp_acc0_vec_local + tmp2;
}
}
#pragma omp for
for(long x0=static_cast<long>(48L); x0<static_cast<long>(50L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(50L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x1 + (50L*x0))];
auto tmp1 = in_ptr1[static_cast<long>(x0 + (50L*x1))];
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
tmp_acc0_local = tmp_acc0_local + tmp2;
}
}
tmp_acc0_arr[tid] = tmp_acc0_local;
tmp_acc0_vec_arr[tid] = tmp_acc0_vec_local;
}
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid];
}
for (int tid = 0; tid < 64; tid++)
{
tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid];
}
tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0);
}
}
}
''')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment