Skip to content

Instantly share code, notes, and snippets.

@eikendev
Created April 12, 2020 20:10
Show Gist options
  • Save eikendev/bcadad1df680cc8b96d4e2ec99e556d2 to your computer and use it in GitHub Desktop.
Save eikendev/bcadad1df680cc8b96d4e2ec99e556d2 to your computer and use it in GitHub Desktop.
Walsh-Hadamard transform using SIMD intrinsics
#include <immintrin.h>
#define NR (8) // Number of rows.
#define MR (75) // Number of columns.
static const inline __m256d transform1a(__m256d a, __m256d b)
{
const __m256d ret = _mm256_add_pd(a, b);
return ret;
}
static const inline __m256d transform1b(__m256d a, __m256d b)
{
const __m256d ret = _mm256_sub_pd(a, b);
return ret;
}
static const inline __m256d transform2(const __m256d aamm, __m256d a)
{
// [a2, a3, a0, a1]
const __m256d a2301 = _mm256_permute2f128_pd(a, a, 0b00000001);
// [a0 + a2, a1 + a3, a0 - a2, a1 - a3]
const __m256d ret = _mm256_fmadd_pd(a, aamm, a2301);
return ret;
}
static const inline __m256d transform3(const __m256d amam, __m256d a)
{
// [a1, a0, a3, a2]
const __m256d a1032 = _mm256_permute_pd(a, 0b0101);
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3]
const __m256d ret = _mm256_fmadd_pd(a, amam, a1032);
return ret;
}
static void wht_decomposed_vectors(double* A, double* C)
{
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0);
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0);
double *a = A;
double *c = C;
int i;
for (i = 0; i < MR - 4; i += 4) {
// Load the input rows.
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0);
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4);
const __m256d a1l = _mm256_load_pd(a + (1 * NR) + 0);
const __m256d a1h = _mm256_load_pd(a + (1 * NR) + 4);
const __m256d a2l = _mm256_load_pd(a + (2 * NR) + 0);
const __m256d a2h = _mm256_load_pd(a + (2 * NR) + 4);
const __m256d a3l = _mm256_load_pd(a + (3 * NR) + 0);
const __m256d a3h = _mm256_load_pd(a + (3 * NR) + 4);
// Apply the first transformation.
const __m256d c00l = transform1a(a0l, a0h);
const __m256d c00h = transform1b(a0l, a0h);
const __m256d c01l = transform1a(a1l, a1h);
const __m256d c01h = transform1b(a1l, a1h);
const __m256d c02l = transform1a(a2l, a2h);
const __m256d c02h = transform1b(a2l, a2h);
const __m256d c03l = transform1a(a3l, a3h);
const __m256d c03h = transform1b(a3l, a3h);
// Apply the second transformation.
const __m256d c10l = transform2(aamm, c00l);
const __m256d c10h = transform2(aamm, c00h);
const __m256d c11l = transform2(aamm, c01l);
const __m256d c11h = transform2(aamm, c01h);
const __m256d c12l = transform2(aamm, c02l);
const __m256d c12h = transform2(aamm, c02h);
const __m256d c13l = transform2(aamm, c03l);
const __m256d c13h = transform2(aamm, c03h);
// Apply the third transformation.
const __m256d c20l = transform3(amam, c10l);
const __m256d c20h = transform3(amam, c10h);
const __m256d c21l = transform3(amam, c11l);
const __m256d c21h = transform3(amam, c11h);
const __m256d c22l = transform3(amam, c12l);
const __m256d c22h = transform3(amam, c12h);
const __m256d c23l = transform3(amam, c13l);
const __m256d c23h = transform3(amam, c13h);
// Write the output rows.
_mm256_store_pd(c + (0 * NR) + 0, c20l);
_mm256_store_pd(c + (0 * NR) + 4, c20h);
_mm256_store_pd(c + (1 * NR) + 0, c21l);
_mm256_store_pd(c + (1 * NR) + 4, c21h);
_mm256_store_pd(c + (2 * NR) + 0, c22l);
_mm256_store_pd(c + (2 * NR) + 4, c22h);
_mm256_store_pd(c + (3 * NR) + 0, c23l);
_mm256_store_pd(c + (3 * NR) + 4, c23h);
// Adjust the pointers for the next four rows.
a += 4 * NR;
c += 4 * NR;
}
for (; i < MR; i++) {
const __m256d a0l = _mm256_load_pd(a + (0 * NR) + 0);
const __m256d a0h = _mm256_load_pd(a + (0 * NR) + 4);
const __m256d c00l = transform1a(a0l, a0h);
const __m256d c00h = transform1b(a0l, a0h);
const __m256d c10l = transform2(aamm, c00l);
const __m256d c10h = transform2(aamm, c00h);
const __m256d c20l = transform3(amam, c10l);
const __m256d c20h = transform3(amam, c10h);
_mm256_store_pd(c + (0 * NR) + 0, c20l);
_mm256_store_pd(c + (0 * NR) + 4, c20h);
a += 1 * NR;
c += 1 * NR;
}
}
static void wht_decomposed_novectors(double* A, double* C)
{
double *a = A;
double *c = C;
for (int i = 0; i < MR; ++i) {
double c00, c01, c02, c03, c04, c05, c06, c07;
double c10, c11, c12, c13, c14, c15, c16, c17;
c00 = a[0] + a[4];
c01 = a[1] + a[5];
c02 = a[2] + a[6];
c03 = a[3] + a[7];
c04 = a[0] - a[4];
c05 = a[1] - a[5];
c06 = a[2] - a[6];
c07 = a[3] - a[7];
c10 = c00 + c02;
c11 = c01 + c03;
c12 = c00 - c02;
c13 = c01 - c03;
c14 = c04 + c06;
c15 = c05 + c07;
c16 = c04 - c06;
c17 = c05 - c07;
c[0] = c10 + c11;
c[1] = c10 - c11;
c[2] = c12 + c13;
c[3] = c12 - c13;
c[4] = c14 + c15;
c[5] = c14 - c15;
c[6] = c16 + c17;
c[7] = c16 - c17;
a += NR;
c += NR;
}
}
static const inline __m256d wht4x4(const __m256d aamm, const __m256d amam, __m256d a)
{
// [a1, a0, a3, a2]
__m256d a1032 = _mm256_permute_pd(a, 0b0101);
// [a0 + a1, a0 - a1, a2 + a3, a2 - a3]
__m256d a01012323 = _mm256_fmadd_pd(a, amam, a1032);
// [a2 + a3, a2 - a3, a0 + a1, a0 - a1]
__m256d a23230101 = _mm256_permute2f128_pd(a01012323, a01012323, 0b00000001);
// [a0 + a1 + a2 + a3, a0 - a1 + a2 - a3, a0 + a1 - a2 - a3, a0 - a1 - a2 + a3]
__m256d ret = _mm256_fmadd_pd(a01012323, aamm, a23230101);
return ret;
}
static void wht_composed_vectors(double* A, double* C)
{
const __m256d aamm = _mm256_set_pd(-1.0, -1.0, 1.0, 1.0);
const __m256d amam = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0);
double *a = A;
double *c = C;
for (int i = 0; i < MR; ++i) {
const __m256d al = _mm256_load_pd(a + 0);
const __m256d ah = _mm256_load_pd(a + 4);
const __m256d cl = wht4x4(aamm, amam, al);
const __m256d ch = wht4x4(aamm, amam, ah);
_mm256_store_pd(c + 0, _mm256_add_pd(cl, ch));
_mm256_store_pd(c + 4, _mm256_sub_pd(cl, ch));
a += NR;
c += NR;
}
}
static void wht_composed_novectors(double* A, double* C)
{
double *a = A;
double *c = C;
for (int i = 0; i < MR; ++i) {
c[0] = a[0] + a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7];
c[1] = a[0] - a[1] + a[2] - a[3] + a[4] - a[5] + a[6] - a[7];
c[2] = a[0] + a[1] - a[2] - a[3] + a[4] + a[5] - a[6] - a[7];
c[3] = a[0] - a[1] - a[2] + a[3] + a[4] - a[5] - a[6] + a[7];
c[4] = a[0] + a[1] + a[2] + a[3] - a[4] - a[5] - a[6] - a[7];
c[5] = a[0] - a[1] + a[2] - a[3] - a[4] + a[5] - a[6] + a[7];
c[6] = a[0] + a[1] - a[2] - a[3] - a[4] - a[5] + a[6] + a[7];
c[7] = a[0] - a[1] - a[2] + a[3] - a[4] + a[5] + a[6] - a[7];
a += NR;
c += NR;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment