Skip to content

Instantly share code, notes, and snippets.

@ronlobo
Last active February 14, 2020 17:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ronlobo/6b39a02f1adca0050e62561ab0f0077f to your computer and use it in GitHub Desktop.
Save ronlobo/6b39a02f1adca0050e62561ab0f0077f to your computer and use it in GitHub Desktop.
Rust Fearless SIMD bundled for Codingame - https://github.com/raphlinus/fearless_simd
pub mod fallback {
use std::ptr;
use crate::traits::{SimdF32, SimdMask32};
impl SimdF32 for f32 {
type Raw = f32;
type Mask = u32;
#[inline]
fn width(self) -> usize { 1 }
#[inline]
fn floor(self) -> f32 { f32::floor(self) }
#[inline]
fn ceil(self) -> f32 { f32::ceil(self) }
#[inline]
fn round(self) -> f32 { f32::floor(self + (0.5 - 0.25 * ::std::f32::EPSILON)) }
#[inline]
fn abs(self) -> f32 { f32::abs(self) }
#[inline]
fn min(self, b: f32) -> f32 { f32::min(self, b) }
#[inline]
fn max(self, b: f32) -> f32 { f32::max(self, b) }
#[inline]
fn recip(self) -> f32 { f32::recip(self) }
#[inline]
fn rsqrt(self) -> f32 { f32::recip(f32::sqrt(self)) }
#[inline]
fn splat(self, x: f32) -> f32 { x }
#[inline]
fn steps(self) -> f32 { 0.0 }
#[inline]
unsafe fn from_raw(raw: f32) -> f32 { raw }
#[inline]
unsafe fn load(p: *const f32) -> f32 { ptr::read(p) }
#[inline]
unsafe fn store(self, p: *mut f32) { ptr::write(p, self); }
#[inline]
unsafe fn create() -> f32 { 0.0 }
#[inline]
fn eq(self, other: f32) -> u32 {
if self == other { !0 } else { 0 }
}
}
impl SimdMask32 for u32 {
type Raw = u32;
type F32 = f32;
#[inline]
fn select(self, a: f32, b: f32) -> f32 {
if self & 0x80000000 != 0 { a } else { b }
}
}
}
pub mod traits {
use std::ops::{Add, Sub, Mul, Div, Neg, BitAnd, Deref};
pub trait SimdF32: Sized + Copy + Clone
+ Add<Self, Output=Self> + Add<f32, Output=Self>
+ Sub<Self, Output=Self> + Sub<f32, Output=Self>
+ Mul<Self, Output=Self> + Mul<f32, Output=Self>
+ Div<Self, Output=Self> + Mul<f32, Output=Self>
+ Neg<Output=Self>
/*
where f32: Add<Self, Output=Self>,
f32: Sub<Self, Output=Self>,
f32: Mul<Self, Output=Self>,
*/
{
type Raw: From<Self>;
type Mask: SimdMask32<F32=Self>;
fn width(self) -> usize;
fn floor(self) -> Self;
fn ceil(self) -> Self;
fn round(self) -> Self;
fn abs(self) -> Self;
fn min(self, other: Self) -> Self;
fn max(self, other: Self) -> Self;
fn recip8(self) -> Self { self.recip11() }
fn recip11(self) -> Self { self.recip14() }
fn recip14(self) -> Self { self.recip16() }
fn recip16(self) -> Self { self.recip22() }
fn recip22(self) -> Self { self.recip() }
fn recip(self) -> Self;
fn rsqrt8(self) -> Self { self.rsqrt11() }
fn rsqrt11(self) -> Self { self.rsqrt14() }
fn rsqrt14(self) -> Self { self.rsqrt16() }
fn rsqrt16(self) -> Self { self.rsqrt22() }
fn rsqrt22(self) -> Self { self.rsqrt() }
fn rsqrt(self) -> Self;
fn splat(self, x: f32) -> Self;
fn steps(self) -> Self;
unsafe fn from_raw(raw: Self::Raw) -> Self;
unsafe fn load(p: *const f32) -> Self;
fn from_slice(self, slice: &[f32]) -> Self {
unsafe {
assert!(slice.len() >= self.width());
Self::load(slice.as_ptr())
}
}
unsafe fn store(self, p: *mut f32);
fn write_to_slice(self, slice: &mut [f32]) {
unsafe {
assert!(slice.len() >= self.width());
self.store(slice.as_mut_ptr());
}
}
unsafe fn create() -> Self;
fn eq(self, other: Self) -> Self::Mask;
}
pub trait SimdMask32: Sized + Copy + Clone
+ BitAnd<Self, Output=Self>
where Self::Raw: From<Self>,
{
type Raw;
type F32: SimdF32<Mask=Self>;
fn select(self, a: Self::F32, b: Self::F32) -> Self::F32;
}
pub trait F32x4: Sized + Copy + Clone
+ Add<Self, Output=Self>
+ Mul + Mul<f32, Output=Self>
+ Deref<Target=[f32; 4]>
where Self::Raw: From<Self>,
/*
[f32; 4]: From<Self>,
*/
{
type Raw;
unsafe fn create() -> Self;
unsafe fn from_raw(raw: Self::Raw) -> Self;
fn new(self, array: [f32; 4]) -> Self;
fn as_vec(self) -> [f32; 4];
}
}
pub mod combinators {
use crate::traits::{SimdF32, F32x4};
pub trait SimdFnF32 {
fn call<S: SimdF32>(&mut self, x: S) -> S;
}
pub trait ThunkF32 {
fn call<S: SimdF32>(self, cap: S);
}
pub trait ThunkF32x4 {
fn call<S: F32x4>(self, cap: S);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod avx {
use std::mem;
use std::ops::{Add, Sub, Mul, Div, Neg, BitAnd, Not};
use crate::traits::{SimdF32, SimdMask32};
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Clone, Copy)]
pub struct AvxF32(__m256);
#[derive(Clone, Copy)]
pub struct AvxMask32(__m256);
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_add_ps(a: __m256, b: __m256) -> __m256 {
_mm256_add_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_sub_ps(a: __m256, b: __m256) -> __m256 {
_mm256_sub_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_mul_ps(a: __m256, b: __m256) -> __m256 {
_mm256_mul_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_div_ps(a: __m256, b: __m256) -> __m256 {
_mm256_div_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_set1_ps(a: f32) -> __m256 {
_mm256_set1_ps(a)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_floor_ps(a: __m256) -> __m256 {
_mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_ceil_ps(a: __m256) -> __m256 {
_mm256_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_round_nearest_ps(a: __m256) -> __m256 {
_mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_min_ps(a: __m256, b: __m256) -> __m256 {
_mm256_min_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_max_ps(a: __m256, b: __m256) -> __m256 {
_mm256_max_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_rcp_ps(a: __m256) -> __m256 {
_mm256_rcp_ps(a)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_rsqrt_ps(a: __m256) -> __m256 {
_mm256_rsqrt_ps(a)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_sqrt_ps(a: __m256) -> __m256 {
_mm256_sqrt_ps(a)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_andnot_ps(a: __m256, b: __m256) -> __m256 {
_mm256_andnot_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_and_ps(a: __m256, b: __m256) -> __m256 {
_mm256_and_ps(a, b)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_setr_ps(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> __m256 {
_mm256_setr_ps(a, b, c, d, e, f, g, h)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_loadu_ps(p: *const f32) -> __m256 {
_mm256_loadu_ps(p)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_storeu_ps(p: *mut f32, a: __m256) {
_mm256_storeu_ps(p, a);
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_cmpeq_ps(a: __m256, b: __m256) -> __m256 {
_mm256_cmp_ps(a, b, _CMP_EQ_UQ)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_set1_epi32(a: i32) -> __m256i {
_mm256_set1_epi32(a)
}
#[inline]
#[target_feature(enable = "avx")]
unsafe fn avx_blendv_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
_mm256_blendv_ps(a, b, c)
}
impl Add for AvxF32 {
type Output = AvxF32;
#[inline]
fn add(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_add_ps(self.0, other.0))
}
}
}
impl Add<f32> for AvxF32 {
type Output = AvxF32;
#[inline]
fn add(self, other: f32) -> AvxF32 {
unsafe {
AvxF32(avx_add_ps(self.0, avx_set1_ps(other)))
}
}
}
impl Add<AvxF32> for f32 {
type Output = AvxF32;
#[inline]
fn add(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_add_ps(avx_set1_ps(self), other.0))
}
}
}
impl Sub for AvxF32 {
type Output = AvxF32;
#[inline]
fn sub(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_sub_ps(self.0, other.0))
}
}
}
impl Sub<f32> for AvxF32 {
type Output = AvxF32;
#[inline]
fn sub(self, other: f32) -> AvxF32 {
unsafe {
AvxF32(avx_sub_ps(self.0, avx_set1_ps(other)))
}
}
}
impl Sub<AvxF32> for f32 {
type Output = AvxF32;
#[inline]
fn sub(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_sub_ps(avx_set1_ps(self), other.0))
}
}
}
impl Mul for AvxF32 {
type Output = AvxF32;
#[inline]
fn mul(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_mul_ps(self.0, other.0))
}
}
}
impl Mul<f32> for AvxF32 {
type Output = AvxF32;
#[inline]
fn mul(self, other: f32) -> AvxF32 {
unsafe {
AvxF32(avx_mul_ps(self.0, avx_set1_ps(other)))
}
}
}
impl Mul<AvxF32> for f32 {
type Output = AvxF32;
#[inline]
fn mul(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_mul_ps(avx_set1_ps(self), other.0))
}
}
}
impl Div for AvxF32 {
type Output = AvxF32;
#[inline]
fn div(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_div_ps(self.0, other.0))
}
}
}
impl Div<f32> for AvxF32 {
type Output = AvxF32;
#[inline]
fn div(self, other: f32) -> AvxF32 {
unsafe {
AvxF32(avx_div_ps(self.0, avx_set1_ps(other)))
}
}
}
impl Div<AvxF32> for f32 {
type Output = AvxF32;
#[inline]
fn div(self, other: AvxF32) -> AvxF32 {
unsafe {
AvxF32(avx_div_ps(avx_set1_ps(self), other.0))
}
}
}
impl Neg for AvxF32 {
type Output = AvxF32;
#[inline]
fn neg(self) -> AvxF32 {
unsafe {
AvxF32(avx_sub_ps(avx_set1_ps(0.0), self.0))
}
}
}
impl From<AvxF32> for __m256 {
#[inline]
fn from(x: AvxF32) -> __m256 {
x.0
}
}
impl SimdF32 for AvxF32 {
type Raw = __m256;
type Mask = AvxMask32;
#[inline]
fn width(self) -> usize { 8 }
#[inline]
fn floor(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_floor_ps(self.0)) }
}
#[inline]
fn ceil(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_ceil_ps(self.0)) }
}
#[inline]
fn round(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_round_nearest_ps(self.0)) }
}
#[inline]
fn abs(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_andnot_ps(avx_set1_ps(-0.0), self.0)) }
}
#[inline]
fn min(self: AvxF32, b: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_min_ps(self.0, b.0)) }
}
#[inline]
fn max(self: AvxF32, b: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_max_ps(self.0, b.0)) }
}
#[inline]
fn recip11(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_rcp_ps(self.0)) }
}
#[inline]
fn recip22(self: AvxF32) -> AvxF32 {
unsafe {
AvxF32({
let est = avx_rcp_ps(self.0);
let muls = avx_mul_ps(self.0, avx_mul_ps(est, est));
avx_sub_ps(avx_add_ps(est, est), muls)
})
}
}
#[inline]
fn recip(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_div_ps(avx_set1_ps(1.0), self.0)) }
}
#[inline]
fn rsqrt11(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_rsqrt_ps(self.0)) }
}
#[inline]
fn rsqrt22(self: AvxF32) -> AvxF32 {
unsafe {
AvxF32({
let est = avx_rsqrt_ps(self.0);
let r_est = avx_mul_ps(self.0, est);
let half_est = avx_mul_ps(avx_set1_ps(0.5), est);
let muls = avx_mul_ps(r_est, est);
let three_minus_muls = avx_sub_ps(avx_set1_ps(3.0), muls);
avx_mul_ps(half_est, three_minus_muls)
})
}
}
#[inline]
fn rsqrt(self: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_div_ps(avx_set1_ps(1.0), avx_sqrt_ps(self.0))) }
}
#[inline]
fn splat(self, x: f32) -> AvxF32 {
unsafe { AvxF32(avx_set1_ps(x)) }
}
#[inline]
fn steps(self) -> AvxF32 {
unsafe { AvxF32(avx_setr_ps(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)) }
}
#[inline]
unsafe fn from_raw(raw: __m256) -> AvxF32 {
AvxF32(raw)
}
#[inline]
unsafe fn load(p: *const f32) -> AvxF32 {
AvxF32(avx_loadu_ps(p))
}
#[inline]
unsafe fn store(self, p: *mut f32) {
avx_storeu_ps(p, self.0);
}
#[inline]
unsafe fn create() -> AvxF32 {
AvxF32(avx_set1_ps(0.0))
}
#[inline]
fn eq(self, other: AvxF32) -> AvxMask32 {
unsafe { AvxMask32(avx_cmpeq_ps(self.0, other.0)) }
}
}
impl From<AvxMask32> for __m256 {
#[inline]
fn from(x: AvxMask32) -> __m256 {
x.0
}
}
impl BitAnd for AvxMask32 {
type Output = AvxMask32;
#[inline]
fn bitand(self, other: AvxMask32) -> AvxMask32 {
unsafe { AvxMask32(avx_and_ps(self.0, other.0)) }
}
}
impl Not for AvxMask32 {
type Output = AvxMask32;
#[inline]
fn not(self) -> AvxMask32 {
unsafe { AvxMask32(avx_andnot_ps(self.0, mem::transmute(avx_set1_epi32(-1)))) }
}
}
impl SimdMask32 for AvxMask32 {
type Raw = __m256;
type F32 = AvxF32;
#[inline]
fn select(self, a: AvxF32, b: AvxF32) -> AvxF32 {
unsafe { AvxF32(avx_blendv_ps(b.0, a.0, self.0)) }
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod sse42 {
use std::mem;
use std::ops::{Add, Sub, Mul, Div, Neg, BitAnd, Not, Deref};
use crate::traits::{SimdF32, SimdMask32, F32x4};
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Clone, Copy)]
pub struct Sse42F32(__m128);
#[derive(Clone, Copy)]
pub struct Sse42Mask32(__m128);
#[derive(Clone, Copy)]
pub struct Sse42F32x4(__m128);
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_add_ps(a: __m128, b: __m128) -> __m128 {
_mm_add_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_sub_ps(a: __m128, b: __m128) -> __m128 {
_mm_sub_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_mul_ps(a: __m128, b: __m128) -> __m128 {
_mm_mul_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_div_ps(a: __m128, b: __m128) -> __m128 {
_mm_div_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_set1_ps(a: f32) -> __m128 {
_mm_set1_ps(a)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_floor_ps(a: __m128) -> __m128 {
_mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_ceil_ps(a: __m128) -> __m128 {
_mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_round_nearest_ps(a: __m128) -> __m128 {
_mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_min_ps(a: __m128, b: __m128) -> __m128 {
_mm_min_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_max_ps(a: __m128, b: __m128) -> __m128 {
_mm_max_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_rcp_ps(a: __m128) -> __m128 {
_mm_rcp_ps(a)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_rsqrt_ps(a: __m128) -> __m128 {
_mm_rsqrt_ps(a)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_sqrt_ps(a: __m128) -> __m128 {
_mm_sqrt_ps(a)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_andnot_ps(a: __m128, b: __m128) -> __m128 {
_mm_andnot_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_and_ps(a: __m128, b: __m128) -> __m128 {
_mm_and_ps(a, b)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_setr_ps(a: f32, b: f32, c: f32, d: f32) -> __m128 {
_mm_setr_ps(a, b, c, d)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_loadu_ps(p: *const f32) -> __m128 {
_mm_loadu_ps(p)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_storeu_ps(p: *mut f32, a: __m128) {
_mm_storeu_ps(p, a);
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_cmpeq_ps(a: __m128, b: __m128) -> __m128 {
_mm_cmp_ps(a, b, _CMP_EQ_UQ)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_set1_epi32(a: i32) -> __m128i {
_mm_set1_epi32(a)
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_blendv_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
_mm_blendv_ps(a, b, c)
}
impl Add for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn add(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_add_ps(self.0, other.0))
}
}
}
impl Add<f32> for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn add(self, other: f32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_add_ps(self.0, sse42_set1_ps(other)))
}
}
}
impl Add<Sse42F32> for f32 {
type Output = Sse42F32;
#[inline]
fn add(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_add_ps(sse42_set1_ps(self), other.0))
}
}
}
impl Sub for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn sub(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_sub_ps(self.0, other.0))
}
}
}
impl Sub<f32> for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn sub(self, other: f32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_sub_ps(self.0, sse42_set1_ps(other)))
}
}
}
impl Sub<Sse42F32> for f32 {
type Output = Sse42F32;
#[inline]
fn sub(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_sub_ps(sse42_set1_ps(self), other.0))
}
}
}
impl Mul for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn mul(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_mul_ps(self.0, other.0))
}
}
}
impl Mul<f32> for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn mul(self, other: f32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_mul_ps(self.0, sse42_set1_ps(other)))
}
}
}
impl Mul<Sse42F32> for f32 {
type Output = Sse42F32;
#[inline]
fn mul(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_mul_ps(sse42_set1_ps(self), other.0))
}
}
}
impl Div for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn div(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_div_ps(self.0, other.0))
}
}
}
impl Div<f32> for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn div(self, other: f32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_div_ps(self.0, sse42_set1_ps(other)))
}
}
}
impl Div<Sse42F32> for f32 {
type Output = Sse42F32;
#[inline]
fn div(self, other: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32(sse42_div_ps(sse42_set1_ps(self), other.0))
}
}
}
impl Neg for Sse42F32 {
type Output = Sse42F32;
#[inline]
fn neg(self) -> Sse42F32 {
unsafe {
Sse42F32(sse42_sub_ps(sse42_set1_ps(0.0), self.0))
}
}
}
impl From<Sse42F32> for __m128 {
#[inline]
fn from(x: Sse42F32) -> __m128 {
x.0
}
}
impl SimdF32 for Sse42F32 {
type Raw = __m128;
type Mask = Sse42Mask32;
#[inline]
fn width(self) -> usize { 4 }
#[inline]
fn floor(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_floor_ps(self.0)) }
}
#[inline]
fn ceil(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_ceil_ps(self.0)) }
}
#[inline]
fn round(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_round_nearest_ps(self.0)) }
}
#[inline]
fn abs(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_andnot_ps(sse42_set1_ps(-0.0), self.0)) }
}
#[inline]
fn min(self: Sse42F32, b: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_min_ps(self.0, b.0)) }
}
#[inline]
fn max(self: Sse42F32, b: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_max_ps(self.0, b.0)) }
}
#[inline]
fn recip11(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_rcp_ps(self.0)) }
}
#[inline]
fn recip22(self: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32({
let est = sse42_rcp_ps(self.0);
let muls = sse42_mul_ps(self.0, sse42_mul_ps(est, est));
sse42_sub_ps(sse42_add_ps(est, est), muls)
})
}
}
#[inline]
fn recip(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_div_ps(sse42_set1_ps(1.0), self.0)) }
}
#[inline]
fn rsqrt11(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_rsqrt_ps(self.0)) }
}
#[inline]
fn rsqrt22(self: Sse42F32) -> Sse42F32 {
unsafe {
Sse42F32({
let est = sse42_rsqrt_ps(self.0);
let r_est = sse42_mul_ps(self.0, est);
let half_est = sse42_mul_ps(sse42_set1_ps(0.5), est);
let muls = sse42_mul_ps(r_est, est);
let three_minus_muls = sse42_sub_ps(sse42_set1_ps(3.0), muls);
sse42_mul_ps(half_est, three_minus_muls)
})
}
}
#[inline]
fn rsqrt(self: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_div_ps(sse42_set1_ps(1.0), sse42_sqrt_ps(self.0))) }
}
#[inline]
fn splat(self, x: f32) -> Sse42F32 {
unsafe { Sse42F32(sse42_set1_ps(x)) }
}
#[inline]
fn steps(self) -> Sse42F32 {
unsafe { Sse42F32(sse42_setr_ps(0.0, 1.0, 2.0, 3.0)) }
}
#[inline]
unsafe fn from_raw(raw: __m128) -> Sse42F32 {
Sse42F32(raw)
}
#[inline]
unsafe fn load(p: *const f32) -> Sse42F32 {
Sse42F32(sse42_loadu_ps(p))
}
#[inline]
unsafe fn store(self, p: *mut f32) {
sse42_storeu_ps(p, self.0);
}
#[inline]
unsafe fn create() -> Sse42F32 {
Sse42F32(sse42_set1_ps(0.0))
}
#[inline]
fn eq(self, other: Sse42F32) -> Sse42Mask32 {
unsafe { Sse42Mask32(sse42_cmpeq_ps(self.0, other.0)) }
}
}
impl From<Sse42Mask32> for __m128 {
#[inline]
fn from(x: Sse42Mask32) -> __m128 {
x.0
}
}
impl BitAnd for Sse42Mask32 {
type Output = Sse42Mask32;
#[inline]
fn bitand(self, other: Sse42Mask32) -> Sse42Mask32 {
unsafe { Sse42Mask32(sse42_and_ps(self.0, other.0)) }
}
}
impl Not for Sse42Mask32 {
type Output = Sse42Mask32;
#[inline]
fn not(self) -> Sse42Mask32 {
unsafe { Sse42Mask32(sse42_andnot_ps(self.0, mem::transmute(sse42_set1_epi32(-1)))) }
}
}
impl SimdMask32 for Sse42Mask32 {
type Raw = __m128;
type F32 = Sse42F32;
#[inline]
fn select(self, a: Sse42F32, b: Sse42F32) -> Sse42F32 {
unsafe { Sse42F32(sse42_blendv_ps(b.0, a.0, self.0)) }
}
}
impl From<Sse42F32x4> for __m128 {
#[inline]
fn from(x: Sse42F32x4) -> __m128 {
x.0
}
}
impl From<Sse42F32x4> for [f32; 4] {
#[inline]
fn from(x: Sse42F32x4) -> [f32; 4] {
x.as_vec()
}
}
impl Deref for Sse42F32x4 {
type Target = [f32; 4];
#[inline]
fn deref(&self) -> &[f32; 4] {
unsafe { mem::transmute(self) }
}
}
impl Add for Sse42F32x4 {
type Output = Sse42F32x4;
#[inline]
fn add(self, other: Sse42F32x4) -> Sse42F32x4 {
unsafe { Sse42F32x4(sse42_add_ps(self.0, other.0)) }
}
}
impl Mul for Sse42F32x4 {
type Output = Sse42F32x4;
#[inline]
fn mul(self, other: Sse42F32x4) -> Sse42F32x4 {
unsafe { Sse42F32x4(sse42_mul_ps(self.0, other.0)) }
}
}
impl Mul<f32> for Sse42F32x4 {
type Output = Sse42F32x4;
#[inline]
fn mul(self, other: f32) -> Sse42F32x4 {
unsafe { Sse42F32x4(sse42_mul_ps(self.0, sse42_set1_ps(other))) }
}
}
impl F32x4 for Sse42F32x4 {
type Raw = __m128;
#[inline]
unsafe fn from_raw(raw: __m128) -> Sse42F32x4 {
Sse42F32x4(raw)
}
#[inline]
unsafe fn create() -> Sse42F32x4 {
Sse42F32x4(sse42_set1_ps(0.0))
}
#[inline]
fn new(self, array: [f32; 4]) -> Sse42F32x4 {
union U {
array: [f32; 4],
xmm: __m128,
}
unsafe { Sse42F32x4(U { array }.xmm) }
}
#[inline]
fn as_vec(self) -> [f32; 4] {
union U {
array: [f32; 4],
xmm: __m128,
}
unsafe { U { xmm: self.0 }.array }
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod x86 {
use crate::avx::AvxF32;
use crate::sse42::{Sse42F32, Sse42F32x4};
use crate::combinators::{SimdFnF32, ThunkF32, ThunkF32x4};
use crate::traits::{SimdF32, F32x4};
pub trait GeneratorF32: Sized {
type IterF32: Iterator<Item=f32>;
type IterSse42: Iterator<Item=Sse42F32>;
type IterAvx: Iterator<Item=AvxF32>;
fn gen_f32(self, cap: f32) -> Self::IterF32;
fn gen_sse42(self, cap: Sse42F32) -> Self::IterSse42;
fn gen_avx(self, cap: AvxF32) -> Self::IterAvx;
#[inline]
fn map<F>(self, f: F) -> F32Map<Self, F>
where Self: Sized, F: SimdFnF32
{
F32Map { inner: self, f }
}
#[inline]
fn collect(self, obuf: &mut [f32]) {
if is_x86_feature_detected!("avx") {
unsafe { collect_avx(self, obuf); }
} else if is_x86_feature_detected!("sse4.2") {
unsafe { collect_sse42(self, obuf); }
} else {
let mut iter = self.gen_f32(0.0);
for i in (0..obuf.len()).step_by(1) {
let x = iter.next().unwrap();
x.write_to_slice(&mut obuf[i..]);
}
}
}
}
#[target_feature(enable = "avx")]
unsafe fn collect_avx<G: GeneratorF32>(gen: G, obuf: &mut [f32]) {
let mut iter = gen.gen_avx(AvxF32::create());
for i in (0..obuf.len()).step_by(8) {
let x = iter.next().unwrap();
x.write_to_slice(&mut obuf[i..]);
}
}
#[target_feature(enable = "sse4.2")]
unsafe fn collect_sse42<G: GeneratorF32>(gen: G, obuf: &mut [f32]) {
let mut iter = gen.gen_sse42(Sse42F32::create());
for i in (0..obuf.len()).step_by(4) {
let x = iter.next().unwrap();
x.write_to_slice(&mut obuf[i..]);
}
}
pub struct F32Map<G: GeneratorF32, F: SimdFnF32> {
inner: G,
f: F,
}
pub struct F32MapIter<S: SimdF32, I: Iterator<Item=S>, F: SimdFnF32> {
inner: I,
f: F,
}
impl<S, I, F> Iterator for F32MapIter<S, I, F>
where S: SimdF32, I: Iterator<Item=S>, F: SimdFnF32
{
type Item = S;
fn next(&mut self) -> Option<S> {
self.inner.next().map(|x| self.f.call(x))
}
}
impl<G: GeneratorF32, F: SimdFnF32> GeneratorF32 for F32Map<G, F> {
type IterF32 = F32MapIter<f32, G::IterF32, F>;
type IterSse42 = F32MapIter<Sse42F32, G::IterSse42, F>;
type IterAvx = F32MapIter<AvxF32, G::IterAvx, F>;
fn gen_f32(self, cap: f32) -> Self::IterF32 {
F32MapIter { inner: self.inner.gen_f32(cap), f: self.f }
}
fn gen_sse42(self, cap: Sse42F32) -> Self::IterSse42 {
F32MapIter { inner: self.inner.gen_sse42(cap), f: self.f }
}
fn gen_avx(self, cap: AvxF32) -> Self::IterAvx {
F32MapIter { inner: self.inner.gen_avx(cap), f: self.f }
}
}
pub struct CountGen {
init: f32,
step: f32,
}
pub struct CountStream<S: SimdF32> {
val: S,
step: f32,
}
#[inline]
pub fn count(init: f32, step: f32) -> CountGen {
CountGen { init, step }
}
impl CountGen {
#[inline]
fn gen<S: SimdF32>(self, cap: S) -> CountStream<S> {
CountStream {
val: cap.steps() * self.step + self.init,
step: self.step * (cap.width() as f32),
}
}
}
impl GeneratorF32 for CountGen {
type IterF32 = CountStream<f32>;
type IterSse42 = CountStream<Sse42F32>;
type IterAvx = CountStream<AvxF32>;
#[inline]
fn gen_f32(self, cap: f32) -> CountStream<f32> {
self.gen(cap)
}
#[inline]
fn gen_sse42(self, cap: Sse42F32) -> CountStream<Sse42F32> {
self.gen(cap)
}
#[inline]
fn gen_avx(self, cap: AvxF32) -> CountStream<AvxF32> {
self.gen(cap)
}
}
impl<S: SimdF32> Iterator for CountStream<S> {
type Item = S;
#[inline]
fn next(&mut self) -> Option<S> {
let val = self.val;
self.val = self.val + self.step;
Some(val)
}
}
#[target_feature(enable = "avx")]
unsafe fn run_f32_avx<S: ThunkF32>(thunk: S) {
thunk.call(AvxF32::create());
}
#[target_feature(enable = "sse4.2")]
unsafe fn run_f32_sse42<S: ThunkF32>(thunk: S) {
thunk.call(Sse42F32::create());
}
pub fn run_f32<S: ThunkF32>(thunk: S) {
if is_x86_feature_detected!("avx") {
unsafe { run_f32_avx(thunk); }
} else if is_x86_feature_detected!("sse4.2") {
unsafe { run_f32_sse42(thunk); }
} else {
thunk.call(0.0f32);
}
}
#[target_feature(enable = "sse4.2")]
unsafe fn run_f32x4_sse42<S: ThunkF32x4>(thunk: S) {
thunk.call(Sse42F32x4::create());
}
pub fn run_f32x4<S: ThunkF32x4>(thunk: S) {
if is_x86_feature_detected!("sse4.2") {
unsafe { run_f32x4_sse42(thunk); }
}
}
}
pub use traits::{SimdF32, SimdMask32, F32x4};
pub use combinators::{SimdFnF32, ThunkF32, ThunkF32x4};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use avx::{AvxF32, AvxMask32};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use sse42::{Sse42F32, Sse42Mask32};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use x86::{count, GeneratorF32, run_f32, run_f32x4};
struct Iir<'a> {
ibuf: &'a [f32; 32],
obuf: &'a mut [f32; 32],
coefs: [[f32; 4]; 4],
}
impl<'a> ThunkF32x4 for Iir<'a> {
#[inline]
fn call<S: F32x4>(self, cap: S) {
let c0 = cap.new(self.coefs[0]);
let c1 = cap.new(self.coefs[1]);
let c2 = cap.new(self.coefs[2]);
let c3 = cap.new(self.coefs[3]);
let mut state = cap.new([0.0, 0.0, 0.0, 0.0]);
for i in (0..self.ibuf.len()).step_by(2) {
let x0 = self.ibuf[i];
let x1 = self.ibuf[i + 1];
state = c0 * x0 + c1 * x1 + c2 * state[2] + c3 * state[3];
self.obuf[i] = state.as_vec()[0];
self.obuf[i + 1] = state.as_vec()[1];
}
}
}
fn main() {
let coefs = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let ibuf = [0.0; 32];
let mut obuf = [0.0; 32];
run_f32x4(Iir {
ibuf: &ibuf,
obuf: &mut obuf,
coefs,
});
for i in 0..obuf.len() {
eprintln!("{}", obuf[i]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment