Skip to content

Instantly share code, notes, and snippets.

@0xqd
Created April 28, 2024 21:02
Show Gist options
  • Save 0xqd/4a1f1b3355ecccc4e4b5debe2bd1bb5b to your computer and use it in GitHub Desktop.
Save 0xqd/4a1f1b3355ecccc4e4b5debe2bd1bb5b to your computer and use it in GitHub Desktop.
Rust concurrent hashmap
// lessons: implement as &self, so we can use with Arc normally
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::RwLock;
struct Entry (RwLock<i32>);
// The example is non hash, we can cast hashed_key to usize
struct ConcurrentHashMap {
inner: Box<Vec<RwLock<HashMap<i32, Entry>>>>,
shard_count: usize,
}
impl ConcurrentHashMap {
pub fn new() -> Self {
let shard_count = 16;
let shards = (0..shard_count)
.map(|_| RwLock::new(HashMap::new()))
.collect();
ConcurrentHashMap {
shard_count,
inner: Box::new(shards) // default to 16 shards
}
}
/// insert
pub fn insert(&self, key: i32, value: i32) {
let shard_idx = key as usize % self.shard_count;
let shard = &self.inner[shard_idx];
if let Ok(mut shard) = shard.try_write() {
// we can just use shard.insert to replace, but this is better since it avoids memory allocator
// by reusing the memory.
if !shard.contains_key(&key) {
shard.insert(key, Entry(RwLock::new(value)));
} else {
if let Some(entry) = shard.get(&key) {
if let Ok(mut entry) = entry.0.try_write() {
*entry = value;
}
}
}
}
}
/// get
pub fn get(&self, key: i32) -> Option<i32> {
// simplify shard finder, in realworld, we use hash function cast to usize to find the shard
let shard_idx = key as usize % self.shard_count;
let shard = &self.inner[shard_idx];
if let Ok(shard) = shard.try_read() {
if let Some(entry) = shard.get(&key) {
if let Ok(entry) = entry.0.try_read() {
return Some(*entry);
}
}
}
None
}
}
// Non goal: no compatible traits.
#[cfg(test)]
mod tests {
use std::thread;
use super::*;
#[test]
fn test_single_thread() {
let mut chm = ConcurrentHashMap::new();
chm.insert(1, 2);
chm.insert(2, 3);
chm.insert(3, 4);
assert_eq!(chm.get(1), Some(2));
assert_eq!(chm.get(2), Some(3));
assert_eq!(chm.get(3), Some(4));
}
#[test]
fn test_two_threads() {
let chm = ConcurrentHashMap::new();
thread::scope(|s| {
s.spawn(|| {
chm.insert(1, 2);
chm.insert(2, 3);
chm.insert(3, 4);
});
s.spawn(|| {
chm.insert(4, 5);
chm.insert(5, 6);
chm.insert(6, 7);
});
});
assert_eq!(chm.get(1), Some(2));
assert_eq!(chm.get(2), Some(3));
assert_eq!(chm.get(3), Some(4));
assert_eq!(chm.get(4), Some(5));
assert_eq!(chm.get(5), Some(6));
assert_eq!(chm.get(6), Some(7));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment