Created
April 12, 2020 20:10
-
-
Save eikendev/bcadad1df680cc8b96d4e2ec99e556d2 to your computer and use it in GitHub Desktop.
Walsh-Hadamard transform using SIMD intrinsics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
#define NR (8) // Number of rows. | |
#define MR (75) // Number of columns. | |
static const inline __m256d transform1a(__m256d a, __m256d b) | |
{ | |
const __m256d ret = _mm256_add_pd(a, b); | |
return ret; | |
} | |
static const inline __m256d transform1b(__m256d a, __m256d b) | |
{ | |
const __m256d ret = _mm256_sub_pd(a, b); | |
return ret; | |
} | |
static const inline __m256d transform2(const __m256d aamm, __m256d a) | |
{ | |
// [a2, a3, a0, a1] | |
const __m256d a2301 = _mm256_permute2f128_pd(a, a, 0b00000001); | |
// [a0 + a2, a1 + a3, a0 - a2, a1 - a3] | |
const __m256d ret = _mm256_fmadd_pd(a, aamm, a2301); | |
return ret; | |
} | |
static const inline __m256d transform3(const __m256d amam, __m256d a) | |
{ | |
// [a1, a0, a3, a2] | |
const __m256d a1032 = _mm256_permute_pd(a, 0b0101); | |
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3] | |
const __m256d ret = _mm256_fmadd_pd(a, amam, a1032); | |
return ret; | |
} | |
static void wht_decomposed_vectors(double* A, double* C) | |
{ | |
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0); | |
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0); | |
double *a = A; | |
double *c = C; | |
int i; | |
for (i = 0; i < MR - 4; i += 4) { | |
// Load the input rows. | |
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0); | |
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4); | |
const __m256d a1l = _mm256_load_pd(a + (1 * NR) + 0); | |
const __m256d a1h = _mm256_load_pd(a + (1 * NR) + 4); | |
const __m256d a2l = _mm256_load_pd(a + (2 * NR) + 0); | |
const __m256d a2h = _mm256_load_pd(a + (2 * NR) + 4); | |
const __m256d a3l = _mm256_load_pd(a + (3 * NR) + 0); | |
const __m256d a3h = _mm256_load_pd(a + (3 * NR) + 4); | |
// Apply the first transformation. | |
const __m256d c00l = transform1a(a0l, a0h); | |
const __m256d c00h = transform1b(a0l, a0h); | |
const __m256d c01l = transform1a(a1l, a1h); | |
const __m256d c01h = transform1b(a1l, a1h); | |
const __m256d c02l = transform1a(a2l, a2h); | |
const __m256d c02h = transform1b(a2l, a2h); | |
const __m256d c03l = transform1a(a3l, a3h); | |
const __m256d c03h = transform1b(a3l, a3h); | |
// Apply the second transformation. | |
const __m256d c10l = transform2(aamm, c00l); | |
const __m256d c10h = transform2(aamm, c00h); | |
const __m256d c11l = transform2(aamm, c01l); | |
const __m256d c11h = transform2(aamm, c01h); | |
const __m256d c12l = transform2(aamm, c02l); | |
const __m256d c12h = transform2(aamm, c02h); | |
const __m256d c13l = transform2(aamm, c03l); | |
const __m256d c13h = transform2(aamm, c03h); | |
// Apply the third transformation. | |
const __m256d c20l = transform3(amam, c10l); | |
const __m256d c20h = transform3(amam, c10h); | |
const __m256d c21l = transform3(amam, c11l); | |
const __m256d c21h = transform3(amam, c11h); | |
const __m256d c22l = transform3(amam, c12l); | |
const __m256d c22h = transform3(amam, c12h); | |
const __m256d c23l = transform3(amam, c13l); | |
const __m256d c23h = transform3(amam, c13h); | |
// Write the output rows. | |
_mm256_store_pd(c + (0 * NR) + 0, c20l); | |
_mm256_store_pd(c + (0 * NR) + 4, c20h); | |
_mm256_store_pd(c + (1 * NR) + 0, c21l); | |
_mm256_store_pd(c + (1 * NR) + 4, c21h); | |
_mm256_store_pd(c + (2 * NR) + 0, c22l); | |
_mm256_store_pd(c + (2 * NR) + 4, c22h); | |
_mm256_store_pd(c + (3 * NR) + 0, c23l); | |
_mm256_store_pd(c + (3 * NR) + 4, c23h); | |
// Adjust the pointers for the next four rows. | |
a += 4 * NR; | |
c += 4 * NR; | |
} | |
for (; i < MR; i++) { | |
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0); | |
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4); | |
const __m256d c00l = transform1a(a0l, a0h); | |
const __m256d c00h = transform1b(a0l, a0h); | |
const __m256d c10l = transform2(aamm, c00l); | |
const __m256d c10h = transform2(aamm, c00h); | |
const __m256d c20l = transform3(amam, c10l); | |
const __m256d c20h = transform3(amam, c10h); | |
_mm256_store_pd(c + (0 * NR) + 0, c20l); | |
_mm256_store_pd(c + (0 * NR) + 4, c20h); | |
a += 1 * NR; | |
c += 1 * NR; | |
} | |
} | |
static void wht_decomposed_novectors(double* A, double* C) | |
{ | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
double c00, c01, c02, c03, c04, c05, c06, c07; | |
double c10, c11, c12, c13, c14, c15, c16, c17; | |
c00 = a[0] + a[4]; | |
c01 = a[1] + a[5]; | |
c02 = a[2] + a[6]; | |
c03 = a[3] + a[7]; | |
c04 = a[0] - a[4]; | |
c05 = a[1] - a[5]; | |
c06 = a[2] - a[6]; | |
c07 = a[3] - a[7]; | |
c10 = c00 + c02; | |
c11 = c01 + c03; | |
c12 = c00 - c02; | |
c13 = c01 - c03; | |
c14 = c04 + c06; | |
c15 = c05 + c07; | |
c16 = c04 - c06; | |
c17 = c05 - c07; | |
c[0] = c10 + c11; | |
c[1] = c10 - c11; | |
c[2] = c12 + c13; | |
c[3] = c12 - c13; | |
c[4] = c14 + c15; | |
c[5] = c14 - c15; | |
c[6] = c16 + c17; | |
c[7] = c16 - c17; | |
a += NR; | |
c += NR; | |
} | |
} | |
static const inline __m256d wht4x4(const __m256d aamm, const __m256d amam, __m256d a) | |
{ | |
// [a1, a0, a3, a2] | |
__m256d a1032 = _mm256_permute_pd(a, 0b0101); | |
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3] | |
__m256d a01012323 = _mm256_fmadd_pd(a, amam, a1032); | |
// [a2 + a3, a2 - a3, a0 + a1, a0 - a1] | |
__m256d a23230101 = _mm256_permute2f128_pd(a01012323, a01012323, 0b00000001); | |
// [a0 + a1 + a2 + a3, a0 - a1 + a2 - a3, a0 + a1 - a2 - a3, a0 - a1 - a2 + a3] | |
__m256d ret = _mm256_fmadd_pd(a01012323, aamm, a23230101); | |
return ret; | |
} | |
static void wht_composed_vectors(double* A, double* C) | |
{ | |
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0); | |
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0); | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
const __m256d al = _mm256_load_pd(a + 0); | |
const __m256d ah = _mm256_load_pd(a + 4); | |
const __m256d cl = wht4x4(aamm, amam, al); | |
const __m256d ch = wht4x4(aamm, amam, ah); | |
_mm256_store_pd(c + 0, _mm256_add_pd(cl, ch)); | |
_mm256_store_pd(c + 4, _mm256_sub_pd(cl, ch)); | |
a += NR; | |
c += NR; | |
} | |
} | |
static void wht_composed_novectors(double* A, double* C) | |
{ | |
double *a = A; | |
double *c = C; | |
for (int i = 0; i < MR; ++i) { | |
c[0] = a[0] + a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7]; | |
c[1] = a[0] - a[1] + a[2] - a[3] + a[4] - a[5] + a[6] - a[7]; | |
c[2] = a[0] + a[1] - a[2] - a[3] + a[4] + a[5] - a[6] - a[7]; | |
c[3] = a[0] - a[1] - a[2] + a[3] + a[4] - a[5] - a[6] + a[7]; | |
c[4] = a[0] + a[1] + a[2] + a[3] - a[4] - a[5] - a[6] - a[7]; | |
c[5] = a[0] - a[1] + a[2] - a[3] - a[4] + a[5] - a[6] + a[7]; | |
c[6] = a[0] + a[1] - a[2] - a[3] - a[4] - a[5] + a[6] + a[7]; | |
c[7] = a[0] - a[1] - a[2] + a[3] - a[4] + a[5] + a[6] - a[7]; | |
a += NR; | |
c += NR; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment