Skip to content

Instantly share code, notes, and snippets.

@b-ma
Last active January 20, 2023 16:41
Show Gist options
  • Save b-ma/a0909191089037b9cbebc2f7bd1c8117 to your computer and use it in GitHub Desktop.
Save b-ma/a0909191089037b9cbebc2f7bd1c8117 to your computer and use it in GitHub Desktop.
use std::pin::Pin;
use std::boxed::Box;
use std::ptr;
use std::sync::atomic::{AtomicUsize, Ordering};
// for debugging
#[allow(dead_code)]
fn print_ptr(ptr: *mut f32) {
unsafe {
let mut dst = [0.; RING_BUFFER_SIZE];
let dst_ptr = dst.as_mut_ptr();
ptr::copy_nonoverlapping(ptr, dst_ptr, RING_BUFFER_SIZE);
println!("{:?}", dst);
}
}
// const RING_BUFFER_SIZE: usize = 65536; // MAX_FFT_SIZE * 2
const RING_BUFFER_SIZE: usize = 1024; // MAX_FFT_SIZE * 2
pub struct Analyser {
buffer: Pin<Box<[f32; RING_BUFFER_SIZE]>>,
buffer_ptr: *mut f32,
index: AtomicUsize,
}
// this data structure is really not safe from the compiler point of view
unsafe impl Sync for Analyser {}
unsafe impl Send for Analyser {}
// @todo (?) - impl drop to drop the pointer manually
impl Analyser {
// runs in control thread
pub fn new() -> Self {
// inspired from https://github.com/utaal/spsc-bip-buffer/blob/master/src/lib.rs#L89
//
// from https://doc.rust-lang.org/std/boxed/struct.Box.html#method.pin
// [doc] Constructs a new Pin<Box<T>>. If T does not implement Unpin, then
// x will be pinned in memory and unable to be moved.
// -> as array does not implement Unpin, the pointer should stay coherent
let mut buffer = Box::pin([0.; RING_BUFFER_SIZE]);
let buffer_ptr = buffer.as_mut_ptr();
Self {
buffer,
buffer_ptr,
index: AtomicUsize::new(0),
}
}
// this runs in the audio thread
pub fn add_input(&self, src: &[f32]) {
let mut index = self.index.load(Ordering::SeqCst);
let len = src.len();
// push src data in ring bufer
if index + len > RING_BUFFER_SIZE {
let offset = RING_BUFFER_SIZE - index;
unsafe {
// fill end of ring buffer
let src_ptr = src.as_ptr();
let dst_ptr = self.buffer_ptr.add(index);
ptr::copy_nonoverlapping(src_ptr, dst_ptr, offset);
// restart to beginning
let src_ptr = src_ptr.add(offset);
let dst_ptr = self.buffer_ptr;
ptr::copy_nonoverlapping(src_ptr, dst_ptr, len - offset);
}
// in our conditions we can't be there yet
} else {
// we have enough room to copy src in one shot
unsafe {
let src_ptr = src.as_ptr();
let dst_ptr = self.buffer_ptr.add(index);
ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
}
}
index += len;
if index >= RING_BUFFER_SIZE {
index -= RING_BUFFER_SIZE;
}
self.index.store(index, Ordering::SeqCst);
}
// if we read only below index in control thread we are sure the memory is clean
pub fn get_float_time_domain_data(&self, dest: &mut [f32]) {
// @todo
// [spec] If array has fewer elements than the value of fftSize, the excess
// elements will be dropped. If array has more elements than the value of
// fftSize, the excess elements will be ignored.
// i.e. let len = dest.len().min(self.fft_size);
let len = dest.len();
let index = self.index.load(Ordering::SeqCst);
if len <= index {
// no need to unwrap buffer
dest[0..len].copy_from_slice(&self.buffer[(index - len)..index]);
} else {
let diff = len - index;
dest[0..diff].copy_from_slice(&self.buffer[(RING_BUFFER_SIZE - diff)..]);
dest[diff..len].copy_from_slice(&self.buffer[0..index]);
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::thread;
use float_eq::{assert_float_eq};
use rand::Rng;
use super::*;
// for now just consider render quatum size should be a power of two
const RENDER_QUANTUM_SIZE: usize = 128;
// just make sure it builds when put in an `Arc`
#[test]
fn test_arc() {
let _analyser = Arc::new(Analyser::new());
}
#[test]
fn test_add_input_aligned() {
let analyser = Analyser::new();
// check index update
{
// fill the buffer twice so we check the buffer wrap
for i in 1..3 {
for j in 0..(RING_BUFFER_SIZE / RENDER_QUANTUM_SIZE) {
let data = [i as f32; RENDER_QUANTUM_SIZE];
analyser.add_input(&data);
// check write index is properly updated
let write_index = analyser.index.load(Ordering::SeqCst);
let expected = (j * RENDER_QUANTUM_SIZE + RENDER_QUANTUM_SIZE) % RING_BUFFER_SIZE;
assert_eq!(write_index, expected);
}
// for each loop check the ring buffer is properly filled
let expected = [i as f32; RING_BUFFER_SIZE];
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12);
}
}
}
#[test]
fn test_add_input_wrap() {
// check values are written in right place
{
let analyser = Analyser::new();
let offset = 10;
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst);
let data = [1.; RENDER_QUANTUM_SIZE];
analyser.add_input(&data);
let mut expected = [0.; RING_BUFFER_SIZE];
expected.iter_mut().enumerate().for_each(|(index, v)| {
if index < RENDER_QUANTUM_SIZE - offset || index >= RING_BUFFER_SIZE - offset {
*v = 1.
} else {
*v = 0.
}
});
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12);
}
// check values are written in right order
{
let analyser = Analyser::new();
let offset = 2;
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst);
let data = [1., 2., 3., 4.];
analyser.add_input(&data);
let mut expected = [0.; RING_BUFFER_SIZE];
expected[RING_BUFFER_SIZE - 2] = 1.;
expected[RING_BUFFER_SIZE - 1] = 2.;
expected[0] = 3.;
expected[1] = 4.;
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12);
}
}
#[test]
fn test_get_float_time_domain_data_simple() {
let analyser = Arc::new(Analyser::new());
// first pass
let data = [1.; RENDER_QUANTUM_SIZE];
analyser.add_input(&data);
// index is where it should be
let index = analyser.index.load(Ordering::SeqCst);
assert_eq!(index, RENDER_QUANTUM_SIZE);
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE];
analyser.get_float_time_domain_data(&mut read_buffer);
// data is good
let expected = [1.; RENDER_QUANTUM_SIZE];
assert_float_eq!(&expected, &read_buffer, abs_all <= 1e-12);
// second pass
let data = [2.; RENDER_QUANTUM_SIZE];
analyser.add_input(&data);
// index is where it should be
let index = analyser.index.load(Ordering::SeqCst);
assert_eq!(index, RENDER_QUANTUM_SIZE * 2);
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE];
analyser.get_float_time_domain_data(&mut read_buffer);
let expected = [2.; RENDER_QUANTUM_SIZE];
assert_float_eq!(&expected, &read_buffer, abs_all <= 1e-12);
let mut full_buffer_expected = [0.; RING_BUFFER_SIZE];
full_buffer_expected[0..RENDER_QUANTUM_SIZE]
.copy_from_slice(&[1.; RENDER_QUANTUM_SIZE]);
full_buffer_expected[RENDER_QUANTUM_SIZE..(RENDER_QUANTUM_SIZE * 2)]
.copy_from_slice(&[2.; RENDER_QUANTUM_SIZE]);
assert_float_eq!(&analyser.buffer[..], &full_buffer_expected[..], abs_all <= 1e-12);
}
#[test]
fn test_get_float_time_domain_data_unwrap() {
// check values are read from right place
{
let analyser = Analyser::new();
let offset = 10;
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst);
let data = [1.; RENDER_QUANTUM_SIZE];
analyser.add_input(&data);
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE];
analyser.get_float_time_domain_data(&mut read_buffer);
assert_float_eq!(&read_buffer, &data, abs_all <= 1e-12);
}
// check values are read from right place and written in right order
{
let analyser = Analyser::new();
let offset = 2;
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst);
let data = [1., 2., 3., 4.];
analyser.add_input(&data);
let mut read_buffer = [0.; 4];
analyser.get_float_time_domain_data(&mut read_buffer);
assert_float_eq!(&read_buffer, &[1., 2., 3., 4.], abs_all <= 1e-12);
}
}
// this mostly shows that it works in concurrently and we don't fall into
// SEGFAULT traps or something, but this is difficult to really test something
// in an accurante way, other tests are there for such thing
#[test]
fn test_concurrency() {
let analyser = Arc::new(Analyser::new());
let audio_thread_analyser = analyser.clone();
let num_loops = 100_000;
let _ = thread::spawn(move || {
let mut rng = rand::thread_rng();
let mut counter = 0;
loop {
let rand = rng.gen::<f32>();
let data = [rand; RENDER_QUANTUM_SIZE];
audio_thread_analyser.add_input(&data);
counter += 1;
if counter == num_loops {
break;
}
std::thread::sleep(std::time::Duration::from_nanos(30));
}
});
// wait
std::thread::sleep(std::time::Duration::from_millis(1));
let mut counter = 0;
loop {
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE];
analyser.get_float_time_domain_data(&mut read_buffer);
counter += 1;
if counter == num_loops {
break;
}
std::thread::sleep(std::time::Duration::from_nanos(25));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment