Created
April 28, 2024 21:02
-
-
Save 0xqd/4a1f1b3355ecccc4e4b5debe2bd1bb5b to your computer and use it in GitHub Desktop.
Rust concurrent hashmap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// lessons: implement as &self, so we can use with Arc normally | |
use std::collections::HashMap; | |
use std::hash::Hash; | |
use std::sync::RwLock; | |
struct Entry (RwLock<i32>); | |
// The example is non hash, we can cast hashed_key to usize | |
struct ConcurrentHashMap { | |
inner: Box<Vec<RwLock<HashMap<i32, Entry>>>>, | |
shard_count: usize, | |
} | |
impl ConcurrentHashMap { | |
pub fn new() -> Self { | |
let shard_count = 16; | |
let shards = (0..shard_count) | |
.map(|_| RwLock::new(HashMap::new())) | |
.collect(); | |
ConcurrentHashMap { | |
shard_count, | |
inner: Box::new(shards) // default to 16 shards | |
} | |
} | |
/// insert | |
pub fn insert(&self, key: i32, value: i32) { | |
let shard_idx = key as usize % self.shard_count; | |
let shard = &self.inner[shard_idx]; | |
if let Ok(mut shard) = shard.try_write() { | |
// we can just use shard.insert to replace, but this is better since it avoids memory allocator | |
// by reusing the memory. | |
if !shard.contains_key(&key) { | |
shard.insert(key, Entry(RwLock::new(value))); | |
} else { | |
if let Some(entry) = shard.get(&key) { | |
if let Ok(mut entry) = entry.0.try_write() { | |
*entry = value; | |
} | |
} | |
} | |
} | |
} | |
/// get | |
pub fn get(&self, key: i32) -> Option<i32> { | |
// simplify shard finder, in realworld, we use hash function cast to usize to find the shard | |
let shard_idx = key as usize % self.shard_count; | |
let shard = &self.inner[shard_idx]; | |
if let Ok(shard) = shard.try_read() { | |
if let Some(entry) = shard.get(&key) { | |
if let Ok(entry) = entry.0.try_read() { | |
return Some(*entry); | |
} | |
} | |
} | |
None | |
} | |
} | |
// Non goal: no compatible traits. | |
#[cfg(test)] | |
mod tests { | |
use std::thread; | |
use super::*; | |
#[test] | |
fn test_single_thread() { | |
let mut chm = ConcurrentHashMap::new(); | |
chm.insert(1, 2); | |
chm.insert(2, 3); | |
chm.insert(3, 4); | |
assert_eq!(chm.get(1), Some(2)); | |
assert_eq!(chm.get(2), Some(3)); | |
assert_eq!(chm.get(3), Some(4)); | |
} | |
#[test] | |
fn test_two_threads() { | |
let chm = ConcurrentHashMap::new(); | |
thread::scope(|s| { | |
s.spawn(|| { | |
chm.insert(1, 2); | |
chm.insert(2, 3); | |
chm.insert(3, 4); | |
}); | |
s.spawn(|| { | |
chm.insert(4, 5); | |
chm.insert(5, 6); | |
chm.insert(6, 7); | |
}); | |
}); | |
assert_eq!(chm.get(1), Some(2)); | |
assert_eq!(chm.get(2), Some(3)); | |
assert_eq!(chm.get(3), Some(4)); | |
assert_eq!(chm.get(4), Some(5)); | |
assert_eq!(chm.get(5), Some(6)); | |
assert_eq!(chm.get(6), Some(7)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment