Skip to content

Instantly share code, notes, and snippets.

@trimental
Created March 20, 2023 04:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trimental/ef99199539f650775bc1d975387b657b to your computer and use it in GitHub Desktop.
Save trimental/ef99199539f650775bc1d975387b657b to your computer and use it in GitHub Desktop.
portable simd example
use std::simd::*;
use half::f16;
pub type I32x8 = i32x8;
pub type F32x8 = f32x8;
pub type I16x8 = i16x8;
/* ------------------ */
/* Loading and storing things */
/* ------------------ */
#[inline]
pub fn load_i16x8(ptr: *const I16x8) -> I16x8 {
let slice: &[i16] = unsafe {
std::slice::from_raw_parts(ptr as *const i16, 8)
};
i16x8::from_slice(slice)
//
}
#[inline]
pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) {
unsafe { core::ptr::copy_nonoverlapping(&a as *const I16x8, ptr, 1) }
//
}
#[inline]
pub fn load_f32x8(ptr: *const F32x8) -> F32x8 {
let slice: &[f32] = unsafe {
std::slice::from_raw_parts(ptr as *const f32, 8)
};
f32x8::from_slice(slice)
//
}
#[inline]
pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) {
unsafe { core::ptr::copy_nonoverlapping(&a as *const F32x8, ptr, 1) }
//
}
#[inline]
pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 {
dbg!("i");
let slice: &[f32] = unsafe {
std::slice::from_raw_parts(ptr as *const f32, 8)
};
dbg!("j");
Simd::gather_or_default(slice, indices.cast::<usize>())
}
/* ------------------ */
/* Conversions */
/* ------------------ */
#[inline]
pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 {
a.cast::<f32>()
}
#[inline]
pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 {
a.cast::<i16>()
}
/*
* Constants, creating from constants
*/
pub fn f32x8_zero() -> F32x8 {
f32x8::default()
}
pub fn i16x8_zero() -> I16x8 {
i16x8::default()
}
pub fn f32x8_singleton(value: f32) -> F32x8 {
f32x8::splat(value)
}
pub fn i32x8_from_values(
val0: i32,
val1: i32,
val2: i32,
val3: i32,
val4: i32,
val5: i32,
val6: i32,
val7: i32,
) -> I32x8 {
i32x8::from_slice(&[val0, val1, val2, val3, val4, val5, val6, val7])
}
/*
* Operations
*/
// FMA
// a * b + c
pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 {
a * b + c
}
// Horizontal sums
#[inline]
pub fn horizontal_sum_f32x8(mut ymm: F32x8) -> f32 {
ymm.as_array().iter().fold(0., |acc, val| acc + val)
}
#[inline]
pub fn horizontal_sum_and_f32_to_f16(mut ymm: F32x8) -> f16 {
f16::from_f32(horizontal_sum_f32x8(ymm))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment