-
-
Save b-ma/a0909191089037b9cbebc2f7bd1c8117 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::pin::Pin; | |
use std::boxed::Box; | |
use std::ptr; | |
use std::sync::atomic::{AtomicUsize, Ordering}; | |
// for debugging | |
#[allow(dead_code)] | |
fn print_ptr(ptr: *mut f32) { | |
unsafe { | |
let mut dst = [0.; RING_BUFFER_SIZE]; | |
let dst_ptr = dst.as_mut_ptr(); | |
ptr::copy_nonoverlapping(ptr, dst_ptr, RING_BUFFER_SIZE); | |
println!("{:?}", dst); | |
} | |
} | |
// const RING_BUFFER_SIZE: usize = 65536; // MAX_FFT_SIZE * 2 | |
const RING_BUFFER_SIZE: usize = 1024; // MAX_FFT_SIZE * 2 | |
pub struct Analyser { | |
buffer: Pin<Box<[f32; RING_BUFFER_SIZE]>>, | |
buffer_ptr: *mut f32, | |
index: AtomicUsize, | |
} | |
// this data structure is really not safe from the compiler point of view | |
unsafe impl Sync for Analyser {} | |
unsafe impl Send for Analyser {} | |
// @todo (?) - impl drop to drop the pointer manually | |
impl Analyser { | |
// runs in control thread | |
pub fn new() -> Self { | |
// inspired from https://github.com/utaal/spsc-bip-buffer/blob/master/src/lib.rs#L89 | |
// | |
// from https://doc.rust-lang.org/std/boxed/struct.Box.html#method.pin | |
// [doc] Constructs a new Pin<Box<T>>. If T does not implement Unpin, then | |
// x will be pinned in memory and unable to be moved. | |
// -> as array does not implement Unpin, the pointer should stay coherent | |
let mut buffer = Box::pin([0.; RING_BUFFER_SIZE]); | |
let buffer_ptr = buffer.as_mut_ptr(); | |
Self { | |
buffer, | |
buffer_ptr, | |
index: AtomicUsize::new(0), | |
} | |
} | |
// this runs in the audio thread | |
pub fn add_input(&self, src: &[f32]) { | |
let mut index = self.index.load(Ordering::SeqCst); | |
let len = src.len(); | |
// push src data in ring bufer | |
if index + len > RING_BUFFER_SIZE { | |
let offset = RING_BUFFER_SIZE - index; | |
unsafe { | |
// fill end of ring buffer | |
let src_ptr = src.as_ptr(); | |
let dst_ptr = self.buffer_ptr.add(index); | |
ptr::copy_nonoverlapping(src_ptr, dst_ptr, offset); | |
// restart to beginning | |
let src_ptr = src_ptr.add(offset); | |
let dst_ptr = self.buffer_ptr; | |
ptr::copy_nonoverlapping(src_ptr, dst_ptr, len - offset); | |
} | |
// in our conditions we can't be there yet | |
} else { | |
// we have enough room to copy src in one shot | |
unsafe { | |
let src_ptr = src.as_ptr(); | |
let dst_ptr = self.buffer_ptr.add(index); | |
ptr::copy_nonoverlapping(src_ptr, dst_ptr, len); | |
} | |
} | |
index += len; | |
if index >= RING_BUFFER_SIZE { | |
index -= RING_BUFFER_SIZE; | |
} | |
self.index.store(index, Ordering::SeqCst); | |
} | |
// if we read only below index in control thread we are sure the memory is clean | |
pub fn get_float_time_domain_data(&self, dest: &mut [f32]) { | |
// @todo | |
// [spec] If array has fewer elements than the value of fftSize, the excess | |
// elements will be dropped. If array has more elements than the value of | |
// fftSize, the excess elements will be ignored. | |
// i.e. let len = dest.len().min(self.fft_size); | |
let len = dest.len(); | |
let index = self.index.load(Ordering::SeqCst); | |
if len <= index { | |
// no need to unwrap buffer | |
dest[0..len].copy_from_slice(&self.buffer[(index - len)..index]); | |
} else { | |
let diff = len - index; | |
dest[0..diff].copy_from_slice(&self.buffer[(RING_BUFFER_SIZE - diff)..]); | |
dest[diff..len].copy_from_slice(&self.buffer[0..index]); | |
} | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::sync::Arc; | |
use std::thread; | |
use float_eq::{assert_float_eq}; | |
use rand::Rng; | |
use super::*; | |
// for now just consider render quatum size should be a power of two | |
const RENDER_QUANTUM_SIZE: usize = 128; | |
// just make sure it builds when put in an `Arc` | |
#[test] | |
fn test_arc() { | |
let _analyser = Arc::new(Analyser::new()); | |
} | |
#[test] | |
fn test_add_input_aligned() { | |
let analyser = Analyser::new(); | |
// check index update | |
{ | |
// fill the buffer twice so we check the buffer wrap | |
for i in 1..3 { | |
for j in 0..(RING_BUFFER_SIZE / RENDER_QUANTUM_SIZE) { | |
let data = [i as f32; RENDER_QUANTUM_SIZE]; | |
analyser.add_input(&data); | |
// check write index is properly updated | |
let write_index = analyser.index.load(Ordering::SeqCst); | |
let expected = (j * RENDER_QUANTUM_SIZE + RENDER_QUANTUM_SIZE) % RING_BUFFER_SIZE; | |
assert_eq!(write_index, expected); | |
} | |
// for each loop check the ring buffer is properly filled | |
let expected = [i as f32; RING_BUFFER_SIZE]; | |
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12); | |
} | |
} | |
} | |
#[test] | |
fn test_add_input_wrap() { | |
// check values are written in right place | |
{ | |
let analyser = Analyser::new(); | |
let offset = 10; | |
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst); | |
let data = [1.; RENDER_QUANTUM_SIZE]; | |
analyser.add_input(&data); | |
let mut expected = [0.; RING_BUFFER_SIZE]; | |
expected.iter_mut().enumerate().for_each(|(index, v)| { | |
if index < RENDER_QUANTUM_SIZE - offset || index >= RING_BUFFER_SIZE - offset { | |
*v = 1. | |
} else { | |
*v = 0. | |
} | |
}); | |
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12); | |
} | |
// check values are written in right order | |
{ | |
let analyser = Analyser::new(); | |
let offset = 2; | |
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst); | |
let data = [1., 2., 3., 4.]; | |
analyser.add_input(&data); | |
let mut expected = [0.; RING_BUFFER_SIZE]; | |
expected[RING_BUFFER_SIZE - 2] = 1.; | |
expected[RING_BUFFER_SIZE - 1] = 2.; | |
expected[0] = 3.; | |
expected[1] = 4.; | |
assert_float_eq!(&analyser.buffer[..], &expected[..], abs_all <= 1e-12); | |
} | |
} | |
#[test] | |
fn test_get_float_time_domain_data_simple() { | |
let analyser = Arc::new(Analyser::new()); | |
// first pass | |
let data = [1.; RENDER_QUANTUM_SIZE]; | |
analyser.add_input(&data); | |
// index is where it should be | |
let index = analyser.index.load(Ordering::SeqCst); | |
assert_eq!(index, RENDER_QUANTUM_SIZE); | |
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE]; | |
analyser.get_float_time_domain_data(&mut read_buffer); | |
// data is good | |
let expected = [1.; RENDER_QUANTUM_SIZE]; | |
assert_float_eq!(&expected, &read_buffer, abs_all <= 1e-12); | |
// second pass | |
let data = [2.; RENDER_QUANTUM_SIZE]; | |
analyser.add_input(&data); | |
// index is where it should be | |
let index = analyser.index.load(Ordering::SeqCst); | |
assert_eq!(index, RENDER_QUANTUM_SIZE * 2); | |
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE]; | |
analyser.get_float_time_domain_data(&mut read_buffer); | |
let expected = [2.; RENDER_QUANTUM_SIZE]; | |
assert_float_eq!(&expected, &read_buffer, abs_all <= 1e-12); | |
let mut full_buffer_expected = [0.; RING_BUFFER_SIZE]; | |
full_buffer_expected[0..RENDER_QUANTUM_SIZE] | |
.copy_from_slice(&[1.; RENDER_QUANTUM_SIZE]); | |
full_buffer_expected[RENDER_QUANTUM_SIZE..(RENDER_QUANTUM_SIZE * 2)] | |
.copy_from_slice(&[2.; RENDER_QUANTUM_SIZE]); | |
assert_float_eq!(&analyser.buffer[..], &full_buffer_expected[..], abs_all <= 1e-12); | |
} | |
#[test] | |
fn test_get_float_time_domain_data_unwrap() { | |
// check values are read from right place | |
{ | |
let analyser = Analyser::new(); | |
let offset = 10; | |
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst); | |
let data = [1.; RENDER_QUANTUM_SIZE]; | |
analyser.add_input(&data); | |
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE]; | |
analyser.get_float_time_domain_data(&mut read_buffer); | |
assert_float_eq!(&read_buffer, &data, abs_all <= 1e-12); | |
} | |
// check values are read from right place and written in right order | |
{ | |
let analyser = Analyser::new(); | |
let offset = 2; | |
analyser.index.store(RING_BUFFER_SIZE - offset, Ordering::SeqCst); | |
let data = [1., 2., 3., 4.]; | |
analyser.add_input(&data); | |
let mut read_buffer = [0.; 4]; | |
analyser.get_float_time_domain_data(&mut read_buffer); | |
assert_float_eq!(&read_buffer, &[1., 2., 3., 4.], abs_all <= 1e-12); | |
} | |
} | |
// this mostly shows that it works in concurrently and we don't fall into | |
// SEGFAULT traps or something, but this is difficult to really test something | |
// in an accurante way, other tests are there for such thing | |
#[test] | |
fn test_concurrency() { | |
let analyser = Arc::new(Analyser::new()); | |
let audio_thread_analyser = analyser.clone(); | |
let num_loops = 100_000; | |
let _ = thread::spawn(move || { | |
let mut rng = rand::thread_rng(); | |
let mut counter = 0; | |
loop { | |
let rand = rng.gen::<f32>(); | |
let data = [rand; RENDER_QUANTUM_SIZE]; | |
audio_thread_analyser.add_input(&data); | |
counter += 1; | |
if counter == num_loops { | |
break; | |
} | |
std::thread::sleep(std::time::Duration::from_nanos(30)); | |
} | |
}); | |
// wait | |
std::thread::sleep(std::time::Duration::from_millis(1)); | |
let mut counter = 0; | |
loop { | |
let mut read_buffer = [0.; RENDER_QUANTUM_SIZE]; | |
analyser.get_float_time_domain_data(&mut read_buffer); | |
counter += 1; | |
if counter == num_loops { | |
break; | |
} | |
std::thread::sleep(std::time::Duration::from_nanos(25)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment