Skip to content

Instantly share code, notes, and snippets.

@rlee287
Last active December 21, 2023 01:21
Show Gist options
  • Save rlee287/68cfd829ade065e88171fd32de5eb7cb to your computer and use it in GitHub Desktop.
Save rlee287/68cfd829ade065e88171fd32de5eb7cb to your computer and use it in GitHub Desktop.
Proposed OnceStorage for `oncecell::race`
use core::sync::atomic::{Ordering, AtomicU32};
use core::mem::MaybeUninit;
use core::cell::UnsafeCell;
use core::hint::spin_loop;
use core::convert::Infallible;
/// A thread-safe cell which can only be written to once.
pub struct OnceStorage<T> {
/// The actual storage of the stored object
data_holder: UnsafeCell<MaybeUninit<T>>,
/// Tracks whether the OnceStorage has been initialized
/// 0 -> no
/// 1 -> write in progress (because writing to data_holder is not atomic)
/// 2 -> init
/// This value can only ever increase in increments of 1
is_init: AtomicU32,
#[cfg(debug)]
#[cfg_attr(debug, doc(hidden))]
/// Helper counter to assert that the critical section is, in fact, only entered by one thread at a time
critical_section_ctr: AtomicU32
}
impl<T> OnceStorage<T> {
/// Creates a new empty cell.
#[inline]
pub const fn new() -> Self {
Self {
data_holder: UnsafeCell::new(MaybeUninit::uninit()),
is_init: AtomicU32::new(0)
}
}
/// Gets a reference to the underlying value.
pub fn get(&self) -> Option<&T> {
let state_snapshot = self.is_init.load(Ordering::Acquire);
if state_snapshot == 2 {
#[cfg(debug)]
assert_eq!(self.critical_section_ctr.load(Ordering::SeqCst), 0);
// SAFETY: 2 -> value is init and nobody is trying to change it
unsafe {
let mut_ptr = self.data_holder.get();
Some((&*mut_ptr as &MaybeUninit<T>).assume_init_ref())
}
} else {
debug_assert!(state_snapshot <= 1);
None
}
}
/// Forcibly sets the value of the cell and returns a mutable reference
/// to the new value.
///
/// SAFETY: The internal state must be set to the intermediate state before
/// this is called. If the internal value was already set, then
/// this function overwrites it without dropping it.
unsafe fn force_set(&self, value: T) -> &mut T {
#[cfg(debug)]
assert_eq!(self.critical_section_ctr.fetch_add(1, Ordering::SeqCst), 0);
let value_ref: &mut T;
unsafe {
let mut_ptr = self.data_holder.get();
value_ref = (&mut *mut_ptr as &mut MaybeUninit<T>).write(value);
}
#[cfg(debug)]
assert_eq!(self.critical_section_ctr.fetch_sub(1, Ordering::SeqCst), 0);
value_ref
}
/// Sets the contents of this cell to value.
///
/// Returns `Ok(())` if the cell was empty and `Err(value)` if it was full.
pub fn set(&self, value: T) -> Result<(), T> {
// Indicate that we are now trying to set the value
// If someone else is also trying, back off and let them go through
// On success we wish to release the new value, already knowing it
// On failure we don't need an ordering, as the failure already forms
// a happens-before relationship between their set and our check
if self.is_init.compare_exchange(0, 1, Ordering::Release, Ordering::Relaxed).is_err() {
return Err(value);
}
// SAFETY: state==1 -> nobody else is touching the UnsafeCell -> we can safely obtain &mut
unsafe {
self.force_set(value);
}
// Indicate that we have successfully written the value
if self.is_init.swap(2, Ordering::AcqRel) != 1 {
unreachable!("Concurrent modification to self.data_holder despite state signalling")
}
return Ok(())
}
/// Gets the contents of the cell, initializing it with `f` if the cell was
/// empty.
///
/// If several threads concurrently run `get_or_init`, more than one `f` can
/// be called. However, all threads will return the same value, produced by
/// some `f`. If this instance of `f` finishes while another instance is
/// writing its value, then this function will spinloop until that instance
/// finishes writing the new value before returning a reference to the
/// initialized value.
pub fn get_or_init_spin<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
let fn_wrap = || {
Ok::<T, Infallible> (f())
};
self.get_or_try_init_spin(fn_wrap).unwrap()
}
/// Gets the contents of the cell, initializing it with `f` if
/// the cell was empty. If the cell was empty and `f` failed, an
/// error is returned.
///
/// If several threads concurrently run `get_or_init`, more than one `f` can
/// be called. However, all threads will return the same value, produced by
/// some `f`. If this instance of `f` finishes while another instance is
/// writing its value, then this function will spinloop until that instance
/// finishes writing the new value before returning a reference to the
/// initialized value.
pub fn get_or_try_init_spin<F, E>(&self, f: F) -> Result<&T, E>
where
F: FnOnce() -> Result<T, E>
{
let mut state_snapshot = self.is_init.load(Ordering::Acquire);
if state_snapshot == 0 {
let f_value = f()?;
// Indicate that we are now trying to set the value
// If someone else is also trying, break out and wait for the other write to go through
// On success we wish to release the new value, without needing to acquire it again
// On failure we need to acquire the actual value and loop again
match self.is_init.compare_exchange(0, 1, Ordering::Release, Ordering::Acquire) {
Ok(_) => {
// SAFETY: state==1 -> nobody else is touching the UnsafeCell -> we can safely obtain &mut
let new_ref = unsafe {self.force_set(f_value)};
// Indicate that we have successfully written the value
if self.is_init.swap(2, Ordering::AcqRel) != 1 {
unreachable!("Concurrent modification to self.data_holder despite state signalling")
}
return Ok(new_ref as &T);
},
Err(new_state) => {
state_snapshot = new_state;
debug_assert!(state_snapshot==1 || state_snapshot==2);
}
}
}
while state_snapshot == 1 {
// 1 -> someone else is currently writing
// Writes (should be) fast so we won't be spinning for long
state_snapshot = self.is_init.load(Ordering::Acquire);
spin_loop();
}
debug_assert_eq!(state_snapshot, 2);
unsafe {
let mut_ptr = self.data_holder.get();
return Ok((&*mut_ptr as &MaybeUninit<T>).assume_init_ref());
}
}
/// ``` compile_fail
/// # use tiva_c_secure_bootloader_rustlib::OnceStorage;
/// #
/// // Ensure that OnceStorage<T> is invariant over T lifetime subtypes
/// let heap_object = std::vec::Vec::from([1,2,3,4]);
/// let once_storage = OnceStorage::new();
/// once_storage.set(&heap_object).unwrap();
/// drop(heap_object);
/// // The stored reference is no longer live because vec is dropped
/// // The following line should fail to compile
/// let _ref = once_storage.get();
/// ```
fn _dummy() {}
}
impl<T> Drop for OnceStorage<T> {
fn drop(&mut self) {
let state = self.is_init.load(Ordering::Acquire);
// &mut self -> nobody else can try to init -> value can't be 1
// If we somehow do, then we leak the set value, which is safer than
// incorrectly freeing it
debug_assert_ne!(state, 1);
if state == 2 {
unsafe {
let mut_ptr = self.data_holder.get();
(&mut *mut_ptr as &mut MaybeUninit<T>).assume_init_drop();
}
}
}
}
unsafe impl<T: Send+Sync> Sync for OnceStorage<T> {}
#[cfg(test)]
mod tests {
use super::*;
extern crate std;
#[test]
fn test_should_compile_static() {
let heap_object = std::vec::Vec::from([1,2,3,4]);
let once_storage = OnceStorage::new();
once_storage.set(heap_object).unwrap();
let _ref = once_storage.get();
assert_eq!(_ref.unwrap(), &[1,2,3,4]);
drop(once_storage);
}
#[test]
fn test_init_only_once() {
const THREAD_COUNT: usize = 20;
use std::sync::Barrier;
use std::sync::atomic::AtomicU32;
let init_ctr = AtomicU32::new(0);
let barrier_obj = Barrier::new(THREAD_COUNT+1);
let once_storage = OnceStorage::new();
std::thread::scope(|s| {
// Start the threads...
for _ in 0..THREAD_COUNT {
s.spawn(|| {
barrier_obj.wait();
if once_storage.set(std::vec::Vec::from([std::string::String::from("abcd")])).is_ok() {
init_ctr.fetch_add(1, Ordering::Relaxed);
}
});
}
// ...and let them hammer the OnceStorage
barrier_obj.wait();
});
// Ensure that writes to init_ctr are now visible
std::sync::atomic::fence(Ordering::Acquire);
// Check that object was only initialized once
assert_eq!(init_ctr.load(Ordering::Acquire), 1);
// Now read from the vec so that Miri can catch invalid accesses
assert_eq!(once_storage.get().unwrap().len(), 1);
assert_eq!(once_storage.get().unwrap()[0], "abcd");
}
#[test]
fn test_should_compile_nonstatic() {
let heap_object = std::vec::Vec::from([1,2,3,4]);
let once_storage = OnceStorage::new();
once_storage.set(&heap_object).unwrap();
let _ref = once_storage.get();
drop(once_storage);
drop(heap_object);
}
}
@jyn514
Copy link

jyn514 commented Dec 21, 2023

diff --git a/once_storage.rs b/once_storage.rs
index d501e4a..94d2d77 100644
--- a/once_storage.rs
+++ b/once_storage.rs
@@ -14,8 +14,9 @@ pub struct OnceStorage<T> {
     /// 2 -> init
     /// This value can only ever increase in increments of 1
     is_init: AtomicU32,
-    #[cfg(debug)]
-    #[cfg_attr(debug, doc(hidden))]
+    // and elsewhere in this file; in practice no one is ever going to use `--cfg debug` and this approach will integrate nicely with cargo's debug profile
+    #[cfg(debug_assertions)]
+    // doc(hidden) doesn't do anything for private fields; and with `--document-private-items` you probably *do* want to document it
     /// Helper counter to assert that the critical section is, in fact, only entered by one thread at a time
     critical_section_ctr: AtomicU32
 }
@@ -25,7 +26,9 @@ impl<T> OnceStorage<T> {
     pub const fn new() -> Self {
         Self {
             data_holder: UnsafeCell::new(MaybeUninit::uninit()),
-            is_init: AtomicU32::new(0)
+            is_init: AtomicU32::new(0),
+            #[cfg(debug_assertions)]
+            critical_section_ctr: AtomicU32::new(0),
         }
     }
     /// Gets a reference to the underlying value.
@@ -34,6 +37,8 @@ impl<T> OnceStorage<T> {
         if state_snapshot == 2 {
             #[cfg(debug)]
             assert_eq!(self.critical_section_ctr.load(Ordering::SeqCst), 0);
+            // you have these two lines copied a lot in this code. they are rather subtle unsafe code. 
+            // i would suggest adding an `unsafe fn get_data(&self) -> &T` helper.
             // SAFETY: 2 -> value is init and nobody is trying to change it
             unsafe {
                 let mut_ptr = self.data_holder.get();
@@ -53,17 +58,18 @@ impl<T> OnceStorage<T> {
     unsafe fn force_set(&self, value: T) -> &mut T {
         #[cfg(debug)]
         assert_eq!(self.critical_section_ctr.fetch_add(1, Ordering::SeqCst), 0);
-        let value_ref: &mut T;
-        unsafe {
+        // nit:
+        // (this will also let you refactor this unsafe block into an `unsafe fn` if desired)
+        let value_ref: &mut T = unsafe {
             let mut_ptr = self.data_holder.get();
-            value_ref = (&mut *mut_ptr as &mut MaybeUninit<T>).write(value);
-        }
+            (&mut *mut_ptr as &mut MaybeUninit<T>).write(value)
+        };
         #[cfg(debug)]
         assert_eq!(self.critical_section_ctr.fetch_sub(1, Ordering::SeqCst), 0);
         value_ref
     }
 
-    /// Sets the contents of this cell to value.
+    /// Sets the contents of this cell to `value`.
     ///
     /// Returns `Ok(())` if the cell was empty and `Err(value)` if it was full.
     pub fn set(&self, value: T) -> Result<(), T> {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment