Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Upgradable Reference Counting
use std::{
marker::PhantomData,
ops::Deref,
ptr::NonNull,
sync::atomic::{AtomicI32, Ordering},
};
struct Inner<T: ?Sized> {
// Number of strong references in addition to the current value.
// A negative value indicates a non-atomic reference count, counting up from i32::MIN
// A positive value indicates an atomic reference count, counting up from 0
strong: AtomicI32,
value: T,
}
impl<T> Inner<T> {
#[inline]
fn new(value: T) -> Self {
Self {
strong: AtomicI32::new(i32::MIN + 1),
value,
}
}
}
impl<T: ?Sized> Inner<T> {
#[inline]
fn incr_strong(&self) {
let strong = self.strong.load(Ordering::Relaxed);
if strong < 0 {
self.strong.store(strong.wrapping_add(1), Ordering::Relaxed);
} else {
self.strong.fetch_add(1, Ordering::Relaxed);
}
}
#[inline]
fn decr_strong(&self) -> bool {
let strong = self.strong.load(Ordering::Relaxed);
if strong < 0 {
self.strong.store(strong.wrapping_sub(1), Ordering::Release);
strong != i32::MIN + 1
} else {
let strong = self.strong.fetch_sub(1, Ordering::Release);
strong != 1
}
}
#[inline]
fn incr_strong_atomic(&self) {
self.strong.fetch_add(1, Ordering::Relaxed);
}
#[inline]
fn decr_strong_atomic(&self) -> bool {
self.strong.fetch_sub(1, Ordering::Release) != 1
}
#[inline]
fn upgrade(&self) {
let strong = self.strong.load(Ordering::Relaxed);
if strong < 0 {
self.strong.store(
strong.wrapping_add(i32::MIN),
Ordering::Relaxed,
);
}
}
}
pub struct Rc<T: ?Sized> {
ptr: NonNull<Inner<T>>,
phantom: PhantomData<Inner<T>>,
}
impl<T> Rc<T> {
pub fn new(value: T) -> Self {
Self::from_inner(Box::leak(Box::new(Inner::new(value))).into())
}
}
impl<T: Default> Default for Rc<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: ?Sized> Rc<T> {
#[inline]
pub fn strong_count(&self) -> i32 {
let strong = self.inner().strong.load(Ordering::Relaxed);
if strong < 0 {
strong.wrapping_add(i32::MIN)
} else {
strong
}
}
#[inline]
pub fn is_unique(&mut self) -> bool {
let strong = self.inner().strong.load(Ordering::Relaxed);
strong == 1 || strong == i32::MIN + 1
}
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr.as_ptr() == other.ptr.as_ptr()
}
#[inline]
pub fn get_mut(&mut self) -> Option<&mut T> {
if self.is_unique() {
// This unsafety is ok because we're guaranteed that the pointer
// returned is the *only* pointer that will ever be returned to T. Our
// reference count is guaranteed to be 1 at this point, and we required
// the Arc itself to be `mut`, so we're returning the only possible
// reference to the inner data.
Some(unsafe { self.get_mut_unchecked() })
} else {
None
}
}
#[inline]
pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
// We are careful to *not* create a reference covering the "count" fields, as
// this would alias with concurrent access to the reference counts.
&mut (*self.ptr.as_ptr()).value
}
#[inline]
fn from_inner(ptr: NonNull<Inner<T>>) -> Self {
Self {
ptr,
phantom: PhantomData,
}
}
#[inline]
fn inner(&self) -> &Inner<T> {
unsafe { self.ptr.as_ref() }
}
#[cold]
#[inline(never)]
fn drop_slow(&self) {
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
drop(Box::from_raw(self.ptr.as_ptr()));
}
}
}
impl<T: ?Sized> Clone for Rc<T> {
fn clone(&self) -> Self {
self.inner().incr_strong();
Self::from_inner(self.inner().into())
}
}
impl<T: ?Sized> Drop for Rc<T> {
fn drop(&mut self) {
if !self.inner().decr_strong() {
self.drop_slow();
}
}
}
impl<T: ?Sized> Deref for Rc<T> {
type Target = T;
// Inner is valid whilever we have a valid Rc.
fn deref(&self) -> &Self::Target {
&self.inner().value
}
}
pub struct Arc<T: ?Sized> {
ptr: NonNull<Inner<T>>,
phantom: PhantomData<Inner<T>>,
}
unsafe impl<T: ?Sized + Sync + Send> Send for Arc<T> {}
unsafe impl<T: ?Sized + Sync + Send> Sync for Arc<T> {}
impl<T> Arc<T> {
pub fn new(value: T) -> Self {
Self::from_inner(Box::leak(Box::new(Inner::new(value))).into())
}
}
impl<T: Default> Default for Arc<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: ?Sized> Arc<T> {
pub fn from_rc(rc: &Rc<T>) -> Self {
let inner = rc.inner();
inner.upgrade();
inner.incr_strong();
Self::from_inner(inner.into())
}
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr.as_ptr() == other.ptr.as_ptr()
}
#[inline]
pub fn strong_count(&self) -> i32 {
let strong = self.inner().strong.load(Ordering::Relaxed);
if strong < 0 {
strong.wrapping_add(i32::MIN)
} else {
strong
}
}
#[inline]
pub fn is_unique(&self) -> bool {
let strong = self.inner().strong.load(Ordering::Acquire);
strong == 1 || strong == i32::MIN + 1
}
pub fn get_mut(&mut self) -> Option<&mut T> {
if self.is_unique() {
// This unsafety is ok because we're guaranteed that the pointer
// returned is the *only* pointer that will ever be returned to T. Our
// reference count is guaranteed to be 1 at this point, and we required
// the Arc itself to be `mut`, so we're returning the only possible
// reference to the inner data.
Some(unsafe { self.get_mut_unchecked() })
} else {
None
}
}
pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
// We are careful to *not* create a reference covering the "count" fields, as
// this would alias with concurrent access to the reference counts.
&mut (*self.ptr.as_ptr()).value
}
fn from_inner(ptr: NonNull<Inner<T>>) -> Self {
Self {
ptr,
phantom: PhantomData,
}
}
#[inline]
fn inner(&self) -> &Inner<T> {
unsafe { self.ptr.as_ref() }
}
#[cold]
#[inline(never)]
fn drop_slow(&self) {
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
drop(Box::from_raw(self.ptr.as_ptr()));
}
}
}
impl<T: ?Sized> Clone for Arc<T> {
fn clone(&self) -> Self {
self.inner().incr_strong_atomic();
Self::from_inner(self.inner().into())
}
}
impl<T: ?Sized> Drop for Arc<T> {
fn drop(&mut self) {
if !self.inner().decr_strong_atomic() {
self.drop_slow()
}
}
}
impl<T: ?Sized> Deref for Arc<T> {
type Target = T;
// Inner is value whilever we have a valid Arc.
fn deref(&self) -> &Self::Target {
&self.inner().value
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn rc_double_upgrade() {
let rc1 = Rc::new(());
assert_eq!(rc1.strong_count(), 1);
let _rc2 = rc1.clone();
assert_eq!(rc1.strong_count(), 2);
let _arc1 = Arc::from_rc(&rc1);
assert_eq!(rc1.strong_count(), 3);
let _arc2 = Arc::from_rc(&rc1);
assert_eq!(rc1.strong_count(), 4);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment