Skip to content

Instantly share code, notes, and snippets.

@Lucretiel
Last active May 23, 2022 18:56
Show Gist options
  • Save Lucretiel/65a5929c4de56613cc501dde8ab540d1 to your computer and use it in GitHub Desktop.
Save Lucretiel/65a5929c4de56613cc501dde8ab540d1 to your computer and use it in GitHub Desktop.
A paired smart pointer. The value owned by a pair of Joint objects is dropped as soon as *either* of the Joints are dropped
use std::{
marker::PhantomData,
mem::MaybeUninit,
ops::Deref,
process::abort,
ptr::NonNull,
sync::atomic::{AtomicU32, Ordering},
};
const MAX_REFCOUNT: u32 = i32::MAX as u32;
struct JointContainer<T> {
value: MaybeUninit<T>,
// Special states:
//
// 0: It has been dropped. No new handles should be created. Only one
// handle exists, and when it's dropped, the container can be freed.
// 1: It is being dropped. The dropping thread will need to check the
// state after it's done.
// Normal states:
// 2+: there are N handles in existence. When the count drops to 1, we
// begin to drop.
count: AtomicU32,
}
#[repr(transparent)]
pub struct Joint<T> {
container: NonNull<JointContainer<T>>,
phantom: PhantomData<JointContainer<T>>,
}
unsafe impl<T: Send + Sync> Send for Joint<T> {}
unsafe impl<T: Send + Sync> Sync for Joint<T> {}
impl<T> Joint<T> {
// Note that, while it's guaranteed that the container exists, it's not
// guaranteed that the value is in an initialized state.
fn container(&self) -> &JointContainer<T> {
unsafe { self.container.as_ref() }
}
pub fn new(value: T) -> (Self, Self) {
let container = Box::new(JointContainer {
value: MaybeUninit::new(value),
count: AtomicU32::new(2),
});
let container = NonNull::new(Box::into_raw(container)).expect("box is definitely non null");
(
Joint {
container,
phantom: PhantomData,
},
Joint {
container,
phantom: PhantomData,
},
)
}
pub fn lock(&self) -> Option<JointLock<'_, T>> {
// Increasing the reference count can always be done with Relaxed– New
// references to an object can only be formed from an existing
// reference, and passing an existing reference from one thread to
// another must already provide any required synchronization.
let mut current = self.container().count.load(Ordering::Relaxed);
loop {
// We can only lock this if *both* handles currently exist.
// TODO: prevent the distribution of new locks after the other
// handle has dropped (currently, if this handle has some
// outstanding locks, it may create more). In general we're not
// worried because the typical usage pattern is that each joint
// will only ever make 1 lock at a time.
current = match current {
0 | 1 => break None,
n if n > MAX_REFCOUNT => abort(),
n => match self.container().count.compare_exchange_weak(
n,
n + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
break Some(JointLock {
container: self.container,
lifetime: PhantomData,
})
}
Err(n) => n,
},
}
}
}
}
impl<T> Drop for Joint<T> {
fn drop(&mut self) {
let mut current = self.container().count.load(Ordering::Acquire);
// Note that all of the failures in the compare-exchanges here are
// Acquire ordering, because failures could indicate that the other
// handle dropped, meaning that we need to acquire its changes before
// we start dropping or deallocating anything. Additionally, note that
// we *usually* don't need to release anything here, because `Joint`
// isn't itself capable of writing to `value` (only JointLock can do
// that, and it *does* release on drop.)
loop {
current = match current {
// The handle has been fully dropped, this is the last
// remaining handle in existence
0 => {
drop(unsafe { Box::from_raw(self.container.as_ptr()) });
return;
}
n => match self.container().count.compare_exchange_weak(
n,
n - 1,
Ordering::Acquire,
Ordering::Acquire,
) {
// All failures, spurious or otherwise, need to be retried.
// There's no "fast escape" case because we always need to
// ensure that n - 1 was stored.
Err(n) => n,
// Another thread is in the middle of dropping the value.
// We stored a 0, so it will also take care of deallocating
// the container.
Ok(1) => return,
// This is the second to last handle in existence, which
// means it's time to drop the value. Don't need to release
// anything until after the drop is finished; other threads
// won't be touching the value while we're in this state.
Ok(2) => {
unsafe { (*self.container.as_ptr()).value.assume_init_drop() };
loop {
// At this point we need to release store the 0, to
// ensure our drop propagates to other threads. We
// did the drop, so there's no other changes we
// might need to acquire. If we find there's already
// a zero, the last handle dropped, so we handle
// deallocating.
match self.container().count.compare_exchange_weak(
1,
0,
// Don't need to acquire in this case because we also did the
// drop ourselves.
Ordering::Release,
Ordering::Relaxed,
) {
// We stored a zero; the other Joint will be responsible
// for deallocating the container
Ok(_) => return,
// There was already a 0; the last handle dropped while we
// were dropping the value. Deallocate.
//
// There's no risk of another thread loading this same 0, because
// we know the only other reference in existence is the other Joint.
// we stored a 1, so it can never create more locks; either it will
// store a 0 (detected here) or we'll store a 0 that it will load.
Err(0) => {
drop(unsafe { Box::from_raw(self.container.as_ptr()) });
return;
}
// Spurious failure; retry
Err(1) => continue,
// It's never possible for the count to transition from 1 to
// any value other than 0 or 1
Err(_) => unreachable!(),
}
}
}
// There are plenty of handles in existence; the decrement we
// performed is the only thing that needed to happen.
Ok(_) => return,
},
}
}
}
}
#[repr(transparent)]
pub struct JointLock<'a, T> {
container: NonNull<JointContainer<T>>,
lifetime: PhantomData<&'a Joint<T>>,
}
unsafe impl<T: Send + Sync> Send for JointLock<'_, T> {}
unsafe impl<T: Send + Sync> Sync for JointLock<'_, T> {}
impl<T> JointLock<'_, T> {
fn container(&self) -> &JointContainer<T> {
unsafe { self.container.as_ref() }
}
}
impl<T> Deref for JointLock<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// Safety: if a JointLock exists, it's guaranteed that the value will
// be alive for at least the duration of the lock
unsafe { self.container().value.assume_init_ref() }
}
}
impl<T> Clone for JointLock<'_, T> {
fn clone(&self) -> Self {
// Safety: if a jointlock exists, it's guaranteed that the value will
// continue to exist, so we can do a simple count increment.
// TODO: this could just be a fetch_add, but we need to guard against overflow
let old_count = self.container().count.fetch_add(1, Ordering::Relaxed);
if old_count > MAX_REFCOUNT {
abort()
}
JointLock {
container: self.container,
lifetime: PhantomData,
}
}
fn clone_from(&mut self, source: &Self) {
if self.container != source.container {
*self = JointLock::clone(source)
}
}
}
impl<T> Drop for JointLock<'_, T> {
fn drop(&mut self) {
// The logic here can be a little simpler than Joint, because we're
// guaranteed that there's at least one other handle in existence (our
// parent), and that it definitely won't be dropped before we're done
// being dropped.
// - Need to acquire any changes made by other threads before dropping
// - Need to release any changes made by *this* thread so that it
// can be dropped by another thread
match self.container().count.fetch_sub(1, Ordering::AcqRel) {
// The count must be at LEAST 2: one for us and one for our parent
0 | 1 => unreachable!(),
// If the count was 2, it means that this was the last lock. We've
// already stored the decrement, which means we've taken
// responsibility for attempting to drop (and that future attempts
// to lock will now fail)
2 => unsafe { (*self.container.as_ptr()).value.assume_init_drop() },
// If the count is higher than two, the value is still alive
_ => {}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment