Skip to content

Instantly share code, notes, and snippets.

@LeshaInc

LeshaInc/ring.rs Secret

Created May 4, 2024 00:59
Show Gist options
  • Save LeshaInc/e39b415928777e34a457e99ccc1b0744 to your computer and use it in GitHub Desktop.
Save LeshaInc/e39b415928777e34a457e99ccc1b0744 to your computer and use it in GitHub Desktop.
use std::alloc::Layout;
use std::marker::PhantomData;
use std::mem::{ManuallyDrop, MaybeUninit};
use std::ptr::NonNull;
#[cfg(not(loom))]
use std::sync::atomic::AtomicUsize;
#[cfg(not(loom))]
use std::sync::atomic::Ordering::{Acquire, Release};
use crossbeam_utils::CachePadded;
#[cfg(loom)]
use loom::sync::atomic::AtomicUsize;
#[cfg(loom)]
use loom::sync::atomic::Ordering::{Acquire, Release};
/// Creates a lock-free SPSC ring buffer.
///
/// `capacity` must be a power of 2 between `1` and `usize::MAX / 2`.
pub fn buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
let buffer = ManuallyDrop::new(ProcessLocalBuffer::new(capacity));
// SAFETY: pointer is valid since we got it from a reference, `ManuallyDrop` will prevent double drop.
// Inside, `ProcessLocalBuffer` is just a pointer, so it's fine to copy it
let buffer_copy = unsafe { std::ptr::read(&buffer) };
// SAFETY: a newly created buffer will satisfy `read_state == write_state == 0`.
let producer = unsafe { Producer::new(buffer_copy) };
let consumer = unsafe { Consumer::new(buffer) };
(producer, consumer)
}
/// Underlying storage of the ring buffer.
///
/// # Safety
///
/// `data_ptr()` must return a valid pointer to an array of at least `capacity` of `T`.
///
/// `capacity()` must return a power of two in the range of `[1, usize::MAX / 2]`.
pub unsafe trait Buffer<T> {
/// Must be a valid pointer to an array of at least `capacity` of `T` (possibly uninitialized).
fn data_ptr(&self) -> *mut MaybeUninit<T>;
/// Capacity is in the range of `[1, usize::MAX / 2]`.
fn capacity(&self) -> usize;
/// Read state consists of the read index (lowest N-1 bits), and a flag (highest bit)
/// signifying that the consumer is closed.
fn read_state(&self) -> &AtomicUsize;
/// Write state consists of the write index (lowest N-1 bits), and a flag (highest bit)
/// signifying that the producer is closed.
fn write_state(&self) -> &AtomicUsize;
}
const INDEX_MASK: usize = usize::MAX / 2;
const CLOSED_MASK: usize = 1 << (usize::BITS - 1);
struct Header {
capacity: usize,
read_state: CachePadded<AtomicUsize>,
write_state: CachePadded<AtomicUsize>,
}
/// Ring buffer stored in a normal process-local allocation.
pub struct ProcessLocalBuffer<T> {
inner: NonNull<Header>,
marker: PhantomData<T>,
}
unsafe impl<T: Send + Sync> Send for ProcessLocalBuffer<T> {}
unsafe impl<T: Send + Sync> Sync for ProcessLocalBuffer<T> {}
impl<T> ProcessLocalBuffer<T> {
fn new(capacity: usize) -> ProcessLocalBuffer<T> {
assert!(capacity > 0 && capacity <= INDEX_MASK && capacity.is_power_of_two());
let layout = Self::layout(capacity);
// SAFETY: size cannot be zero because we're storing at least `Header`.
let ptr = unsafe { std::alloc::alloc(layout) as *mut Header };
let Some(inner) = NonNull::new(ptr) else {
std::alloc::handle_alloc_error(layout);
};
let header = Header {
capacity,
read_state: CachePadded::new(AtomicUsize::new(0)),
write_state: CachePadded::new(AtomicUsize::new(0)),
};
// SAFETY: pointer is valid, because we've just allocated it and checked the result.
unsafe { std::ptr::write(inner.as_ptr(), header) };
ProcessLocalBuffer {
inner,
marker: PhantomData,
}
}
fn header(&self) -> &Header {
// SAFETY: pointer is valid until self is dropped
unsafe { self.inner.as_ref() }
}
fn offset() -> usize {
Layout::new::<Header>()
.extend(Layout::new::<T>())
.unwrap()
.1
}
fn layout(capacity: usize) -> Layout {
Layout::new::<Header>()
.extend(Layout::array::<T>(capacity).unwrap())
.unwrap()
.0
}
}
unsafe impl<T> Buffer<T> for ProcessLocalBuffer<T> {
fn data_ptr(&self) -> *mut MaybeUninit<T> {
let offset = Self::offset();
// SAFETY: after adding `offset`, pointer is still in bounds (since capacity is > 0)
// and properly aligned (enforced by `Layout::extend`).
unsafe { (self.inner.as_ptr() as *mut u8).add(offset) as *mut MaybeUninit<T> }
}
fn capacity(&self) -> usize {
self.header().capacity
}
fn read_state(&self) -> &AtomicUsize {
&self.header().read_state
}
fn write_state(&self) -> &AtomicUsize {
&self.header().write_state
}
}
impl<T> Drop for ProcessLocalBuffer<T> {
fn drop(&mut self) {
let layout = Self::layout(self.capacity());
// SAFETY: block of memory was allocated with `std::alloc::alloc` using the same `layout`.
unsafe { std::alloc::dealloc(self.inner.as_ptr() as *mut u8, layout) };
}
}
/// Producing side of the SPSC ring buffer.
pub struct Producer<T, B: Buffer<T> = ProcessLocalBuffer<T>> {
buffer: ManuallyDrop<B>,
read_state: usize,
write_idx: usize,
marker: PhantomData<T>,
}
impl<T, B: Buffer<T>> Producer<T, B> {
/// SAFETY: assuming `read_state == write_state == 0`.
unsafe fn new(buffer: ManuallyDrop<B>) -> Producer<T, B> {
Producer {
buffer,
read_state: 0,
write_idx: 0,
marker: PhantomData,
}
}
/// Updates cached state, synchronizing with the consumer.
///
/// Will affect the following methods:
/// - [`Self::len()`].
/// - [`Self::is_closed()`].
/// - [`Self::is_empty()`].
/// - [`Self::is_full()`].
///
/// If the buffer isn't full but the consumer is droped, and the producer isn't aware of that, [`Self::push()`] will
/// still allow pushing new values. without any errors. This is done for performance reasons.
///
/// If you want to make sure this doesn't happen, call [`Self::refresh()`] before `Self::push()`.
pub fn refresh(&mut self) {
// Using `Acquire` here to establish a happens-after relationship with `Consumer::pop()` and `Consumer::drop()`.
self.read_state = self.buffer.read_state().load(Acquire);
}
fn read_idx(&self) -> usize {
self.read_state & INDEX_MASK
}
/// Returns `true` if the consumer has been dropped.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_closed(&self) -> bool {
self.read_state & CLOSED_MASK == CLOSED_MASK
}
/// Returns the number of elements currently stored in the buffer.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn len(&self) -> usize {
self.write_idx.wrapping_sub(self.read_idx()) & INDEX_MASK
}
/// Returns the capacity of the buffer.
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
fn data_ptr(&self) -> *mut MaybeUninit<T> {
self.buffer.data_ptr()
}
/// Returns `true`, if the buffer is empty.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_empty(&self) -> bool {
self.read_idx() == self.write_idx
}
/// Returns `true`, if the buffer is at full capacity.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
/// Pushes an element into the ring buffer in a FIFO manner. The consumer will see elements in the same order.
///
/// If the buffer is full, will update the cached state and check again.
///
/// Then, if the buffer is closed, will return [`PushError::Closed`], and if it's full -- [`PushError::Full`].
///
/// Otherwise will store the element inside the ring buffer, returning `Ok(())`.
pub fn push(&mut self, value: T) -> Result<(), PushError<T>> {
if self.is_full() {
self.refresh();
if self.is_closed() {
return Err(PushError::Closed(value));
} else if self.is_full() {
return Err(PushError::Full(value));
}
}
// Now we are certain that there is space in the buffer:
// - We are the sole producer, so no other thread can push more elements here (enforced by &mut).
// - The consumer can `pop()` some elements, but this would only free space at the start of the buffer,
// not affecting the end of the buffer, where we'll be writing.
// `capacity` is a power of two, using bitwise and instead of modulo.
let idx = self.write_idx & (self.capacity() - 1);
// SAFETY: `data_ptr` is valid, `idx` is in bounds
// The value is assumed to be uninitialized, so we make sure not to read it.
unsafe { self.data_ptr().add(idx).write(MaybeUninit::new(value)) };
// Move the write index, marking this slot used, manually handling overflow
if self.write_idx == INDEX_MASK {
self.write_idx = 0;
} else {
self.write_idx = self.write_idx.wrapping_add(1);
}
// Update the write index in the shared buffer, notifying the consumer.
// Using `Release` ordering to establish a happens-before relationship with `Consumer::refresh()`.
self.buffer.write_state().store(self.write_idx, Release);
Ok(())
}
}
impl<T, B: Buffer<T>> Drop for Producer<T, B> {
fn drop(&mut self) {
// It's very important to refresh before updating read or write state here:
// - `Consumer::refresh` happens before updating `write_state`;
// - `Producer::refresh` happens before updating `read_state`.
//
// If there is a race between `Consumer::drop` and `Producer::drop`, there are two outcomes:
//
// - `Consumer::refresh` sees the updated `write_state`;
// `Producer::refresh` doesn't see the updated `read_state`;
// Outcome: `drop_buffer` is called by the consumer.
//
// - `Consumer::refresh` doesn't see the updated `write_state`;
// `Producer::refresh` sees the updated `read_state`;
// Outcome: `drop_buffer` is called by the producer.
//
// In both outcomes, `drop_buffer` is called only once.
// If we swapped the following lines, this would not be the case.
self.refresh();
// Using `Release` ordering to establish a happens-before relationship with `Consumer::refresh()`.
self.buffer.write_state().fetch_or(CLOSED_MASK, Release);
if self.is_closed() {
let read_idx = self.read_idx();
let write_idx = self.write_idx;
// SAFETY: Consumer is also closed, so we can drop the buffer
unsafe { drop_buffer(&mut self.buffer, read_idx, write_idx) };
}
}
}
/// Error returned from [`Producer::push()`].
#[derive(Debug, Clone, Copy, Eq, PartialEq, thiserror::Error)]
pub enum PushError<T> {
/// The buffer is full, try again later.
#[error("full")]
Full(T),
/// [`Consumer`] has been dropped.
#[error("closed")]
Closed(T),
}
/// Consuming side of the SPSC ring buffer.
pub struct Consumer<T, B: Buffer<T> = ProcessLocalBuffer<T>> {
buffer: ManuallyDrop<B>,
read_idx: usize,
write_state: usize,
marker: PhantomData<T>,
}
impl<T, B: Buffer<T>> Consumer<T, B> {
/// SAFETY: assuming `read_state == write_state == 0`
unsafe fn new(buffer: ManuallyDrop<B>) -> Consumer<T, B> {
Consumer {
buffer,
read_idx: 0,
write_state: 0,
marker: PhantomData,
}
}
/// Updates cached state, synchronizing with the producer.
///
/// Will affect the following methods:
/// - [`Self::len()`].
/// - [`Self::is_closed()`].
/// - [`Self::is_empty()`].
/// - [`Self::is_full()`].
///
/// No need to call before [`Self::pop()`]
pub fn refresh(&mut self) {
// Using `Acquire` here to establish a happens-after relationship with `Producer::pop()` and `Producer::drop()`.
self.write_state = self.buffer.write_state().load(Acquire);
}
fn write_idx(&self) -> usize {
self.write_state & INDEX_MASK
}
/// Returns `true` if the producer has been dropped.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_closed(&self) -> bool {
self.write_state & CLOSED_MASK == CLOSED_MASK
}
/// Returns the number of elements currently stored in the buffer.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn len(&self) -> usize {
self.write_idx().wrapping_sub(self.read_idx) & INDEX_MASK
}
/// Returns the capacity of the buffer.
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
fn data_ptr(&self) -> *mut MaybeUninit<T> {
self.buffer.data_ptr()
}
/// Returns `true`, if the buffer is empty.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_empty(&self) -> bool {
self.write_idx() == self.read_idx
}
/// Returns `true`, if the buffer is at full capacity.
///
/// Uses the cached state. To update, call [`Self::refresh()`].
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
/// Pops an element from the buffer in a FIFO manner, i.e. the same order in which they have been pushed.
///
/// If the buffer is empty, will update the cached state and check again.
///
/// Then, if the buffer is still empty, will return either of [`PushError::Full`] or [`PushError::Closed`],
/// depending on whether or not the producer still exists and more values can be pushed.
///
/// If the buffer isn't empty, will pop the first value and return it.
pub fn pop(&mut self) -> Result<T, PopError> {
if self.is_empty() {
self.refresh();
if self.is_empty() {
if self.is_closed() {
return Err(PopError::Closed);
} else {
return Err(PopError::Empty);
}
}
}
// Now we are certain that there's an element in the buffer:
// - We are the sole consumer, and no other thread can `pop()` here (enforced by &mut).
// - The producer can `push()`, affecting the end of the buffer, but not the start (we never overwrite elements).
// `capacity` is a power of two, using bitwise and instead of modulo.
let idx = self.read_idx & (self.capacity() - 1);
// SAFETY: `data_ptr` is valid, `idx` is in its bounds, the value has been initialzied by the producer.
let value = unsafe { self.data_ptr().add(idx).read().assume_init() };
// Move the read index, marking this slot free, manually handling overflow.
if self.read_idx == INDEX_MASK {
self.read_idx = 0;
} else {
self.read_idx = self.read_idx.wrapping_add(1);
}
// Update the write index in the shared buffer, notifying the producer.
// Using `Release` ordering to establish a happens-before relationship with `Consumer::refresh`.
self.buffer.read_state().store(self.read_idx, Release);
Ok(value)
}
}
impl<T, B: Buffer<T>> Drop for Consumer<T, B> {
fn drop(&mut self) {
// It's very important to refresh before updating read or write state here:
// - `Consumer::refresh` happens before updating `write_state`;
// - `Producer::refresh` happens before updating `read_state`.
//
// If there is a race between `Consumer::drop` and `Producer::drop`, there are two outcomes:
//
// - `Consumer::refresh` sees the updated `write_state`;
// `Producer::refresh` doesn't see the updated `read_state`;
// Outcome: `drop_buffer` is called by the consumer.
//
// - `Consumer::refresh` doesn't see the updated `write_state`;
// `Producer::refresh` sees the updated `read_state`;
// Outcome: `drop_buffer` is called by the producer.
//
// In both outcomes, `drop_buffer` is called only once.
// If we swapped the following lines, this would not be the case.
self.refresh();
// Using `Release` ordering to establish a happens-before relationship with `Consumer::refresh`.
self.buffer.read_state().fetch_or(CLOSED_MASK, Release);
if self.is_closed() {
let read_idx = self.read_idx;
let write_idx = self.write_idx();
// SAFETY: Producer is also closed, so we can drop the buffer.
unsafe { drop_buffer(&mut self.buffer, read_idx, write_idx) };
}
}
}
/// SAFETY: both producer and consumer should be closed, and only one of them should call `drop_buffer`.
unsafe fn drop_buffer<T, B: Buffer<T>>(
buffer: &mut ManuallyDrop<B>,
mut read_idx: usize,
write_idx: usize,
) {
// At this point, both sides of the buffer are closed.
if std::mem::needs_drop::<T>() {
while read_idx != write_idx {
// `capacity` is a power of two, using bitwise and instead of modulo.
let idx = read_idx & (buffer.capacity() - 1);
// SAFETY: `data_ptr` is valid, `idx` is in its bounds, the value has been initialzied by the producer.
unsafe { std::ptr::drop_in_place(buffer.data_ptr().add(idx)) };
read_idx = read_idx.wrapping_add(1);
}
}
// SAFETY: buffer can't be accessed afterwards.
unsafe { ManuallyDrop::drop(buffer) };
}
/// Error returned from [`Consumer::pop()`].
#[derive(Debug, Clone, Copy, Eq, PartialEq, thiserror::Error)]
pub enum PopError {
/// The buffer is full, try again later.
#[error("empty")]
Empty,
/// [`Producer`] has been dropped.
#[error("closed")]
Closed,
}
#[cfg(test)]
mod tests {
#[cfg(not(loom))]
use std::thread;
#[cfg(loom)]
use loom::thread;
use super::*;
#[test]
#[cfg(not(loom))]
fn test_sequential_copy() {
let (mut producer, mut consumer) = buffer(4);
assert_eq!(producer.push(0), Ok(()));
assert_eq!(consumer.pop(), Ok(0));
assert_eq!(consumer.pop(), Err(PopError::Empty));
assert_eq!(producer.push(1), Ok(()));
assert_eq!(producer.push(2), Ok(()));
assert_eq!(producer.push(3), Ok(()));
assert_eq!(producer.push(4), Ok(()));
assert_eq!(producer.push(5), Err(PushError::Full(5)));
assert_eq!(consumer.pop(), Ok(1));
assert_eq!(consumer.pop(), Ok(2));
assert_eq!(consumer.pop(), Ok(3));
assert_eq!(consumer.pop(), Ok(4));
assert_eq!(consumer.pop(), Err(PopError::Empty));
drop(producer);
assert_eq!(consumer.pop(), Err(PopError::Closed));
}
#[test]
#[cfg(not(loom))]
fn test_sequential_zst() {
let (mut producer, mut consumer) = buffer(4);
assert_eq!(producer.push(()), Ok(()));
assert_eq!(consumer.pop(), Ok(()));
assert_eq!(consumer.pop(), Err(PopError::Empty));
assert_eq!(producer.push(()), Ok(()));
assert_eq!(producer.push(()), Ok(()));
assert_eq!(producer.push(()), Ok(()));
assert_eq!(producer.push(()), Ok(()));
assert_eq!(producer.push(()), Err(PushError::Full(())));
assert_eq!(consumer.pop(), Ok(()));
assert_eq!(consumer.pop(), Ok(()));
assert_eq!(consumer.pop(), Ok(()));
assert_eq!(consumer.pop(), Ok(()));
assert_eq!(consumer.pop(), Err(PopError::Empty));
drop(producer);
assert_eq!(consumer.pop(), Err(PopError::Closed));
}
#[test]
#[cfg(not(loom))]
fn test_sequential_drop() {
let (mut producer, mut consumer) = buffer::<String>(4);
assert_eq!(producer.push("0".into()), Ok(()));
assert_eq!(consumer.pop(), Ok("0".into()));
assert_eq!(consumer.pop(), Err(PopError::Empty));
assert_eq!(producer.push("1".into()), Ok(()));
assert_eq!(producer.push("2".into()), Ok(()));
assert_eq!(producer.push("3".into()), Ok(()));
assert_eq!(producer.push("4".into()), Ok(()));
assert_eq!(producer.push("5".into()), Err(PushError::Full("5".into())));
assert_eq!(consumer.pop(), Ok("1".into()));
assert_eq!(consumer.pop(), Ok("2".into()));
assert_eq!(consumer.pop(), Ok("3".into()));
assert_eq!(consumer.pop(), Ok("4".into()));
assert_eq!(consumer.pop(), Err(PopError::Empty));
drop(producer);
assert_eq!(consumer.pop(), Err(PopError::Closed));
}
fn parallel_drop() {
let (mut producer, mut consumer) = buffer::<String>(2);
let expected_vec = vec!["1".to_string(), "2".to_string(), "3".to_string()];
let vec = expected_vec.clone();
let t1 = thread::spawn(move || {
for value in vec {
let mut value = Some(value);
loop {
match producer.push(value.take().unwrap()) {
Ok(()) => break,
Err(PushError::Full(v)) => {
value = Some(v);
thread::yield_now();
}
Err(PushError::Closed(_)) => panic!("couldn't send all values"),
}
}
}
});
let t2 = thread::spawn(move || {
let mut vec = Vec::new();
loop {
match consumer.pop() {
Ok(v) => vec.push(v),
Err(PopError::Empty) => {
thread::yield_now();
}
Err(PopError::Closed) => break,
}
}
assert_eq!(vec, expected_vec);
});
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
#[cfg(not(loom))]
fn test_parallel_drop() {
parallel_drop();
}
#[test]
#[cfg(loom)]
fn loom_test_parallel_drop() {
loom::model(|| parallel_drop());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment