Skip to content

Instantly share code, notes, and snippets.

@SkiFire13
Last active August 20, 2020 08:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SkiFire13/12a61302e35ffe52c0337e0959fb15f6 to your computer and use it in GitHub Desktop.
Save SkiFire13/12a61302e35ffe52c0337e0959fb15f6 to your computer and use it in GitHub Desktop.
Possible optimization for the GroupingMap::aggregate method using unsafe
use std::collections::{HashMap, hash_map::Entry};
use std::hash::Hash;
use std::panic::{self, AssertUnwindSafe};
use std::mem::MaybeUninit;
pub struct GroupingMap<I> {
iter: I,
}
impl<I, K, V> GroupingMap<I>
where
I: Iterator<Item = (K, V)>,
K: Hash + Eq,
{
pub fn aggregate<FO, R>(self, mut operation: FO) -> HashMap<K, R>
where FO: FnMut(Option<R>, &K, V) -> Option<R>,
{
let mut destination_map = HashMap::new();
// Loop invariant: The map and every element inside it is in a valid state.
for (key, val) in self.iter {
match destination_map.entry(key) {
Entry::Occupied(mut occupied_entry) => {
let key = occupied_entry.key();
// Safety: `occupied_entry.get()` is a reference so it's valid for reads.
// This is intended to move the value out of the entry. From now on `acc` is responsible
// for dropping the value. We're responsible for removing/overwriting the old value without
// dropping it (which could lead to double frees). Until that moment the map is in an invalid state.
let acc = unsafe { std::ptr::read(occupied_entry.get()) };
// TODO: Not sure if it is ok to use `AssertUnwindSafe`.
match panic::catch_unwind(AssertUnwindSafe(|| operation(Some(acc), key, val))) {
// `ptr::write` overwrites the old value without dropping it. The map is not valid again.
Ok(Some(op_res)) => {
// Safety: `occupied_entry.get_mut()` is a mutable reference so it's valid for writes.
unsafe { std::ptr::write(occupied_entry.get_mut(), op_res); }
}
// `remove` the old value but doesn't drop it using `mem::forget`. The map is not valid again.
Ok(None) => {
std::mem::forget(occupied_entry.remove());
}
// `remove` the old value but doesn't drop it using `mem::forget`. The map is not valid again.
Err(payload) => {
std::mem::forget(occupied_entry.remove());
// This runs after the removal and forget of the old value,
// so the map is valid and can be safely dropped.
panic::resume_unwind(payload);
}
}
}
Entry::Vacant(vacant_entry) => {
let key = vacant_entry.key();
if let Some(new_value) = operation(None, key, val) {
vacant_entry.insert(new_value);
}
}
}
}
destination_map
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment