Skip to content

Instantly share code, notes, and snippets.

@therealbnut
Last active September 29, 2023 08:08
Show Gist options
  • Save therealbnut/0dd65edac8e2212df51052b8620f8f0c to your computer and use it in GitHub Desktop.
Save therealbnut/0dd65edac8e2212df51052b8620f8f0c to your computer and use it in GitHub Desktop.
use core::{alloc::Allocator, mem::ManuallyDrop, num::NonZeroUsize, ptr};
pub trait ExtractingIterator: ExactSizeIterator {
type OwnedItem;
fn peek_mut(&mut self) -> Option<Self::Item>;
fn extract(&mut self) -> Self::OwnedItem;
}
// This only exists to allow it to be called conveniently.
pub trait HasExtractIf {
type Iter<'a>: ExtractingIterator
where
Self: 'a;
fn extract_if2<'a, F: FnMut(<Self::Iter<'a> as Iterator>::Item) -> bool>(
&'a mut self,
predicate: F,
) -> ExtractIf<Self::Iter<'a>, F>;
}
impl<T, A: Allocator> HasExtractIf for Vec<T, A> {
type Iter<'a> = VecExtractIf<'a, T, A> where Self: 'a;
#[inline]
fn extract_if2<'a, F: FnMut(<Self::Iter<'a> as Iterator>::Item) -> bool>(
&'a mut self,
predicate: F,
) -> ExtractIf<Self::Iter<'a>, F> {
ExtractIf::new(VecExtractIf::new(self), predicate)
}
}
pub struct ExtractIf<I: ExtractingIterator, F: FnMut(I::Item) -> bool> {
iterator: ManuallyDrop<I>,
predicate: F,
}
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> ExtractIf<I, F> {
pub fn new(iterator: I, predicate: F) -> Self {
ExtractIf {
iterator: ManuallyDrop::new(iterator),
predicate,
}
}
}
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> Drop for ExtractIf<I, F> {
fn drop(&mut self) {
// TODO: Maybe this happens before the iterator...
let len = self.iterator.len();
self.iterator.advance_by(len).unwrap();
// Ensures that the iterator is advancaed before being dropped.
// SAFETY: This is the only place that drops the iterator.
unsafe { ManuallyDrop::drop(&mut self.iterator) }
}
}
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> Iterator for ExtractIf<I, F> {
type Item = I::OwnedItem;
fn next(&mut self) -> Option<Self::Item> {
loop {
let item = self.iterator.peek_mut()?;
if (self.predicate)(item) {
return Some(self.iterator.extract());
}
self.iterator.next().unwrap();
}
}
fn advance_by(&mut self, n: usize) -> Result<(), std::num::NonZeroUsize> {
self.iterator.advance_by(n)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iterator.size_hint()
}
}
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> ExactSizeIterator for ExtractIf<I, F> {}
pub struct VecExtractIf<'a, T, A: Allocator> {
vec: &'a mut Vec<T, A>,
read_index: usize,
write_index: usize,
original_len: usize,
}
impl<'a, T, A: Allocator> VecExtractIf<'a, T, A> {
pub fn new(vec: &'a mut Vec<T, A>) -> Self {
let original_len = vec.len();
// SAFETY: I'm not sure if this is needed, but retain_mut does it...
unsafe { vec.set_len(0) };
Self {
vec,
read_index: 0,
write_index: 0,
original_len,
}
}
#[inline]
unsafe fn unsafe_read(&mut self) -> &'a mut T {
// SAFETY: Elements at the read index has not been modified.
unsafe { &mut *self.vec.as_mut_ptr().add(self.read_index) }
}
}
impl<'a, T, A: Allocator> Drop for VecExtractIf<'a, T, A> {
fn drop(&mut self) {
// SAFETY: Elements begore the write index are valid.
unsafe { self.vec.set_len(self.write_index) };
}
}
impl<'a, T, A: Allocator> ExactSizeIterator for VecExtractIf<'a, T, A> {}
impl<'a, T, A: Allocator> Iterator for VecExtractIf<'a, T, A> {
type Item = &'a mut T;
fn size_hint(&self) -> (usize, Option<usize>) {
let len = (self.read_index..self.original_len).len();
(len, Some(len))
}
fn advance_by(&mut self, n: usize) -> Result<(), std::num::NonZeroUsize> {
let len = (self.read_index..self.original_len).len();
if n > len {
return Err(NonZeroUsize::new(n - len).unwrap());
}
unsafe {
ptr::copy(
self.vec.as_ptr().add(self.read_index),
self.vec.as_mut_ptr().add(self.write_index),
len,
);
}
self.read_index += n;
self.write_index += n;
Ok(())
}
fn next(&mut self) -> Option<Self::Item> {
if self.read_index == self.original_len {
return None;
}
// SAFETY: The `read_index` is always valid.
let item = unsafe { self.unsafe_read() };
if self.write_index < self.read_index {
// SAFETY: The read and write indices are always valid.
unsafe {
ptr::copy_nonoverlapping(
self.vec.as_ptr().add(self.read_index),
self.vec.as_mut_ptr().add(self.write_index),
1,
)
}
}
self.read_index += 1;
self.write_index += 1;
Some(item)
}
}
impl<'a, T, A: Allocator> ExtractingIterator for VecExtractIf<'a, T, A> {
type OwnedItem = T;
fn peek_mut(&mut self) -> Option<Self::Item> {
if self.read_index < self.original_len {
// SAFETY: The read index is always valid.
Some(unsafe { self.unsafe_read() })
} else {
None
}
}
fn extract(&mut self) -> Self::OwnedItem {
debug_assert!(self.read_index < self.original_len);
// SAFETY: The read index is always valid.
// Incrementing only the read index ensures it will never be read again.
let item = unsafe { ptr::read(self.unsafe_read()) };
self.read_index += 1;
item
}
}
#[cfg(test)]
mod tests {
use std::alloc::Global;
use super::*;
#[test]
fn test_extract_if_exhaustive() {
for len in 0usize..10 {
for take_count in 0..len {
for extract_mask in 0..(1u64 << len) {
let mut vec: Vec<usize> = (0..len).collect();
let original_vec = vec.clone();
let mut expect_extracted = vec![];
let mut expect_vec = vec![];
for i in 0..len {
let x = vec[i];
if ((extract_mask >> i) & 1) != 0 && i < take_count {
expect_extracted.push(x);
} else {
expect_vec.push(x);
}
}
let iter = vec.extract_if2(|x| expect_extracted.contains(x));
let extracted = iter.take(take_count).collect::<Vec<_>>();
eprintln!("{original_vec:?}:");
eprintln!(" extract: {:?} (expect {:?})", extracted, expect_extracted);
eprintln!(" retain: {:?} (expect {:?})", vec, expect_vec);
if extracted != expect_extracted || vec != expect_vec {
assert_eq!(extracted, expect_extracted);
assert_eq!(vec, expect_vec);
}
}
}
}
}
#[test]
fn test_pair_extract_if_exhaustive() {
for len in 0usize..10 {
for take_count in 0..len {
for extract_mask in 0..(1u64 << len) {
let mut lhs: Vec<usize> = (0..len).collect();
let mut rhs: Vec<char> = "abcdefghij".chars().collect();
let original_vecs = (lhs.clone(), rhs.clone());
let mut expect_extracted = vec![];
let mut expect_vec = vec![];
for i in 0..len {
let xy = (lhs[i], rhs[i]);
if ((extract_mask >> i) & 1) != 0 && i < take_count {
expect_extracted.push(xy);
} else {
expect_vec.push(xy);
}
}
let pair_iter = ExtractPair {
lhs: VecExtractIf::new(&mut lhs),
rhs: VecExtractIf::new(&mut rhs),
};
let iter = ExtractIf::new(pair_iter, |(&mut x, &mut y)| {
expect_extracted.contains(&(x, y))
});
let extracted = iter.take(take_count).collect::<Vec<_>>();
let vec = lhs.into_iter().zip(rhs.into_iter()).collect::<Vec<_>>();
eprintln!("{original_vecs:?}:");
eprintln!(" extract: {:?} (expect {:?})", extracted, expect_extracted);
eprintln!(" retain: {:?} (expect {:?})", vec, expect_vec);
if extracted != expect_extracted || vec != expect_vec {
assert_eq!(extracted, expect_extracted);
assert_eq!(vec, expect_vec);
}
}
}
}
}
struct ExtractPair<'a, T, U, S: Allocator = Global> {
lhs: VecExtractIf<'a, T, S>,
rhs: VecExtractIf<'a, U, S>,
}
impl<'a, T, U, S: Allocator> ExactSizeIterator for ExtractPair<'a, T, U, S> {}
impl<'a, T, U, S: Allocator> Iterator for ExtractPair<'a, T, U, S> {
type Item = (&'a mut T, &'a mut U);
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.lhs.len();
(len, Some(len))
}
fn next(&mut self) -> Option<Self::Item> {
Some((self.lhs.next()?, self.rhs.next()?))
}
fn advance_by(&mut self, n: usize) -> Result<(), NonZeroUsize> {
self.lhs.advance_by(n)?;
self.rhs.advance_by(n)?;
Ok(())
}
}
impl<'a, T, U, S: Allocator> ExtractingIterator for ExtractPair<'a, T, U, S> {
type OwnedItem = (T, U);
fn peek_mut(&mut self) -> Option<Self::Item> {
Some((self.lhs.peek_mut()?, self.rhs.peek_mut()?))
}
fn extract(&mut self) -> Self::OwnedItem {
(self.lhs.extract(), self.rhs.extract())
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment