Skip to content

Instantly share code, notes, and snippets.

@pstephens
Forked from rust-play/playground.rs
Last active February 15, 2023 17:16
Show Gist options
  • Save pstephens/f6ff4ac333fa95e53b69c85941504f15 to your computer and use it in GitHub Desktop.
Save pstephens/f6ff4ac333fa95e53b69c85941504f15 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
//! Proof of concept for a reference counted [str]
//! Primary innovation over a plain Rc<str> is that is one less level of
//! indirection and heap allocation
#![feature(const_alloc_layout)]
use core::ptr::NonNull;
use std::alloc::Layout;
use std::alloc::{GlobalAlloc, System};
use std::fmt::Formatter;
use std::fmt::Error;
/// Header fields for a reference counted [str]
/// The [str] payload is not included so as to avoid the fat pointer in the
/// handle and instead use the length from the header.
struct ArcStrInner {
// Layout and offsets are computed manually. See compute_offsets, etc.
// ref_cnt: usize,
// str_len: u16,
// buff: [u8; str_len]
}
type RefCnt = usize;
type StrLen = u16; // only support string sizes of < 65536 utf8 bytes
/// NIGHTLY!: use the unstable const API for Layout for performance.
/// This could be reworked into lazy_static and used in Standard or could
/// implement our own const layout API on top of size_of* and align_of*.
const OFFSETS: (usize, usize, usize) = compute_offsets();
const OFFSET_STR_LEN: usize = OFFSETS.1;
const OFFSET_BUFF: usize = OFFSETS.2;
// base layout includes the header fields but not the str buff
const BASE_LAYOUT: Layout = compute_base_layout();
const fn compute_offsets() -> (usize, usize, usize) {
let (layout_str_len, offset_str_len) =
match Layout::new::<RefCnt>().extend(Layout::new::<StrLen>()) {
Ok(x) => x,
Err(_) => panic!("Invalid layout"),
};
let (_layout_buff, offset_buff) =
match layout_str_len.extend(Layout::for_value("")) {
Ok(x) => x,
Err(_) => panic!("Invalid layout"),
};
(0, offset_str_len, offset_buff)
}
#[inline(always)]
unsafe fn read_str_len(inner: NonNull<ArcStrInner>) -> StrLen {
let str_len: *const StrLen = inner.cast::<u8>().as_ptr().add(OFFSET_STR_LEN).cast::<StrLen>();
std::ptr::read(str_len)
}
#[inline(always)]
unsafe fn write_str_len(inner: NonNull<ArcStrInner>, len: StrLen) {
let str_len: *mut StrLen = inner.cast::<u8>().as_ptr().add(OFFSET_STR_LEN).cast::<StrLen>();
std::ptr::write(str_len, len);
}
#[inline(always)]
unsafe fn incr_ref_cnt(inner: NonNull<ArcStrInner>) {
let ref_cnt: *mut RefCnt = inner.cast::<RefCnt>().as_ptr();
*ref_cnt += 1;
}
#[inline(always)]
unsafe fn decr_ref_cnt(inner: NonNull<ArcStrInner>) -> bool {
let ref_cnt: *mut RefCnt = inner.cast::<RefCnt>().as_ptr();
*ref_cnt -= 1;
*ref_cnt == 0
}
#[inline(always)]
unsafe fn read_ref_cnt(inner: NonNull<ArcStrInner>) -> RefCnt {
let ref_cnt: *const RefCnt = inner.cast::<RefCnt>().as_ptr();
std::ptr::read(ref_cnt)
}
#[inline(always)]
unsafe fn write_ref_cnt(inner: NonNull<ArcStrInner>, v: RefCnt) {
let ref_cnt: *mut RefCnt = inner.cast::<RefCnt>().as_ptr();
std::ptr::write(ref_cnt, v);
}
#[inline(always)]
unsafe fn write_buff(inner: NonNull<ArcStrInner>, s: &str) {
let len = s.len();
let buff: *mut u8 = inner.cast::<u8>().as_ptr().add(OFFSET_BUFF);
let u8_slice = std::slice::from_raw_parts_mut(buff, len);
u8_slice.copy_from_slice(s.as_bytes());
}
#[inline(always)]
unsafe fn as_str<'a>(inner: NonNull<ArcStrInner>) -> &'a str {
let len = read_str_len(inner);
let buff: *const u8 = inner.cast::<u8>().as_ptr().add(OFFSET_BUFF);
let u8_slice = std::slice::from_raw_parts(buff, len.into());
std::str::from_utf8_unchecked(u8_slice)
}
#[inline(always)]
unsafe fn raw_payload<'a>(inner: NonNull<ArcStrInner>) -> &'a [u8] {
let total_len = OFFSET_BUFF + usize::from(read_str_len(inner));
let u8_ptr = inner.cast::<u8>().as_ptr();
std::slice::from_raw_parts(u8_ptr, total_len)
}
const fn compute_base_layout() -> Layout {
match Layout::new::<RefCnt>().extend(Layout::new::<StrLen>()) {
Ok((layout, _offset)) => layout,
Err(_) => panic!("Invalid layout"),
}
}
#[inline(always)]
fn compute_layout_from_str(s: &str) -> Layout {
let (layout, _offset) = BASE_LAYOUT.extend(Layout::for_value(s)).unwrap();
layout
}
#[inline(always)]
unsafe fn compute_layout_for_ptr(ptr: NonNull<ArcStrInner>) -> Layout {
let len = read_str_len(ptr);
let (layout, _offset) = BASE_LAYOUT.extend(Layout::from_size_align_unchecked(len.into(), 1)).unwrap();
layout
}
#[inline(always)]
unsafe fn alloc(s: &str) -> NonNull<ArcStrInner> {
let layout = compute_layout_from_str(s);
let len: StrLen = s.len().try_into().expect("s must sized less than 65536 bytes.");
let inner = NonNull::new(System::alloc(&System, layout)).expect("Out of memory").cast::<ArcStrInner>(); // panic on OOM
write_ref_cnt(inner, 1);
write_str_len(inner, len);
write_buff(inner, s);
inner
}
#[inline(always)]
unsafe fn dealloc(inner: NonNull<ArcStrInner>) {
let layout = compute_layout_for_ptr(inner);
System::dealloc(&System, inner.cast::<u8>().as_ptr(), layout);
}
/// A handle to a reference counted [str] buffer.
struct ArcStr {
p: NonNull<ArcStrInner>
}
impl ArcStr {
pub fn from_str(s: &str) -> ArcStr {
ArcStr {
p: unsafe { alloc(s) }
}
}
pub fn ref_cnt(&self) -> usize {
unsafe {
read_ref_cnt(self.p).into()
}
}
pub fn len(&self) -> usize {
unsafe {
read_str_len(self.p).into()
}
}
pub fn as_str(&self) -> &str {
unsafe {
as_str(self.p)
}
}
pub fn raw_payload(&self) -> &[u8] {
unsafe { raw_payload(self.p) }
}
}
impl Clone for ArcStr {
fn clone(&self) -> Self {
// bump the ref_cnt and return a new handle
unsafe {
incr_ref_cnt(self.p);
}
let new_handle = Self {
p: self.p
};
println!("Cloned the ArcStr");
new_handle
}
}
impl std::fmt::Debug for ArcStr {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
f.debug_struct("ArcStr")
.field("ref_cnt", &self.ref_cnt())
.field("len", &self.len())
.field("payload", &self.as_str())
.finish()
}
}
impl Drop for ArcStr {
fn drop(&mut self) {
if !unsafe { decr_ref_cnt(self.p) } {
return;
}
unsafe { dealloc(self.p) };
println!("Dealloc'd the ArcStr");
}
}
fn main() {
let x = ArcStr::from_str("Hello, World!");
println!("x: {x:?}");
let cloned_x = x.clone();
println!("cloned_x: {cloned_x:?}");
println!("raw: {:x?}", x.raw_payload());
drop(cloned_x);
println!("x: {x:?}");
println!("size of ArcStr: {}", std::mem::size_of::<ArcStr>());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment