-
-
Save therealbnut/0dd65edac8e2212df51052b8620f8f0c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use core::{alloc::Allocator, mem::ManuallyDrop, num::NonZeroUsize, ptr}; | |
pub trait ExtractingIterator: ExactSizeIterator { | |
type OwnedItem; | |
fn peek_mut(&mut self) -> Option<Self::Item>; | |
fn extract(&mut self) -> Self::OwnedItem; | |
} | |
// This only exists to allow it to be called conveniently. | |
pub trait HasExtractIf { | |
type Iter<'a>: ExtractingIterator | |
where | |
Self: 'a; | |
fn extract_if2<'a, F: FnMut(<Self::Iter<'a> as Iterator>::Item) -> bool>( | |
&'a mut self, | |
predicate: F, | |
) -> ExtractIf<Self::Iter<'a>, F>; | |
} | |
impl<T, A: Allocator> HasExtractIf for Vec<T, A> { | |
type Iter<'a> = VecExtractIf<'a, T, A> where Self: 'a; | |
#[inline] | |
fn extract_if2<'a, F: FnMut(<Self::Iter<'a> as Iterator>::Item) -> bool>( | |
&'a mut self, | |
predicate: F, | |
) -> ExtractIf<Self::Iter<'a>, F> { | |
ExtractIf::new(VecExtractIf::new(self), predicate) | |
} | |
} | |
pub struct ExtractIf<I: ExtractingIterator, F: FnMut(I::Item) -> bool> { | |
iterator: ManuallyDrop<I>, | |
predicate: F, | |
} | |
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> ExtractIf<I, F> { | |
pub fn new(iterator: I, predicate: F) -> Self { | |
ExtractIf { | |
iterator: ManuallyDrop::new(iterator), | |
predicate, | |
} | |
} | |
} | |
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> Drop for ExtractIf<I, F> { | |
fn drop(&mut self) { | |
// TODO: Maybe this happens before the iterator... | |
let len = self.iterator.len(); | |
self.iterator.advance_by(len).unwrap(); | |
// Ensures that the iterator is advancaed before being dropped. | |
// SAFETY: This is the only place that drops the iterator. | |
unsafe { ManuallyDrop::drop(&mut self.iterator) } | |
} | |
} | |
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> Iterator for ExtractIf<I, F> { | |
type Item = I::OwnedItem; | |
fn next(&mut self) -> Option<Self::Item> { | |
loop { | |
let item = self.iterator.peek_mut()?; | |
if (self.predicate)(item) { | |
return Some(self.iterator.extract()); | |
} | |
self.iterator.next().unwrap(); | |
} | |
} | |
fn advance_by(&mut self, n: usize) -> Result<(), std::num::NonZeroUsize> { | |
self.iterator.advance_by(n) | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
self.iterator.size_hint() | |
} | |
} | |
impl<I: ExtractingIterator, F: FnMut(I::Item) -> bool> ExactSizeIterator for ExtractIf<I, F> {} | |
pub struct VecExtractIf<'a, T, A: Allocator> { | |
vec: &'a mut Vec<T, A>, | |
read_index: usize, | |
write_index: usize, | |
original_len: usize, | |
} | |
impl<'a, T, A: Allocator> VecExtractIf<'a, T, A> { | |
pub fn new(vec: &'a mut Vec<T, A>) -> Self { | |
let original_len = vec.len(); | |
// SAFETY: I'm not sure if this is needed, but retain_mut does it... | |
unsafe { vec.set_len(0) }; | |
Self { | |
vec, | |
read_index: 0, | |
write_index: 0, | |
original_len, | |
} | |
} | |
#[inline] | |
unsafe fn unsafe_read(&mut self) -> &'a mut T { | |
// SAFETY: Elements at the read index has not been modified. | |
unsafe { &mut *self.vec.as_mut_ptr().add(self.read_index) } | |
} | |
} | |
impl<'a, T, A: Allocator> Drop for VecExtractIf<'a, T, A> { | |
fn drop(&mut self) { | |
// SAFETY: Elements begore the write index are valid. | |
unsafe { self.vec.set_len(self.write_index) }; | |
} | |
} | |
impl<'a, T, A: Allocator> ExactSizeIterator for VecExtractIf<'a, T, A> {} | |
impl<'a, T, A: Allocator> Iterator for VecExtractIf<'a, T, A> { | |
type Item = &'a mut T; | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
let len = (self.read_index..self.original_len).len(); | |
(len, Some(len)) | |
} | |
fn advance_by(&mut self, n: usize) -> Result<(), std::num::NonZeroUsize> { | |
let len = (self.read_index..self.original_len).len(); | |
if n > len { | |
return Err(NonZeroUsize::new(n - len).unwrap()); | |
} | |
unsafe { | |
ptr::copy( | |
self.vec.as_ptr().add(self.read_index), | |
self.vec.as_mut_ptr().add(self.write_index), | |
len, | |
); | |
} | |
self.read_index += n; | |
self.write_index += n; | |
Ok(()) | |
} | |
fn next(&mut self) -> Option<Self::Item> { | |
if self.read_index == self.original_len { | |
return None; | |
} | |
// SAFETY: The `read_index` is always valid. | |
let item = unsafe { self.unsafe_read() }; | |
if self.write_index < self.read_index { | |
// SAFETY: The read and write indices are always valid. | |
unsafe { | |
ptr::copy_nonoverlapping( | |
self.vec.as_ptr().add(self.read_index), | |
self.vec.as_mut_ptr().add(self.write_index), | |
1, | |
) | |
} | |
} | |
self.read_index += 1; | |
self.write_index += 1; | |
Some(item) | |
} | |
} | |
impl<'a, T, A: Allocator> ExtractingIterator for VecExtractIf<'a, T, A> { | |
type OwnedItem = T; | |
fn peek_mut(&mut self) -> Option<Self::Item> { | |
if self.read_index < self.original_len { | |
// SAFETY: The read index is always valid. | |
Some(unsafe { self.unsafe_read() }) | |
} else { | |
None | |
} | |
} | |
fn extract(&mut self) -> Self::OwnedItem { | |
debug_assert!(self.read_index < self.original_len); | |
// SAFETY: The read index is always valid. | |
// Incrementing only the read index ensures it will never be read again. | |
let item = unsafe { ptr::read(self.unsafe_read()) }; | |
self.read_index += 1; | |
item | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::alloc::Global; | |
use super::*; | |
#[test] | |
fn test_extract_if_exhaustive() { | |
for len in 0usize..10 { | |
for take_count in 0..len { | |
for extract_mask in 0..(1u64 << len) { | |
let mut vec: Vec<usize> = (0..len).collect(); | |
let original_vec = vec.clone(); | |
let mut expect_extracted = vec![]; | |
let mut expect_vec = vec![]; | |
for i in 0..len { | |
let x = vec[i]; | |
if ((extract_mask >> i) & 1) != 0 && i < take_count { | |
expect_extracted.push(x); | |
} else { | |
expect_vec.push(x); | |
} | |
} | |
let iter = vec.extract_if2(|x| expect_extracted.contains(x)); | |
let extracted = iter.take(take_count).collect::<Vec<_>>(); | |
eprintln!("{original_vec:?}:"); | |
eprintln!(" extract: {:?} (expect {:?})", extracted, expect_extracted); | |
eprintln!(" retain: {:?} (expect {:?})", vec, expect_vec); | |
if extracted != expect_extracted || vec != expect_vec { | |
assert_eq!(extracted, expect_extracted); | |
assert_eq!(vec, expect_vec); | |
} | |
} | |
} | |
} | |
} | |
#[test] | |
fn test_pair_extract_if_exhaustive() { | |
for len in 0usize..10 { | |
for take_count in 0..len { | |
for extract_mask in 0..(1u64 << len) { | |
let mut lhs: Vec<usize> = (0..len).collect(); | |
let mut rhs: Vec<char> = "abcdefghij".chars().collect(); | |
let original_vecs = (lhs.clone(), rhs.clone()); | |
let mut expect_extracted = vec![]; | |
let mut expect_vec = vec![]; | |
for i in 0..len { | |
let xy = (lhs[i], rhs[i]); | |
if ((extract_mask >> i) & 1) != 0 && i < take_count { | |
expect_extracted.push(xy); | |
} else { | |
expect_vec.push(xy); | |
} | |
} | |
let pair_iter = ExtractPair { | |
lhs: VecExtractIf::new(&mut lhs), | |
rhs: VecExtractIf::new(&mut rhs), | |
}; | |
let iter = ExtractIf::new(pair_iter, |(&mut x, &mut y)| { | |
expect_extracted.contains(&(x, y)) | |
}); | |
let extracted = iter.take(take_count).collect::<Vec<_>>(); | |
let vec = lhs.into_iter().zip(rhs.into_iter()).collect::<Vec<_>>(); | |
eprintln!("{original_vecs:?}:"); | |
eprintln!(" extract: {:?} (expect {:?})", extracted, expect_extracted); | |
eprintln!(" retain: {:?} (expect {:?})", vec, expect_vec); | |
if extracted != expect_extracted || vec != expect_vec { | |
assert_eq!(extracted, expect_extracted); | |
assert_eq!(vec, expect_vec); | |
} | |
} | |
} | |
} | |
} | |
struct ExtractPair<'a, T, U, S: Allocator = Global> { | |
lhs: VecExtractIf<'a, T, S>, | |
rhs: VecExtractIf<'a, U, S>, | |
} | |
impl<'a, T, U, S: Allocator> ExactSizeIterator for ExtractPair<'a, T, U, S> {} | |
impl<'a, T, U, S: Allocator> Iterator for ExtractPair<'a, T, U, S> { | |
type Item = (&'a mut T, &'a mut U); | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
let len = self.lhs.len(); | |
(len, Some(len)) | |
} | |
fn next(&mut self) -> Option<Self::Item> { | |
Some((self.lhs.next()?, self.rhs.next()?)) | |
} | |
fn advance_by(&mut self, n: usize) -> Result<(), NonZeroUsize> { | |
self.lhs.advance_by(n)?; | |
self.rhs.advance_by(n)?; | |
Ok(()) | |
} | |
} | |
impl<'a, T, U, S: Allocator> ExtractingIterator for ExtractPair<'a, T, U, S> { | |
type OwnedItem = (T, U); | |
fn peek_mut(&mut self) -> Option<Self::Item> { | |
Some((self.lhs.peek_mut()?, self.rhs.peek_mut()?)) | |
} | |
fn extract(&mut self) -> Self::OwnedItem { | |
(self.lhs.extract(), self.rhs.extract()) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment