Skip to content

Instantly share code, notes, and snippets.

@JosiahParry
Created April 2, 2024 16:16
Show Gist options
  • Save JosiahParry/7886dc2d57c70a52cc76bd7d9d77ab64 to your computer and use it in GitHub Desktop.
Save JosiahParry/7886dc2d57c70a52cc76bd7d9d77ab64 to your computer and use it in GitHub Desktop.
Recreating the vacc function with simd from advanced R. This isn't very well done. It only returns in multiples of 4. This is probably because of `array_chunks::<4>()`. Its surprisingly slower than anticipated?
#![feature(portable_simd)]
#![feature(array_chunks)]
#[extendr]
fn vacc(age: &[f64], female: &[u8], ily: &[f64]) -> Vec<f64> {
age.array_chunks::<4>()
.map(|&a| f64x4::from_array(a))
.zip(female.array_chunks::<4>().map(|&f| u8x4::from_array(f)))
.zip(ily.array_chunks::<4>().map(|&i| f64x4::from_array(i)))
.map(|((a, f), i)| {
// 0.25 + 0.3 * 1.0
let num = f64x4::splat(0.24) + f64x4::splat(0.3) * f64x4::splat(1.0);
let denom =
(f64x4::splat(1.0) - (f64x4::splat(0.04) * a).exp()) + f64x4::splat(0.1) * i;
let coef = if f == u8x4::splat(0) {
f64x4::splat(1.25)
} else {
f64x4::splat(0.75)
};
let p = num / denom;
(p * coef)
.simd_max(f64x4::splat(0.0))
.simd_min(f64x4::splat(1.0))
})
.flat_map(|x| x.to_array())
.collect::<Vec<_>>()
}
#[extendr]
fn vacc_(age: &[f64], female: &[f64], ily: &[f64]) -> Vec<f64> {
let num = f64x4::splat(0.55);
let one_quarter = f64x4::splat(0.25);
let three_tenths = f64x4::splat(0.3);
let one = f64x4::splat(1.0);
let a_coef = f64x4::splat(0.04);
let i_coef = f64x4::splat(0.1);
let male_coef = f64x4::splat(1.25);
let female_coef = f64x4::splat(0.75);
let min_val = f64x4::splat(0.0);
let zero = u8x4::splat(0);
let max_val = f64x4::splat(1.0);
age.array_chunks::<4>()
.map(|&a| f64x4::from_array(a))
.zip(female.array_chunks::<4>().map(|&f| f64x4::from_array(f)))
.zip(ily.array_chunks::<4>().map(|&i| f64x4::from_array(i)))
.map(|((a, f), i)| {
let p = one_quarter + three_tenths * one / (one - (a_coef * a).exp()) + (i_coef * i);
let coef = f * male_coef + (f - one).abs() * female_coef;
(p * coef).simd_max(min_val).simd_min(max_val)
})
.flat_map(|x| x.to_array())
.collect::<Vec<_>>()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment