Created
July 2, 2024 17:07
-
-
Save ChillFish8/80f6959dc8ea32680c2192f48d551a75 to your computer and use it in GitHub Desktop.
A simple-ish AVX2 implementation of a dot product for two vectors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use core::arch::x86_64::*; | |
#[target_feature(enable = "avx2")] | |
unsafe fn avx2_dot_product(vector_a: &[f32], vector_b: &[f32]) -> f32 { | |
assert_eq!(vector_a.len(), vector_b.len(), "Vectors must be equal in length"); | |
let dims = vector_a.len(); | |
// We want to ensure we're going in steps of 8 elements | |
// and then handle the remainder separately. | |
let remainder = dims % 8; | |
let vector_a_ptr = vector_a.as_ptr(); | |
let vector_b_ptr = vector_b.as_ptr(); | |
let mut acc = _mm256_setzero_ps(); | |
let mut i = 0; | |
while i < (dims - remainder) { | |
// We use `_mm256_loadu_ps` here to load memory without an alignment requirement, | |
// in reality, this instruction has identical performance to `_mm256_load_ps` | |
// on modern CPUs (and by modern I mean most CPUs within the last 10+ years) | |
let a = _mm256_loadu_ps(vector_a_ptr.add(i)); | |
let b = _mm256_loadu_ps(vector_b_ptr.add(i)); | |
let res = _mm256_mul_ps(a, b); | |
acc = _mm256_add_ps(acc, res); | |
// BONUS: If we have the `fma` CPU flag enabled you can swap the above lines to: | |
// acc = _mm256_fmadd_ps(a, b, acc); | |
i += 8; | |
} | |
// Handle the remainder of the data that doesn't fit into a AVX2 register. | |
// | |
// Now If we wanted to go crazy we could use SSE to handle another 4 elements | |
// in a single set of instructions, but I have never been able to measure an actual | |
// useful difference in performance compared to just rolling up the last `n` elements | |
// in a simple loop. | |
let mut total = avx2_sum_register(acc); | |
while i < dims { | |
let a = *vector_a.get_unchecked(i); | |
let b = *vector_b.get_unchecked(i); | |
let res = a * b; | |
total += res; | |
i += 1; | |
} | |
total | |
} | |
/// Performs a horizontal sum of a single 256bit register. | |
/// | |
/// The way this works is we split the register into 128 bit registers | |
/// and incrementally sum the different halves of the register. | |
unsafe fn avx2_sum_register(reg: __m256) -> f32 { | |
let left_half = _mm256_extractf128_ps::<1>(reg); | |
let right_half = _mm256_castps256_ps128(reg); | |
let sum_quad = _mm_add_ps(left_half, right_half); | |
let left_half = sum_quad; | |
let right_half = _mm_movehl_ps(sum_quad, sum_quad); | |
let sum_dual = _mm_add_ps(left_half, right_half); | |
let left_half = sum_dual; | |
let right_half = _mm_shuffle_ps::<0x1>(sum_dual, sum_dual); | |
let sum = _mm_add_ss(left_half, right_half); | |
_mm_cvtss_f32(sum) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment