Skip to content

Instantly share code, notes, and snippets.

@raphlinus
Created October 10, 2018 22:45
Show Gist options
  • Save raphlinus/d0666fcd1c454d13ab5a9aec1244b4b3 to your computer and use it in GitHub Desktop.
Save raphlinus/d0666fcd1c454d13ab5a9aec1244b4b3 to your computer and use it in GitHub Desktop.
#![feature(test)]
extern crate test;
// Note: this dependency is nightly only
extern crate packed_simd;
#[macro_use]
extern crate cfg_if;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use std::str::FromStr;
use packed_simd::{f32x4, f32x8, IntoBits};
trait SimdRounding {
fn round(self) -> Self;
}
cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "sse4"))] {
impl SimdRounding for f32x4 {
#[inline]
fn round(self) -> Self {
unsafe {
_mm_round_ps(self.into_bits(), 8).into_bits()
}
}
}
} else {
impl SimdRounding for f32x4 {
fn round(self) -> Self {
// TODO: provide fallback
unimplemented!();
}
}
}
}
cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "avx"))] {
impl SimdRounding for f32x8 {
#[inline(always)]
fn round(self) -> Self {
unsafe {
_mm256_round_ps(self.into_bits(), 8).into_bits()
}
}
}
} else {
impl SimdRounding for f32x8 {
#[inline(always)]
fn round(self) -> Self {
f32x8::new(self.extract(0).round(),
self.extract(1).round(),
self.extract(2).round(),
self.extract(3).round(),
self.extract(4).round(),
self.extract(5).round(),
self.extract(6).round(),
self.extract(7).round())
}
}
}
}
trait Steps {
fn steps() -> Self;
}
impl Steps for f32x4 {
fn steps() -> Self {
f32x4::new(0.0, 1.0, 2.0, 3.0)
}
}
impl Steps for f32x8 {
fn steps() -> Self {
f32x8::new(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
}
}
fn quadwave_scalar(freq: f32, obuf: &mut [f32]) {
let mut phase = 0.0f32;
for out in obuf {
let y = phase - phase.round();
let y = 16.0 * y * (0.5 - y.abs());
*out = y;
phase += freq;
}
}
fn sin_scalar(freq: f32, obuf: &mut [f32]) {
let mut phase = 0.0f32;
let twopi = 2.0 * ::std::f32::consts::PI;
for out in obuf {
*out = (phase * twopi).sin();
phase += freq;
phase -= phase.floor();
}
}
#[target_feature(enable = "avx")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn _mm256_abs(x: __m256) -> __m256 {
_mm256_andnot_ps(_mm256_set1_ps(-0.0), x)
}
#[target_feature(enable = "avx")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn quadwave_avx(freq: f32, obuf: &mut [f32]) {
assert!(obuf.len() % 8 == 0);
let mut i = 0;
let mut phase = _mm256_mul_ps(_mm256_setr_ps(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0),
_mm256_set1_ps(freq));
let phaseinc = _mm256_set1_ps(8.0 * freq);
loop {
if i == obuf.len() { break; }
let y = _mm256_sub_ps(phase, _mm256_round_ps(phase, 8));
let y = _mm256_mul_ps(y, _mm256_sub_ps(_mm256_set1_ps(0.5),
_mm256_abs(y)));
let y = _mm256_mul_ps(y, _mm256_set1_ps(16.0));
_mm256_storeu_ps(obuf.as_mut_ptr().add(i), y);
phase = _mm256_add_ps(phase, phaseinc);
phase = _mm256_sub_ps(phase, _mm256_floor_ps(phase));
i += 8;
}
}
#[target_feature(enable = "avx")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn sin7_avx(freq: f32, obuf: &mut [f32]) {
assert!(obuf.len() % 8 == 0);
let mut i = 0;
let mut phase = _mm256_mul_ps(_mm256_setr_ps(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0),
_mm256_set1_ps(freq));
phase = _mm256_add_ps(phase, _mm256_set1_ps(0.25));
let phaseinc = _mm256_set1_ps(8.0 * freq);
loop {
if i == obuf.len() { break; }
let a = _mm256_sub_ps(phase, _mm256_round_ps(phase, 8));
let a = _mm256_sub_ps(_mm256_abs(a), _mm256_set1_ps(0.25));
let a2 = _mm256_mul_ps(a, a);
let y = _mm256_mul_ps(a2, _mm256_set1_ps(-57.09913892));
let y = _mm256_mul_ps(a2, _mm256_add_ps(y, _mm256_set1_ps(78.3211512)));
let y = _mm256_mul_ps(a2, _mm256_add_ps(y, _mm256_set1_ps(-41.13568647)));
let y = _mm256_mul_ps(a, _mm256_add_ps(y, _mm256_set1_ps(6.27971764)));
_mm256_storeu_ps(obuf.as_mut_ptr().add(i), y);
phase = _mm256_add_ps(phase, phaseinc);
phase = _mm256_sub_ps(phase, _mm256_floor_ps(phase));
i += 8;
}
}
#[target_feature(enable = "avx")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn sin9_avx(freq: f32, obuf: &mut [f32]) {
assert!(obuf.len() % 8 == 0);
let mut i = 0;
let mut phase = _mm256_mul_ps(_mm256_setr_ps(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0),
_mm256_set1_ps(freq));
phase = _mm256_add_ps(phase, _mm256_set1_ps(0.25));
let phaseinc = _mm256_set1_ps(8.0 * freq);
while i != obuf.len() {
if i == obuf.len() { break; }
let a = _mm256_sub_ps(phase, _mm256_round_ps(phase, 8));
let a = _mm256_sub_ps(_mm256_abs(a), _mm256_set1_ps(0.25));
let a2 = _mm256_mul_ps(a, a);
// coefs:
// basis = linspace(-.5, .5, 1000)
// print polyfit(basis, sin(2*pi*basis), 9)[-2::-2]
let y = _mm256_mul_ps(a2, _mm256_set1_ps(33.15324345));
let y = _mm256_mul_ps(a2, _mm256_add_ps(y, _mm256_set1_ps(-74.66884436)));
let y = _mm256_mul_ps(a2, _mm256_add_ps(y, _mm256_set1_ps(81.39900205)));
let y = _mm256_mul_ps(a2, _mm256_add_ps(y, _mm256_set1_ps(-41.33318707)));
let y = _mm256_mul_ps(a, _mm256_add_ps(y, _mm256_set1_ps(6.28308759)));
_mm256_storeu_ps(obuf.as_mut_ptr().add(i), y);
phase = _mm256_add_ps(phase, phaseinc);
phase = _mm256_sub_ps(phase, _mm256_floor_ps(phase));
i += 8;
}
}
#[target_feature(enable = "avx")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn packed_sin9(freq: f32, obuf: &mut [f32]) {
let c0 = 6.28308759;
let c1 = -41.33318707;
let c2 = 81.39900205;
let c3 = -74.66884436;
let c4 = 33.15324345;
assert!(obuf.len() % 8 == 0);
let mut phase = f32x8::steps() * freq;
phase += 0.25;
let mut i = 0;
while i != obuf.len() {
let a = (phase - phase.round()).abs() - 0.25;
let a2 = a * a;
let y = a * (c0 + a2 * (c1 + a2 * (c2 + a2 * (c3 + a2 * c4))));
y.write_to_slice_unaligned(&mut obuf[i..i + 8]);
phase += freq * 8.0;
i += 8;
}
}
fn main() {
let freq = f32::from_str(&::std::env::args().skip(1).next().unwrap()).unwrap();
//let freq = 0.1;
let mut obuf = [0.0f32; 32];
unsafe { sin9_avx(freq, &mut obuf); }
let mut obuf2 = [0.0f32; 32];
unsafe { packed_sin9(freq, &mut obuf2); }
for (x, y) in obuf.iter().zip(obuf2.iter()) {
println!("{} {}", x, y);
}
}
#[cfg(test)]
mod tests {
use super::*;
use test::Bencher;
#[bench]
fn bench_quadwave_scalar(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
b.iter(|| quadwave_scalar(test::black_box(freq), &mut obuf));
}
#[bench]
fn bench_quadwave_avx(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
unsafe {
b.iter(|| quadwave_avx(test::black_box(freq), &mut obuf));
}
}
#[bench]
fn bench_sin_scalar(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
b.iter(|| sin_scalar(test::black_box(freq), &mut obuf));
}
#[bench]
fn bench_sin7_avx(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
unsafe {
b.iter(|| sin7_avx(test::black_box(freq), &mut obuf));
}
}
#[bench]
fn bench_sin9_avx(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
unsafe {
b.iter(|| sin9_avx(test::black_box(freq), &mut obuf));
}
}
#[bench]
fn bench_packed_sin9(b: &mut Bencher) {
let mut obuf = [0.0f32; 64];
let freq = 0.1;
unsafe {
b.iter(|| packed_sin9(test::black_box(freq), &mut obuf));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment