Skip to content

Instantly share code, notes, and snippets.

@CrabNejonas
Last active December 14, 2023 13:11
Show Gist options
  • Save CrabNejonas/32e7bf4bed2a5644ff0f23b4840d0234 to your computer and use it in GitHub Desktop.
Save CrabNejonas/32e7bf4bed2a5644ff0f23b4840d0234 to your computer and use it in GitHub Desktop.
Ports the `intersection` algorithm defined between two `BTreeSet<T>`s to a `BTreeMap<K, V>` and a `BTreeSet<K>` essentially retaining only key-value pairs who's keys are part of the set.
use std::cmp::Ordering;
use std::collections::{btree_map, btree_set, BTreeMap, BTreeSet};
use std::iter::FusedIterator;
trait IntersectionExt<K, V> {
fn pick<'a>(&'a self, other: &'a BTreeSet<K>) -> Pick<'a, K, V>
where
K: Ord;
}
struct Pick<'a, K, V> {
inner: IntersectionInner<'a, K, V>,
}
enum IntersectionInner<'a, K, V> {
Answer(Option<(&'a K, &'a V)>),
SearchMap {
small_iter: btree_map::Iter<'a, K, V>,
large_set: &'a BTreeSet<K>,
},
SearchSet {
small_iter: btree_set::Iter<'a, K>,
large_map: &'a BTreeMap<K, V>,
},
Stitch {
map: btree_map::Iter<'a, K, V>,
set: btree_set::Iter<'a, K>,
},
}
// This constant is used by functions that compare two sets.
// It estimates the relative size at which searching performs better
// than iterating, based on the benchmarks in
// https://github.com/ssomers/rust_bench_btreeset_intersection.
// It's used to divide rather than multiply sizes, to rule out overflow,
// and it's a power of two to make that division cheap.
const ITER_PERFORMANCE_TIPPING_SIZE_DIFF: usize = 16;
impl<K, V> IntersectionExt<K, V> for BTreeMap<K, V> {
fn pick<'a>(&'a self, set: &'a BTreeSet<K>) -> Pick<'a, K, V>
where
K: Ord,
{
let (self_min, self_max) = if let (Some(self_min), Some(self_max)) =
(self.first_key_value(), self.last_key_value())
{
(self_min, self_max)
} else {
return Pick {
inner: IntersectionInner::Answer(None),
};
};
let (set_min, set_max) = if let (Some(set_min), Some(set_max)) = (set.first(), set.last()) {
(set_min, set_max)
} else {
return Pick {
inner: IntersectionInner::Answer(None),
};
};
Pick {
inner: match (self_min.0.cmp(set_max), self_max.0.cmp(set_min)) {
(Ordering::Greater, _) | (_, Ordering::Less) => IntersectionInner::Answer(None),
(Ordering::Equal, _) => IntersectionInner::Answer(Some(self_min)),
(_, Ordering::Equal) => IntersectionInner::Answer(Some(self_max)),
_ if self.len() <= set.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::SearchMap {
small_iter: self.iter(),
large_set: set,
}
}
_ if set.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::SearchSet {
small_iter: set.iter(),
large_map: self,
}
}
_ => IntersectionInner::Stitch {
map: self.iter(),
set: set.iter(),
},
},
}
}
}
impl<'a, K: Ord, V> Iterator for Pick<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<Self::Item> {
match &mut self.inner {
IntersectionInner::Answer(answer) => answer.take(),
IntersectionInner::SearchMap {
small_iter,
large_set,
} => loop {
let small_next = small_iter.next()?;
if large_set.contains(&small_next.0) {
return Some(small_next);
}
},
IntersectionInner::SearchSet {
small_iter,
large_map,
} => loop {
let small_next = small_iter.next()?;
if let Some(v) = large_map.get(small_next) {
return Some((small_next, v));
}
},
IntersectionInner::Stitch { map, set } => {
let mut map_next = map.next()?;
let mut set_next = set.next()?;
loop {
match map_next.0.cmp(set_next) {
Ordering::Less => map_next = map.next()?,
Ordering::Greater => set_next = set.next()?,
Ordering::Equal => return Some(map_next),
}
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match &self.inner {
IntersectionInner::Stitch { map, set } => {
(0, Some(std::cmp::min(map.len(), set.len())))
}
IntersectionInner::SearchMap { small_iter, .. } => (0, Some(small_iter.len())),
IntersectionInner::SearchSet { small_iter, .. } => (0, Some(small_iter.len())),
IntersectionInner::Answer(None) => (0, Some(0)),
IntersectionInner::Answer(Some(_)) => (1, Some(1)),
}
}
fn min(mut self) -> Option<(&'a K, &'a V)> {
self.next()
}
}
impl<K: Ord, V> FusedIterator for Pick<'_, K, V> {}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn pick() {
let mut map = BTreeMap::new();
map.insert("a", "a");
map.insert("b", "b");
map.insert("c", "c");
map.insert("d", "d");
let mut set = BTreeSet::new();
set.insert("a");
set.insert("b");
let pick = map.pick(&set);
let pick: Vec<_> = pick.collect();
assert_eq!(pick.len(), 2);
assert_eq!(pick[0], (&"a", &"a"));
assert_eq!(pick[1], (&"b", &"b"));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment