Walsh-Hadamard transform using SIMD intrinsics
#include <immintrin.h> | |
#define NR (8) // Number of rows. | |
#define MR (75) // Number of columns. | |
static const inline __m256d transform1a(__m256d a, __m256d b) | |
{ | |
const __m256d ret = _mm256_add_pd(a, b); | |
return ret; | |
} | |
static const inline __m256d transform1b(__m256d a, __m256d b) | |
{ | |
const __m256d ret = _mm256_sub_pd(a, b); | |
return ret; | |
} | |
static const inline __m256d transform2(const __m256d aamm, __m256d a) | |
{ | |
// [a2, a3, a0, a1] | |
const __m256d a2301 = _mm256_permute2f128_pd(a, a, 0b00000001); | |
// [a0 + a2, a1 + a3, a0 - a2, a1 - a3] | |
const __m256d ret = _mm256_fmadd_pd(a, aamm, a2301); | |
return ret; | |
} | |
static const inline __m256d transform3(const __m256d amam, __m256d a) | |
{ | |
// [a1, a0, a3, a2] | |
const __m256d a1032 = _mm256_permute_pd(a, 0b0101); | |
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3] | |
const __m256d ret = _mm256_fmadd_pd(a, amam, a1032); | |
return ret; | |
} | |
static void wht_decomposed_vectors(double* A, double* C) | |
{ | |
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0); | |
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0); | |
double *a = A; | |
double *c = C; | |
int i; | |
for (i = 0; i < MR - 4; i += 4) { | |
// Load the input rows. | |
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0); | |
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4); | |
const __m256d a1l = _mm256_load_pd(a + (1 * NR) + 0); | |
const __m256d a1h = _mm256_load_pd(a + (1 * NR) + 4); | |
const __m256d a2l = _mm256_load_pd(a + (2 * NR) + 0); | |
const __m256d a2h = _mm256_load_pd(a + (2 * NR) + 4); | |
const __m256d a3l = _mm256_load_pd(a + (3 * NR) + 0); | |
const __m256d a3h = _mm256_load_pd(a + (3 * NR) + 4); | |
// Apply the first transformation. | |
const __m256d c00l = transform1a(a0l, a0h); | |
const __m256d c00h = transform1b(a0l, a0h); | |
const __m256d c01l = transform1a(a1l, a1h); | |
const __m256d c01h = transform1b(a1l, a1h); | |
const __m256d c02l = transform1a(a2l, a2h); | |
const __m256d c02h = transform1b(a2l, a2h); | |
const __m256d c03l = transform1a(a3l, a3h); | |
const __m256d c03h = transform1b(a3l, a3h); | |
// Apply the second transformation. | |
const __m256d c10l = transform2(aamm, c00l); | |
const __m256d c10h = transform2(aamm, c00h); | |
const __m256d c11l = transform2(aamm, c01l); | |
const __m256d c11h = transform2(aamm, c01h); | |
const __m256d c12l = transform2(aamm, c02l); | |
const __m256d c12h = transform2(aamm, c02h); | |
const __m256d c13l = transform2(aamm, c03l); | |
const __m256d c13h = transform2(aamm, c03h); | |
// Apply the third transformation. | |
const __m256d c20l = transform3(amam, c10l); | |
const __m256d c20h = transform3(amam, c10h); | |
const __m256d c21l = transform3(amam, c11l); | |
const __m256d c21h = transform3(amam, c11h); | |
const __m256d c22l = transform3(amam, c12l); | |
const __m256d c22h = transform3(amam, c12h); | |
const __m256d c23l = transform3(amam, c13l); | |
const __m256d c23h = transform3(amam, c13h); | |
// Write the output rows. | |
_mm256_store_pd(c + (0 * NR) + 0, c20l); | |
_mm256_store_pd(c + (0 * NR) + 4, c20h); | |
_mm256_store_pd(c + (1 * NR) + 0, c21l); | |
_mm256_store_pd(c + (1 * NR) + 4, c21h); | |
_mm256_store_pd(c + (2 * NR) + 0, c22l); | |
_mm256_store_pd(c + (2 * NR) + 4, c22h); | |
_mm256_store_pd(c + (3 * NR) + 0, c23l); | |
_mm256_store_pd(c + (3 * NR) + 4, c23h); | |
// Adjust the pointers for the next four rows. | |
a += 4 * NR; | |
c += 4 * NR; | |
} | |
for (; i < MR; i++) { | |
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0); | |
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4); | |
const __m256d c00l = transform1a(a0l, a0h); | |
const __m256d c00h = transform1b(a0l, a0h); | |
const __m256d c10l = transform2(aamm, c00l); | |
const __m256d c10h = transform2(aamm, c00h); | |
const __m256d c20l = transform3(amam, c10l); | |
const __m256d c20h = transform3(amam, c10h); | |
_mm256_store_pd(c + (0 * NR) + 0, c20l); | |
_mm256_store_pd(c + (0 * NR) + 4, c20h); | |
a += 1 * NR; | |
c += 1 * NR; | |
} | |
} | |
static void wht_decomposed_novectors(double* A, double* C) | |
{ | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
double c00, c01, c02, c03, c04, c05, c06, c07; | |
double c10, c11, c12, c13, c14, c15, c16, c17; | |
c00 = a[0] + a[4]; | |
c01 = a[1] + a[5]; | |
c02 = a[2] + a[6]; | |
c03 = a[3] + a[7]; | |
c04 = a[0] - a[4]; | |
c05 = a[1] - a[5]; | |
c06 = a[2] - a[6]; | |
c07 = a[3] - a[7]; | |
c10 = c00 + c02; | |
c11 = c01 + c03; | |
c12 = c00 - c02; | |
c13 = c01 - c03; | |
c14 = c04 + c06; | |
c15 = c05 + c07; | |
c16 = c04 - c06; | |
c17 = c05 - c07; | |
c[0] = c10 + c11; | |
c[1] = c10 - c11; | |
c[2] = c12 + c13; | |
c[3] = c12 - c13; | |
c[4] = c14 + c15; | |
c[5] = c14 - c15; | |
c[6] = c16 + c17; | |
c[7] = c16 - c17; | |
a += NR; | |
c += NR; | |
} | |
} | |
static const inline __m256d wht4x4(const __m256d aamm, const __m256d amam, __m256d a) | |
{ | |
// [a1, a0, a3, a2] | |
__m256d a1032 = _mm256_permute_pd(a, 0b0101); | |
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3] | |
__m256d a01012323 = _mm256_fmadd_pd(a, amam, a1032); | |
// [a2 + a3, a2 - a3, a0 + a1, a0 - a1] | |
__m256d a23230101 = _mm256_permute2f128_pd(a01012323, a01012323, 0b00000001); | |
// [a0 + a1 + a2 + a3, a0 - a1 + a2 - a3, a0 + a1 - a2 - a3, a0 - a1 - a2 + a3] | |
__m256d ret = _mm256_fmadd_pd(a01012323, aamm, a23230101); | |
return ret; | |
} | |
static void wht_composed_vectors(double* A, double* C) | |
{ | |
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0); | |
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0); | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
const __m256d al = _mm256_load_pd(a + 0); | |
const __m256d ah = _mm256_load_pd(a + 4); | |
const __m256d cl = wht4x4(aamm, amam, al); | |
const __m256d ch = wht4x4(aamm, amam, ah); | |
_mm256_store_pd(c + 0, _mm256_add_pd(cl, ch)); | |
_mm256_store_pd(c + 4, _mm256_sub_pd(cl, ch)); | |
a += NR; | |
c += NR; | |
} | |
} | |
static void wht_composed_novectors(double* A, double* C) | |
{ | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
c[0] = a[0] + a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7]; | |
c[1] = a[0] - a[1] + a[2] - a[3] + a[4] - a[5] + a[6] - a[7]; | |
c[2] = a[0] + a[1] - a[2] - a[3] + a[4] + a[5] - a[6] - a[7]; | |
c[3] = a[0] - a[1] - a[2] + a[3] + a[4] - a[5] - a[6] + a[7]; | |
c[4] = a[0] + a[1] + a[2] + a[3] - a[4] - a[5] - a[6] - a[7]; | |
c[5] = a[0] - a[1] + a[2] - a[3] - a[4] + a[5] - a[6] + a[7]; | |
c[6] = a[0] + a[1] - a[2] - a[3] - a[4] - a[5] + a[6] + a[7]; | |
c[7] = a[0] - a[1] - a[2] + a[3] - a[4] + a[5] + a[6] - a[7]; | |
a += NR; | |
c += NR; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment