Skip to content

Instantly share code, notes, and snippets.

@ShigekiKarita
Forked from belltailjp/simd.cpp
Last active May 22, 2018 06:59
Show Gist options
  • Save ShigekiKarita/15092e7f2c3f96ba007a336dd11f36b3 to your computer and use it in GitHub Desktop.
Save ShigekiKarita/15092e7f2c3f96ba007a336dd11f36b3 to your computer and use it in GitHub Desktop.
SSE,AVX組み込み関数を用いたベクトルの内積計算高速化の実験コード
// origin https://gist.githubusercontent.com/belltailjp/4653695/raw/1cf8b5cbb6c3b4d4f9374b8b1ccae702867543ef/simd.cpp
#include <iostream>
#include <random>
#include <algorithm>
#include <xmmintrin.h>
#include <immintrin.h>
// #include <boost/format.hpp>
// #include <osakana/stopwatch.hpp>
template <typename T>
T dot_normal(const T *vec1, const T *vec2, unsigned n)
{
T sum = 0;
for(unsigned i = 0; i < n; ++i)
sum += vec1[i] * vec2[i];
return sum;
}
float dot_sse(const float *vec1, const float *vec2, unsigned n)
{
__m128 u = {0};
for (unsigned i = 0; i < n; i += 4)
{
__m128 w = _mm_load_ps(&vec1[i]);
__m128 x = _mm_load_ps(&vec2[i]);
x = _mm_mul_ps(w, x);
u = _mm_add_ps(u, x);
}
__attribute__((aligned(16))) float t[4] = {0};
_mm_store_ps(t, u);
return t[0] + t[1] + t[2] + t[3];
}
float dot_avx(const float *vec1, const float *vec2, unsigned n)
{
__m256 u = {0};
for(unsigned i = 0; i < n; i += 8)
{
__m256 w = _mm256_load_ps(&vec1[i]);
__m256 x = _mm256_load_ps(&vec2[i]);
x = _mm256_mul_ps(w, x);
u = _mm256_add_ps(u, x);
}
__attribute__((aligned(32))) float t[8] = {0};
_mm256_store_ps(t, u);
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];
}
float dot_avx_2(const float *vec1, const float *vec2, unsigned n)
{
__m256 u1 = {0};
__m256 u2 = {0};
for(unsigned i = 0; i < n; i += 16)
{
__m256 w1 = _mm256_load_ps(&vec1[i]);
__m256 w2 = _mm256_load_ps(&vec1[i + 8]);
__m256 x1 = _mm256_load_ps(&vec2[i]);
__m256 x2 = _mm256_load_ps(&vec2[i + 8]);
x1 = _mm256_mul_ps(w1, x1);
x2 = _mm256_mul_ps(w2, x2);
u1 = _mm256_add_ps(u1, x1);
u2 = _mm256_add_ps(u2, x2);
}
u1 = _mm256_add_ps(u1, u2);
__attribute__((aligned(32))) float t[8] = {0};
_mm256_store_ps(t, u1);
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];
}
//FMA版
float dot_avx_fma(const float *vec1, const float *vec2, unsigned n)
{
__m256 u1 = {0};
__m256 u2 = {0};
for(unsigned i = 0; i < n; i += 16)
{
__m256 w1 = _mm256_load_ps(&vec1[i]);
__m256 w2 = _mm256_load_ps(&vec1[i + 8]);
__m256 x1 = _mm256_load_ps(&vec2[i]);
__m256 x2 = _mm256_load_ps(&vec2[i + 8]);
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`)
u1 = _mm256_fmadd_ps(w1, x1, u1);
u2 = _mm256_fmadd_ps(w2, x2, u2);
}
u1 = _mm256_add_ps(u1, u2);
//レジスタから書き戻し
// __attribute__((aligned(32)))
alignas(alignof(u1)) float t[8] = {0};
_mm256_store_ps(t, u1);
return t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];
}
//FMA版
double dot_avx_fma(const double *vec1, const double *vec2, unsigned n)
{
__m256d u1 = {0};
__m256d u2 = {0};
for(unsigned i = 0; i < n; i += 8)
{
__m256d w1 = _mm256_load_pd(&vec1[i]);
__m256d w2 = _mm256_load_pd(&vec1[i + 4]);
__m256d x1 = _mm256_load_pd(&vec2[i]);
__m256d x2 = _mm256_load_pd(&vec2[i + 4]);
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`)
u1 = _mm256_fmadd_pd(w1, x1, u1);
u2 = _mm256_fmadd_pd(w2, x2, u2);
}
u1 = _mm256_add_pd(u1, u2);
//レジスタから書き戻し
// __attribute__((aligned(32)))
alignas(alignof(u1)) double t[4] = {0};
_mm256_store_pd(t, u1);
return t[0] + t[1] + t[2] + t[3];
}
/*
float dot_avx512_fma(const float *vec1, const float *vec2, unsigned n)
{
__m512 u1 = {0};
__m512 u2 = {0};
for(unsigned i = 0; i < n; i += 32)
{
__m512 w1 = _mm512_load_ps(&vec1[i]);
__m512 w2 = _mm512_load_ps(&vec1[i + 16]);
__m512 x1 = _mm512_load_ps(&vec2[i]);
__m512 x2 = _mm512_load_ps(&vec2[i + 16]);
//FMA命令で加算と乗算を行うけど,Haswellアーキテクチャ待ち(´・ω・`)
u1 = _mm512_fmadd_ps(w1, x1, u1);
u2 = _mm512_fmadd_ps(w2, x2, u2);
}
u1 = _mm512_add_ps(u1, u2);
//レジスタから書き戻し
__attribute__((aligned(32))) float t[16] = {0};
_mm512_store_ps(t, u1);
float ret = 0;
for (unsigned i = 0; i < 16; ++i) {
ret += t[i];
}
return ret;
}
*/
#include <chrono>
template<class T>
double calc_for_a_moment(T t, unsigned ms)
{
using MS = std::chrono::milliseconds;
auto msec = MS(ms);
auto start = std::chrono::high_resolution_clock::now();
int cnt = 0;
volatile float sum = 0; //最適化で消されるの防止
auto elapsed = start;
while(elapsed-start < msec)
{
sum += t();
elapsed = std::chrono::high_resolution_clock::now();
++cnt;
}
return (double) std::chrono::duration_cast<MS>(elapsed-start).count() / cnt;
}
template <typename T>
void assert_approx(T a, T b, std::string name, T eps=1e-4) {
if (std::fabs((a - b) / a) > eps) {
std::cerr << "std::fabs(" << a << " - " << b << ") > " << eps << std::endl;
throw std::runtime_error(name + " is wrong answer");
}
}
int main()
{
const unsigned len_begin = 8;
const unsigned len_end = 512 * 1024;
const unsigned len_fact = 2;
const unsigned run_ms = 100;
std::mt19937 rng;
std::uniform_real_distribution<> dst(-1, 1);
using F = double;
using M256 = __m256;
for(unsigned len = len_begin; len <= len_end; len *= len_fact)
{
F *p1 = new __attribute__((aligned(32))) F[len + 8];
F *p2 = new __attribute__((aligned(32))) F[len + 8];
F *vec1 = p1;
F *vec2 = p2;
while(reinterpret_cast<long>(vec1) % 32) ++vec1;
while(reinterpret_cast<long>(vec2) % 32) ++vec2;
std::generate(vec1, vec1 + len, [&rng, &dst](){ return dst(rng); });
std::generate(vec2, vec2 + len, [&rng, &dst](){ return dst(rng); });
// std::cout << (boost::format("%d %lf %lf %lf %lf")
printf("len:\t%d\n"
"ref:\t%lf\n"
// "sse:\t%lf\n"
// "avx:\t%lf\n"
// "avx2:\t%lf\n"
"avxfma:\t%lf\n",
len,
calc_for_a_moment([vec1, vec2, len](){ return dot_normal(vec1, vec2, len); }, run_ms),
// calc_for_a_moment([vec1, vec2, len](){ return dot_sse (vec1, vec2, len); }, run_ms),
// calc_for_a_moment([vec1, vec2, len](){ return dot_avx (vec1, vec2, len); }, run_ms),
// calc_for_a_moment([vec1, vec2, len](){ return dot_avx_2 (vec1, vec2, len); }, run_ms),
calc_for_a_moment([vec1, vec2, len](){ return dot_avx_fma(vec1, vec2, len); }, run_ms)
);
auto expected = dot_normal(vec1, vec2, len);
// assert_approx(expected, dot_sse(vec1, vec2, len), "sse");
assert_approx(expected, dot_avx_fma(vec1, vec2, len), "avx_fma");
// ) << std::endl;
delete[] p1;
delete[] p2;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment