-
-
Save adlerd/08829dfc239f26b9c827756785cc7af8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(maybe_uninit_array_assume_init)] | |
#![feature(maybe_uninit_uninit_array)] | |
#![feature(maybe_uninit_extra)] | |
use std::iter::FromIterator; | |
use std::mem::MaybeUninit; | |
pub unsafe trait GetMutRefsRaw<Key, Value> { | |
type RawEntry; | |
fn would_overlap(&self, e1: &Self::RawEntry, e2: &Self::RawEntry) -> bool; | |
// NB. we provide *mut Self instead of &'_ self because we can't soundly | |
// hold a shared reference to self and a unique reference to some values at | |
// the same time, even though borrowck won't notice the relationship. | |
unsafe fn entry_to_ref<'a>(this: *mut Self, entry: Self::RawEntry) -> &'a mut Value; | |
fn get_entry<'a>(&mut self, key: Key) -> Option<Self::RawEntry>; | |
} | |
pub trait GetMutRefs<Key, Value> { | |
fn get_mut_refs<'a, const N: usize>(&'a mut self, keys: [Key; N]) | |
-> Option<[&'a mut Value; N]>; | |
} | |
// Safety invariant: we own self.filled initialized elements at the front of self.arr. | |
struct ArrayFill<T, const N: usize> { | |
arr: [MaybeUninit<T>; N], | |
filled: usize, | |
} | |
impl<T, const N: usize> ArrayFill<T, N> { | |
fn new() -> Self { | |
// Safety invariant: we own zero elements. | |
Self { | |
arr: MaybeUninit::uninit_array(), | |
filled: 0, | |
} | |
} | |
fn push(&mut self, value: T) { | |
if let Some(place) = self.arr.get_mut(self.filled) { | |
// Safety: It's always safe to ptr::write to a MaybeUninit. | |
unsafe { | |
place.as_mut_ptr().write(value); | |
} | |
// Safety invariant: we have just initilized self.arr[self.filled], | |
// and moreover ptr::write has suppressed dropping the existing copy | |
// of the value, so we now own it properly. | |
self.filled += 1; | |
} | |
} | |
fn is_full(&self) -> bool { | |
self.filled >= self.arr.len() | |
} | |
fn into_inner(mut self) -> Option<[T; N]> { | |
if self.is_full() { | |
// Safety invariant: By setting self.filled = 0, we avoid | |
// double-drop after std::ptr::read. | |
self.filled = 0; | |
// Safety: By our type invariant (after self.is_full()), the array | |
// is fully initialized and owned. | |
unsafe { Some(MaybeUninit::array_assume_init(std::ptr::read(&self.arr))) } | |
} else { | |
None | |
} | |
} | |
} | |
impl<T, const N: usize> Drop for ArrayFill<T, N> { | |
fn drop(&mut self) { | |
if !std::mem::needs_drop::<T>() { | |
return; | |
} | |
for place in self.arr.iter_mut().take(self.filled) { | |
// Safety: By our type invariant, we own an initialized value at | |
// place, which is one of the first self.filled entries in self.arr. | |
// Since Drop::drop will only be called once we will not double-drop. | |
unsafe { | |
place.assume_init_drop(); | |
} | |
} | |
} | |
} | |
impl<T, const N: usize> FromIterator<T> for ArrayFill<T, N> { | |
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self { | |
let mut iter = iter.into_iter(); | |
let mut this = Self::new(); | |
loop { | |
if this.is_full() { | |
break; | |
} else if let Some(value) = iter.next() { | |
this.push(value); | |
} else { | |
break; | |
} | |
} | |
this | |
} | |
} | |
impl<T, Key, Value> GetMutRefs<Key, Value> for T | |
where | |
T: GetMutRefsRaw<Key, Value>, | |
{ | |
fn get_mut_refs<'a, const N: usize>( | |
&'a mut self, | |
keys: [Key; N], | |
) -> Option<[&'a mut Value; N]> { | |
let all_entries = | |
Option::from_iter(std::array::IntoIter::new(keys).map(|key| self.get_entry(key)))?; | |
let arr: [T::RawEntry; N] = ArrayFill::into_inner(all_entries).unwrap(); | |
let mut iter = arr.iter(); | |
while let Some(x) = iter.next() { | |
if iter.clone().any(|y| self.would_overlap(x, y)) { | |
return None; | |
} | |
} | |
let self_ptr: *mut Self = self; | |
let all_refs = std::array::IntoIter::new(arr) | |
.map(|raw_entry| { | |
// Safety: We fulfilled the requirements of entry_to_ref by checking for overlaps. | |
// The returned lifetime is bound to a unique borrow of self. | |
unsafe { Self::entry_to_ref(self_ptr, raw_entry) } | |
}) | |
.collect(); | |
Some(ArrayFill::into_inner(all_refs).unwrap()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment