Created
April 29, 2024 16:21
-
-
Save bczhc/a2d0acf55aff4bd339b8b8f22254bc00 to your computer and use it in GitHub Desktop.
SIMD玩弄
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// set(CMAKE_CXX_FLAGS "-mavx -mavx512vl -mavx512dq") | |
#include <cstdint> | |
#include <cstddef> | |
#include <cstdlib> | |
#include <cstdio> | |
#include <iostream> | |
#include <xmmintrin.h> | |
#include <cmath> | |
#include <immintrin.h> | |
#include <sys/time.h> | |
using namespace std; | |
template<unsigned i> | |
float vectorGetByIndex(__m128 V) { | |
// shuffle V so that the element that you want is moved to the least- | |
// significant element of the vector (V[0]) | |
V = _mm_shuffle_ps(V, V, _MM_SHUFFLE(i, i, i, i)); | |
// return the value in V[0] | |
return _mm_cvtss_f32(V); | |
} | |
struct Vec4 { | |
float v[4]; | |
}; | |
Vec4 mm128_to_vec4(__m128 v) { | |
return Vec4{ | |
.v = {vectorGetByIndex<0>(v), vectorGetByIndex<1>(v), vectorGetByIndex<2>(v), vectorGetByIndex<3>(v)} | |
}; | |
} | |
struct Vec4d { | |
double v[4]; | |
}; | |
Vec4d m256d_to_vec4(__m256d v) { | |
// Extracting the lower and upper halves of the vector | |
__m128d lower = _mm256_extractf64x2_pd(v, 0); // Extracts the lower 128 bits | |
__m128d upper = _mm256_extractf64x2_pd(v, 1); // Extracts the upper 128 bits | |
// Extracting individual elements from the lower and upper halves | |
double lower_elements[2]; | |
double upper_elements[2]; | |
_mm_storeu_pd(lower_elements, lower); | |
_mm_storeu_pd(upper_elements, upper); | |
return Vec4d{ | |
.v = { | |
lower_elements[0], | |
lower_elements[1], | |
upper_elements[0], | |
upper_elements[1], | |
} | |
}; | |
} | |
void print_vec4(Vec4 v) { | |
cout << v.v[0] << ' ' << v.v[1] << ' ' << v.v[2] << ' ' << v.v[3] << endl; | |
} | |
double integrate1(double (*f)(double), double a, double b, uint64_t segments) { | |
auto delta = (b - a) / (double) segments; | |
double sum = 0.0; | |
for (double d = a; d < b; d += delta) { | |
sum += (f(d) + f(d + delta)) * delta / 2.0; | |
} | |
return sum; | |
} | |
double integrate2(double (*f)(double), double a, double b, uint64_t segments) { | |
auto delta = (b - a) / (double) segments; | |
auto interval = b - a; | |
auto f_mm256 = [&](Vec4d x) { | |
auto xs = x.v; | |
return _mm256_set_pd(f(xs[0]), f(xs[1]), f(xs[2]), f(xs[3])); | |
}; | |
auto f_mm256_sse = [&](__m256d xs) { | |
return f_mm256(m256d_to_vec4(xs)); | |
}; | |
auto s1 = a, s2 = a + interval / 4.0, s3 = a + interval / 4.0 * 2.0, s4 = a + interval / 4.0 * 3.0; | |
auto m1 = _mm256_set_pd(s1, s2, s3, s4); | |
auto m3 = _mm256_set_pd(delta, delta, delta, delta); | |
auto m4 = _mm256_add_pd(m1, m3); | |
auto m5 = _mm256_set_pd(2.0, 2.0, 2.0, 2.0); | |
auto sum = _mm256_setzero_pd(); | |
for (double d = s1; d < s2; d += delta) { | |
m1 = _mm256_add_pd(m1, m3); | |
m4 = _mm256_add_pd(m1, m3); | |
auto m6 = _mm256_div_pd(_mm256_mul_pd(_mm256_add_pd(f_mm256_sse(m1), f_mm256_sse(m4)), m3), m5); | |
sum = _mm256_add_pd(sum, m6); | |
} | |
const Vec4d &vec4 = m256d_to_vec4(sum); | |
auto v = vec4.v; | |
return v[0] + v[1] + v[2] + v[3]; | |
} | |
double integrate3(__m256d (*f)(__m256d), double a, double b, uint64_t segments) { | |
auto delta = (b - a) / (double) segments; | |
auto interval = b - a; | |
auto s1 = a, s2 = a + interval / 4.0, s3 = a + interval / 4.0 * 2.0, s4 = a + interval / 4.0 * 3.0; | |
auto m1 = _mm256_set_pd(s1, s2, s3, s4); | |
auto m3 = _mm256_set_pd(delta, delta, delta, delta); | |
auto m4 = _mm256_add_pd(m1, m3); | |
auto m5 = _mm256_set_pd(2.0, 2.0, 2.0, 2.0); | |
auto sum = _mm256_setzero_pd(); | |
for (double d = s1; d < s2; d += delta) { | |
m1 = _mm256_add_pd(m1, m3); | |
m4 = _mm256_add_pd(m1, m3); | |
auto m6 = _mm256_div_pd(_mm256_mul_pd(_mm256_add_pd(f(m1), f(m4)), m3), m5); | |
sum = _mm256_add_pd(sum, m6); | |
} | |
const Vec4d &vec4 = m256d_to_vec4(sum); | |
auto v = vec4.v; | |
return v[0] + v[1] + v[2] + v[3]; | |
} | |
uint64_t timestamp_ms() { | |
timeval t{}; | |
gettimeofday(&t, nullptr); | |
return t.tv_sec * 1000 + t.tv_usec / 1000; | |
} | |
class Timer { | |
uint64_t start{}; | |
public: | |
Timer() { | |
start = timestamp_ms(); | |
} | |
void print() const { | |
cout << timestamp_ms() - start << "ms" << endl; | |
} | |
}; | |
static __m256d M1 = _mm256_set_pd(3.0, 3.0, 3.0, 3.0); | |
int main() { | |
float sum = 0.0; | |
for (float i = 1; i <= 10000; ++i) { | |
sum += i; | |
} | |
cout << sum << endl; | |
__m128 sum2 = _mm_setzero_ps(); | |
__m128 adder = _mm_set_ps(1.0, 2.0, 3.0, 4.0); | |
__m128 c4 = _mm_set_ps(4.0, 4.0, 4.0, 4.0); | |
for (int i = 0; i < 10000 / 4; ++i) { | |
sum2 = _mm_add_ps(sum2, adder); | |
adder = _mm_add_ps(adder, c4); | |
} | |
auto v = mm128_to_vec4(sum2); | |
auto v1 = v.v; | |
print_vec4(v); | |
cout << v1[0] + v1[1] + v1[2] + v1[3] << endl; | |
auto fn = [](double x) { | |
return sqrt(pow(sin(x), cos(x))); | |
}; | |
Timer t1; | |
cout << integrate1(fn, 1.0, 3.0, 10000000) << endl; | |
t1.print(); | |
Timer t2; | |
cout << integrate2(fn, 1.0, 3.0, 10000000) << endl; | |
t2.print(); | |
auto fn2 = [](double x) { | |
return 3.0 / (x * x); | |
}; | |
auto fn2_m256d = [](__m256d x) { | |
auto m1 = _mm256_mul_pd(x, x); | |
return _mm256_div_pd(M1, m1); | |
}; | |
Timer t3; | |
cout << integrate1(fn2, 2.0, 4.0, 200000000) << endl; | |
t3.print(); | |
Timer t4; | |
cout << integrate3(fn2_m256d, 2.0, 4.0, 200000000) << endl; | |
t4.print(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment