Last active
August 20, 2020 08:29
-
-
Save SkiFire13/12a61302e35ffe52c0337e0959fb15f6 to your computer and use it in GitHub Desktop.
Possible optimization for the GroupingMap::aggregate method using unsafe
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::{HashMap, hash_map::Entry}; | |
use std::hash::Hash; | |
use std::panic::{self, AssertUnwindSafe}; | |
use std::mem::MaybeUninit; | |
pub struct GroupingMap<I> { | |
iter: I, | |
} | |
impl<I, K, V> GroupingMap<I> | |
where | |
I: Iterator<Item = (K, V)>, | |
K: Hash + Eq, | |
{ | |
pub fn aggregate<FO, R>(self, mut operation: FO) -> HashMap<K, R> | |
where FO: FnMut(Option<R>, &K, V) -> Option<R>, | |
{ | |
let mut destination_map = HashMap::new(); | |
// Loop invariant: The map and every element inside it is in a valid state. | |
for (key, val) in self.iter { | |
match destination_map.entry(key) { | |
Entry::Occupied(mut occupied_entry) => { | |
let key = occupied_entry.key(); | |
// Safety: `occupied_entry.get()` is a reference so it's valid for reads. | |
// This is intended to move the value out of the entry. From now on `acc` is responsible | |
// for dropping the value. We're responsible for removing/overwriting the old value without | |
// dropping it (which could lead to double frees). Until that moment the map is in an invalid state. | |
let acc = unsafe { std::ptr::read(occupied_entry.get()) }; | |
// TODO: Not sure if it is ok to use `AssertUnwindSafe`. | |
match panic::catch_unwind(AssertUnwindSafe(|| operation(Some(acc), key, val))) { | |
// `ptr::write` overwrites the old value without dropping it. The map is not valid again. | |
Ok(Some(op_res)) => { | |
// Safety: `occupied_entry.get_mut()` is a mutable reference so it's valid for writes. | |
unsafe { std::ptr::write(occupied_entry.get_mut(), op_res); } | |
} | |
// `remove` the old value but doesn't drop it using `mem::forget`. The map is not valid again. | |
Ok(None) => { | |
std::mem::forget(occupied_entry.remove()); | |
} | |
// `remove` the old value but doesn't drop it using `mem::forget`. The map is not valid again. | |
Err(payload) => { | |
std::mem::forget(occupied_entry.remove()); | |
// This runs after the removal and forget of the old value, | |
// so the map is valid and can be safely dropped. | |
panic::resume_unwind(payload); | |
} | |
} | |
} | |
Entry::Vacant(vacant_entry) => { | |
let key = vacant_entry.key(); | |
if let Some(new_value) = operation(None, key, val) { | |
vacant_entry.insert(new_value); | |
} | |
} | |
} | |
} | |
destination_map | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment