Skip to content

Instantly share code, notes, and snippets.

@ChillFish8
Created July 2, 2024 17:07
Show Gist options
  • Save ChillFish8/80f6959dc8ea32680c2192f48d551a75 to your computer and use it in GitHub Desktop.
Save ChillFish8/80f6959dc8ea32680c2192f48d551a75 to your computer and use it in GitHub Desktop.
A simple-ish AVX2 implementation of a dot product for two vectors.
use core::arch::x86_64::*;
#[target_feature(enable = "avx2")]
unsafe fn avx2_dot_product(vector_a: &[f32], vector_b: &[f32]) -> f32 {
assert_eq!(vector_a.len(), vector_b.len(), "Vectors must be equal in length");
let dims = vector_a.len();
// We want to ensure we're going in steps of 8 elements
// and then handle the remainder separately.
let remainder = dims % 8;
let vector_a_ptr = vector_a.as_ptr();
let vector_b_ptr = vector_b.as_ptr();
let mut acc = _mm256_setzero_ps();
let mut i = 0;
while i < (dims - remainder) {
// We use `_mm256_loadu_ps` here to load memory without an alignment requirement,
// in reality, this instruction has identical performance to `_mm256_load_ps`
// on modern CPUs (and by modern I mean most CPUs within the last 10+ years)
let a = _mm256_loadu_ps(vector_a_ptr.add(i));
let b = _mm256_loadu_ps(vector_b_ptr.add(i));
let res = _mm256_mul_ps(a, b);
acc = _mm256_add_ps(acc, res);
// BONUS: If we have the `fma` CPU flag enabled you can swap the above lines to:
// acc = _mm256_fmadd_ps(a, b, acc);
i += 8;
}
// Handle the remainder of the data that doesn't fit into a AVX2 register.
//
// Now If we wanted to go crazy we could use SSE to handle another 4 elements
// in a single set of instructions, but I have never been able to measure an actual
// useful difference in performance compared to just rolling up the last `n` elements
// in a simple loop.
let mut total = avx2_sum_register(acc);
while i < dims {
let a = *vector_a.get_unchecked(i);
let b = *vector_b.get_unchecked(i);
let res = a * b;
total += res;
i += 1;
}
total
}
/// Performs a horizontal sum of a single 256bit register.
///
/// The way this works is we split the register into 128 bit registers
/// and incrementally sum the different halves of the register.
unsafe fn avx2_sum_register(reg: __m256) -> f32 {
let left_half = _mm256_extractf128_ps::<1>(reg);
let right_half = _mm256_castps256_ps128(reg);
let sum_quad = _mm_add_ps(left_half, right_half);
let left_half = sum_quad;
let right_half = _mm_movehl_ps(sum_quad, sum_quad);
let sum_dual = _mm_add_ps(left_half, right_half);
let left_half = sum_dual;
let right_half = _mm_shuffle_ps::<0x1>(sum_dual, sum_dual);
let sum = _mm_add_ss(left_half, right_half);
_mm_cvtss_f32(sum)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment