Last active
February 18, 2024 18:13
-
-
Save mildsunrise/4462033971bcf30b2112660edf06cfeb to your computer and use it in GitHub Desktop.
simple linked list allocator in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{mem, ptr}; | |
use crate::BufferAlloc; | |
unsafe fn write_ptr(n: *mut usize, ptr: Option<*mut usize>) { | |
debug_assert!(!ptr.is_some_and(|ptr| n == ptr)); | |
ptr::write(n.cast(), ptr.unwrap_or(n)) | |
} | |
unsafe fn read_ptr(n: *mut usize) -> Option<*mut usize> { | |
let ptr = ptr::read(n.cast()); | |
(n != ptr).then_some(ptr) | |
} | |
// This is the low level layer of the allocator. We format the | |
// buffer as a stream of blocks, with the following layout: | |
// | |
// next: *mut block (if empty, there is no block here (EOF)) | |
// prev: *mut block (if empty, this is the first block) | |
// [allocation (padded as necessary)] | |
// flags: u8 | |
// [next block...] | |
#[derive(Clone, Debug)] | |
pub struct Block<'a> { | |
pub buffer: BufferAlloc<'a>, | |
pub ptr: *mut usize, | |
pub end: *mut usize, | |
} | |
pub const BLOCK_SIZE: usize = 3; | |
pub const BLOCK_OVERHEAD: usize = 2 * mem::size_of::<usize>() + 1; | |
impl<'a> Block<'a> { | |
// SIZE | |
// always call this after constructing a Self literal, and check both pointers are aligned beforehand | |
fn check_size(self) -> Self { | |
debug_assert!(self.ptr < self.end && self.valid_size()); | |
self | |
} | |
fn valid_size(&self) -> bool { | |
self.size() >= BLOCK_SIZE | |
} | |
pub fn size(&self) -> usize { | |
unsafe { self.end.sub_ptr(self.ptr) } | |
} | |
pub fn size_bytes(&self) -> usize { | |
self.size() * mem::size_of::<usize>() | |
} | |
pub fn data(&self) -> *mut usize { | |
unsafe { self.ptr.add(2) } | |
} | |
// MEMBERS | |
fn check_next(self) -> Self { | |
debug_assert!(unsafe { read_ptr(self.ptr) } == Some(self.end)); | |
self | |
} | |
fn init_next(&self) { | |
unsafe { write_ptr(self.ptr, Some(self.end)) } | |
} | |
fn prev_raw(&self) -> Option<*mut usize> { | |
unsafe { read_ptr(self.ptr.add(1)) } | |
} | |
fn init_prev(&self, prev: Option<&Self>) { | |
unsafe { write_ptr(self.ptr.add(1), prev.map(|x| x.ptr)) } | |
} | |
pub fn flags(&self) -> &mut u8 { | |
unsafe { &mut *self.end.cast::<u8>().sub(1) } | |
} | |
fn init_flags(&self, flags: u8) { | |
unsafe { self.end.cast::<u8>().sub(1).write(flags) } | |
} | |
// CONSTRUCTION | |
pub fn first(buffer: BufferAlloc<'a>) -> Option<Self> { | |
Self::from_ptr(buffer, buffer.0) | |
} | |
pub fn at_ptr(buffer: BufferAlloc<'a>, ptr: *mut usize) -> Option<Self> { | |
debug_assert!(ptr.is_aligned() && ptr >= buffer.0); | |
Self::from_ptr(buffer, ptr) | |
} | |
// ptr itself is assumed to be healthy (aligned, >= buffer) | |
fn from_ptr(buffer: BufferAlloc<'a>, ptr: *mut usize) -> Option<Self> { | |
let end = unsafe { read_ptr(ptr) }?; | |
debug_assert!(end.is_aligned()); | |
Some(Self{ buffer, ptr, end }.check_size()) | |
} | |
pub fn next(&self) -> Option<Self> { | |
let block = Self::from_ptr(self.buffer, self.end)?; | |
debug_assert!(block.prev_raw() == Some(self.ptr)); | |
Some(block) | |
} | |
pub fn prev(&self) -> Option<Self> { | |
let &Self{ buffer, ptr: end, .. } = self; | |
let ptr = self.prev_raw()?; | |
debug_assert!(ptr.is_aligned() && ptr >= buffer.0); | |
Some(Self{ buffer, ptr, end }.check_size().check_next()) | |
} | |
// CREATION | |
pub fn init(buffer: BufferAlloc<'a>, size: usize, flags: u8) -> Option<Self> { | |
assert!(size > 0); | |
let mut ptr = buffer.0; | |
let block = Self{ buffer, ptr, end: unsafe { ptr.add(size - 1) } }; | |
let block = block.valid_size().then_some(block); | |
if let Some(ref block) = block { | |
block.init_flags(flags); | |
block.init_next(); | |
block.init_prev(None); | |
ptr = block.end; | |
} | |
unsafe { write_ptr(ptr, None) }; | |
block | |
} | |
fn adjust_next(&self) { | |
if let Some(next) = self.next() { | |
next.init_prev(Some(self)); | |
} | |
} | |
pub fn split(&mut self, new_size: usize, flags: u8) -> Self { | |
let end = unsafe { self.ptr.add(new_size) }; | |
let next = Self{ ptr: end, ..*self }.check_size(); | |
*self = Self{ end, ..*self }.check_size(); | |
self.init_flags(flags); | |
self.init_next(); | |
next.init_next(); | |
next.init_prev(Some(self)); | |
next.adjust_next(); | |
next | |
} | |
pub fn join(&mut self, next: Block<'a>) { | |
debug_assert!(self.end == next.ptr); | |
self.end = next.end; | |
self.init_next(); | |
self.adjust_next(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(allocator_api, slice_ptr_get, ptr_sub_ptr, layout_for_ptr, ptr_metadata, pointer_is_aligned, pointer_byte_offsets)] | |
mod block; | |
use block::{Block, BLOCK_OVERHEAD, BLOCK_SIZE}; | |
use std::{ | |
alloc::{Allocator, Layout, AllocError}, | |
marker::PhantomData, | |
mem::{self, MaybeUninit}, | |
ptr::{self, NonNull}, | |
}; | |
#[derive(Clone, Copy, Debug)] | |
pub struct BufferAlloc<'a>(*mut usize, PhantomData<&'a mut [usize]>); | |
impl<'a> BufferAlloc<'a> { | |
pub fn new<const N: usize>(buffer: &'a mut MaybeUninit<[usize; N]>) -> Self { | |
let buffer = <*mut [usize]>::as_mut_ptr(buffer.as_mut_ptr()); | |
let buffer = Self(buffer, PhantomData); | |
Block::init(buffer, N, 0); // flags are 1 if used | |
buffer | |
} | |
fn to_alloc(block: &Block<'a>) -> NonNull<[u8]> { | |
let size = block.size_bytes() - BLOCK_OVERHEAD; | |
let ptr = ptr::from_raw_parts_mut(block.data().cast(), size); | |
unsafe { NonNull::new_unchecked(ptr) } | |
} | |
fn from_alloc(&self, ptr: NonNull<u8>, layout: Layout) -> Block<'a> { | |
let ptr = unsafe { ptr.as_ptr().cast::<usize>().sub(2) }; | |
let block = Block::at_ptr(*self, ptr); | |
debug_assert!(block.is_some()); | |
let block = unsafe { block.unwrap_unchecked() }; | |
debug_assert!(*block.flags() != 0); | |
debug_assert!(block.size_bytes() - BLOCK_OVERHEAD >= layout.size()); | |
block | |
} | |
fn try_alloc(block: &mut Block<'a>, layout: Layout) -> Option<()> { | |
// check block is free | |
if *block.flags() != 0 { | |
return None | |
} | |
// determine padding | |
let mut padding = block.data().align_offset(layout.align()); | |
while padding != 0 && padding < BLOCK_SIZE { | |
padding += layout.align(); | |
} | |
// check padding + allocation fits inside block | |
let block_size = block.size().checked_sub(padding)?; | |
let min_size = layout.size() + BLOCK_OVERHEAD; | |
if block_size * mem::size_of::<usize>() < min_size { | |
return None | |
} | |
// if necessary, split block to add padding | |
if padding != 0 { | |
*block = block.split(padding, 0); | |
} | |
// if possible, shrink block | |
Self::trim(block, min_size); | |
Some(()) | |
} | |
fn try_grow(block: &mut Block<'a>, layout: Layout) -> Option<()> { | |
// see if next block is free | |
let next = block.next().filter(|next| *next.flags() == 0); | |
// check new allocation fits inside about-to-be-joined block | |
let block_size = block.size() + next.as_ref().map_or(0, Block::size); | |
let min_size = layout.size() + BLOCK_OVERHEAD; | |
if block_size * mem::size_of::<usize>() < min_size { | |
return None | |
} | |
// if we previously found a next block to join with, do it now | |
if let Some(next) = next { | |
block.join(next); | |
} | |
// if possible, shrink block | |
*block.flags() = 0; | |
Self::trim(block, min_size); | |
Some(()) | |
} | |
fn trim(block: &mut Block<'a>, min_size: usize) { | |
let min_end = unsafe { block.ptr.cast::<u8>().add(min_size) }; | |
let max_end = unsafe { block.end.sub(BLOCK_SIZE) }; | |
let padding = min_end.align_offset(mem::align_of::<usize>()); | |
let free_space = unsafe { max_end.byte_offset_from(min_end) }; | |
if free_space >= 0 && padding <= (free_space as usize) { | |
let size = unsafe { min_end.add(padding).cast::<usize>().sub_ptr(block.ptr) }; | |
block.split(size, 0); | |
} | |
*block.flags() = 1; | |
} | |
fn free(mut block: Block<'a>) { | |
if let Some(next) = block.next().filter(|next| *next.flags() == 0) { | |
block.join(next); | |
} else { | |
*block.flags() = 0; | |
} | |
if let Some(mut prev) = block.prev().filter(|prev| *prev.flags() == 0) { | |
prev.join(block); | |
} | |
} | |
} | |
unsafe impl<'a> Allocator for BufferAlloc<'a> { | |
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> { | |
let mut cursor = Block::first(*self); | |
while let Some(mut block) = cursor { | |
if Self::try_alloc(&mut block, layout).is_some() { | |
return Ok(Self::to_alloc(&block)) | |
} | |
cursor = block.next(); | |
} | |
Err(AllocError) | |
} | |
unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) { | |
Self::free(self.from_alloc(ptr, layout)); | |
} | |
unsafe fn grow(&self, ptr: NonNull<u8>, layout: Layout, new_layout: Layout) -> Result<NonNull<[u8]>, AllocError> { | |
let mut block = self.from_alloc(ptr, layout); | |
let is_aligned = ptr.as_ptr().is_aligned_to(new_layout.align()); | |
if is_aligned && Self::try_grow(&mut block, new_layout).is_some() { | |
return Ok(Self::to_alloc(&block)) | |
} | |
// fallback | |
let alloc = Self::allocate(self, new_layout)?; | |
unsafe { ptr::copy_nonoverlapping(ptr.as_ptr(), alloc.as_mut_ptr(), layout.size()) }; | |
Self::free(block); | |
Ok(alloc) | |
} | |
unsafe fn shrink(&self, ptr: NonNull<u8>, layout: Layout, new_layout: Layout) -> Result<NonNull<[u8]>, AllocError> { | |
self.grow(ptr, layout, new_layout) | |
} | |
} | |
#[test] | |
fn allocator_test() { | |
// construction | |
let mut buffer = MaybeUninit::uninit(); | |
let allocator = BufferAlloc::new::<655>(&mut buffer); | |
// must be copyable | |
let allocator2 = allocator; | |
let mut bl = Vec::new_in(allocator); | |
bl.push(true); | |
let mut bl2 = Vec::new_in(allocator2); | |
bl2.push(false); | |
// free + alloc again | |
mem::drop(bl); | |
let mut bl = Vec::new_in(allocator); | |
bl.push(true); | |
bl.resize(10, false); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment