Skip to content

Instantly share code, notes, and snippets.

@ekzhang
Created July 1, 2024 20:30
Show Gist options
  • Save ekzhang/0db8693ea76dbdf80f693a3470d4ede9 to your computer and use it in GitHub Desktop.
Save ekzhang/0db8693ea76dbdf80f693a3470d4ede9 to your computer and use it in GitHub Desktop.
An incomplete interview question for performance-optimizing Hashlife in Rust
use std::{collections::HashMap, sync::Arc};
use md5::{Digest, Md5};
#[derive(Clone)]
pub enum Node {
Subtree(Arc<Subtree>),
Leaf(Leaf),
}
pub struct Subtree {
pub nw: Node,
pub ne: Node,
pub sw: Node,
pub se: Node,
pub hash: u128,
pub count: u64,
/// Stores the result of the central half-sized region after 2^k generations.
/// This starts empty, then it is filled in on-demand and cached.
pub result: HashMap<u8, Node>,
}
#[derive(Clone, Copy)]
pub struct Leaf {
/// Cells of the 4x4 region, packed into a 16-bit integer.
pub value: u16,
}
impl Node {
pub fn hash(&self) -> u128 {
match self {
Self::Subtree(subtree) => subtree.hash,
Self::Leaf(leaf) => 0xdeadbeefdeadbeef1234567812345678 ^ u128::from(leaf.value),
}
}
pub fn count(&self) -> u64 {
match self {
Self::Subtree(subtree) => subtree.count,
Self::Leaf(leaf) => leaf.value.count_ones() as u64,
}
}
/// Return the log2() of the width and height of the region.
pub fn order(&self) -> u8 {
match self {
Self::Subtree(subtree) => 1 + subtree.nw.order(),
Self::Leaf(_) => 2, // 4x4 base case
}
}
pub fn get_cell(&self, x: u64, y: u64) -> bool {
match self {
Self::Subtree(subtree) => {
let mid = 1 << subtree.nw.order();
match (x >= mid, y >= mid) {
(false, false) => subtree.nw.get_cell(x, y),
(true, false) => subtree.ne.get_cell(x - mid, y),
(false, true) => subtree.sw.get_cell(x, y - mid),
(true, true) => subtree.se.get_cell(x - mid, y - mid),
}
}
Self::Leaf(leaf) => (leaf.value >> (y * 4 + x)) & 1 != 0,
}
}
pub fn as_subtree(&self) -> &Subtree {
match self {
Self::Subtree(subtree) => subtree,
_ => panic!("expected a subtree"),
}
}
pub fn as_8x8(&self) -> u64 {
debug_assert!(self.order() == 3);
match self.as_subtree() {
Subtree {
nw: Node::Leaf(nw),
ne: Node::Leaf(ne),
sw: Node::Leaf(sw),
se: Node::Leaf(se),
..
} => {
let split = |x: u16| -> u64 {
u64::from(x & 0xf000)
| u64::from(x & 0x0f00) << 4
| u64::from(x & 0x00f0) << 8
| u64::from(x & 0x000f) << 12
};
split(nw.value)
| split(ne.value) << 4
| split(sw.value) << 32
| split(se.value) << 36
}
_ => unreachable!(),
}
}
}
/// Given an 8x8 region, produce the next generation of the 6x6 in the center.
fn step_8x8(value: u64) -> u64 {
let mask = 0x7050700000000000;
let mut result = 0;
for row in 0..6 {
for col in 0..6 {
let offset = row * 8 + col;
let neighbors = ((value >> offset) & mask).count_ones();
let cell = (value >> (offset + 9)) & 1 != 0;
if neighbors == 3 || (neighbors == 2 && cell) {
result = (result << 1) + 1;
} else {
result <<= 1;
}
}
result <<= 2;
}
result
}
fn extract_8x8_nw(value: u64) -> u16 {
(value & 0xf000000000000000
| (value >> 4) & 0x0f00000000000000
| (value >> 8) & 0x00f0000000000000
| (value >> 12) & 0x000f000000000000) as u16
}
/// Produce the hash for a subtree node from the hashes of its children.
fn combine_hashes(nw: u128, ne: u128, sw: u128, se: u128) -> u128 {
let mut digest = Md5::new();
digest.update(nw.to_le_bytes());
digest.update(ne.to_le_bytes());
digest.update(sw.to_le_bytes());
digest.update(se.to_le_bytes());
let bytes: [u8; 16] = digest.finalize().into();
u128::from_le_bytes(bytes)
}
#[derive(Clone)]
pub struct Board {
/// Coordinates of the top-left of the root node's region.
offset: (i64, i64),
/// The cells of the grid as a quadtree.
root: Node,
}
impl Board {
pub fn get_cell(&self, x: i64, y: i64) -> bool {
let dims = 1_i64 << self.root.order();
let x0 = x - self.offset.0;
let y0 = y - self.offset.1;
if x0 < 0 || x0 >= dims || y0 < 0 || y0 >= dims {
return false; // all cells outside of the root are dead
}
self.root.get_cell(x0 as u64, y0 as u64)
}
}
pub struct Engine {
/// A cache with previously-computed quadtree nodes.
cache: HashMap<u128, Node>,
}
impl Engine {
/// Create a new simulation engine with an empty cache.
pub fn new() -> Self {
let cache = HashMap::new();
Self { cache }
}
/// Simulate 2^k steps of a pattern.
pub fn step(&self, board: &Board, k: u8) -> Board {
let mut board = board.clone();
let count = board.root.count();
while board.root.order() < k + 2 {
board = self.expand(board);
}
if self.central(&board.root).count() < count {
board = self.expand(board);
}
let order = board.root.order();
let root = self.node_step(&board.root, k);
Board {
offset: (
board.offset.0 + i64::from(1 << (order - 2)),
board.offset.1 + i64::from(1 << (order - 2)),
),
root,
}
}
pub fn parse_rle(&self, pattern: &str) -> Board {
}
fn expand(&self, board: Board) -> Board {
let sub = board.root.as_subtree();
let k = sub.nw.order();
let zeros = self.zeros(k);
let offset = (
board.offset.0 - i64::from(1 << k),
board.offset.1 - i64::from(1 << k),
);
let root = self.subtree(
&self.subtree(&zeros, &zeros, &zeros, &sub.nw),
&self.subtree(&zeros, &zeros, &sub.ne, &zeros),
&self.subtree(&zeros, &sub.sw, &zeros, &zeros),
&self.subtree(&sub.se, &zeros, &zeros, &zeros),
);
Board { offset, root }
}
/// Create a subtree out of four child nodes.
fn subtree(&self, nw: &Node, ne: &Node, sw: &Node, se: &Node) -> Node {
let hash = combine_hashes(nw.hash(), ne.hash(), sw.hash(), se.hash());
if let Some(node) = self.cache.get(&hash) {
node.clone()
} else {
Node::Subtree(Arc::new(Subtree {
nw: nw.clone(),
ne: ne.clone(),
sw: sw.clone(),
se: se.clone(),
hash,
count: nw.count() + ne.count() + sw.count() + se.count(),
result: HashMap::new(),
}))
}
}
fn zeros(&self, order: u8) -> Node {
if order == 2 {
Node::Leaf(Leaf { value: 0 })
} else {
let sub = self.zeros(order - 1);
self.subtree(&sub, &sub, &sub, &sub)
}
}
fn hstack(&self, w: &Node, e: &Node) -> Node {
let w = w.as_subtree();
let e = e.as_subtree();
self.subtree(&w.ne, &e.nw, &w.se, &e.sw)
}
fn vstack(&self, n: &Node, s: &Node) -> Node {
let n = n.as_subtree();
let s = s.as_subtree();
self.subtree(&n.sw, &n.se, &s.nw, &s.ne)
}
fn central(&self, node: &Node) -> Node {
let node = node.as_subtree();
self.subtree(
&node.nw.as_subtree().se,
&node.ne.as_subtree().sw,
&node.sw.as_subtree().ne,
&node.se.as_subtree().nw,
)
}
/// Step through 2^k generations of a node in the quadtree, returning the central half.
fn node_step(&self, node: &Node, k: u8) -> Node {
// The node must be at least 2^(k+1) x 2^(k+1) in size to simulate 2^k generations.
debug_assert!(node.order() >= k + 2);
// Base case: 8x8 region.
if node.order() == 3 {
let value = node.as_8x8();
return Node::Leaf(Leaf {
value: if k == 0 {
extract_8x8_nw(step_8x8(value) >> 9)
} else {
debug_assert!(k == 1);
extract_8x8_nw(step_8x8(step_8x8(value)))
},
});
}
let sub = node.as_subtree();
if node.order() == k + 2 {
// Recursive case 1: the region is exactly 2^(k+1) x 2^(k+1) in size.
let s00 = self.node_step(&sub.nw, k - 1);
let s10 = self.node_step(&self.hstack(&sub.nw, &sub.ne), k - 1);
let s20 = self.node_step(&sub.ne, k - 1);
let s01 = self.node_step(&self.vstack(&sub.nw, &sub.sw), k - 1);
let s11 = self.node_step(&self.central(node), k - 1);
let s21 = self.node_step(&self.vstack(&sub.ne, &sub.se), k - 1);
let s02 = self.node_step(&sub.sw, k - 1);
let s12 = self.node_step(&self.hstack(&sub.sw, &sub.se), k - 1);
let s22 = self.node_step(&sub.se, k - 1);
self.subtree(
&self.node_step(&self.subtree(&s00, &s10, &s01, &s11), k - 1),
&self.node_step(&self.subtree(&s10, &s20, &s11, &s21), k - 1),
&self.node_step(&self.subtree(&s01, &s11, &s02, &s12), k - 1),
&self.node_step(&self.subtree(&s11, &s21, &s12, &s22), k - 1),
)
} else {
// Recursive case 2: the region is larger than 2^(k+1) x 2^(k+1).
let s00 = self.central(&sub.nw);
let s10 = self.central(&self.hstack(&sub.nw, &sub.ne));
let s20 = self.central(&sub.ne);
let s01 = self.central(&self.vstack(&sub.nw, &sub.sw));
let s11 = self.central(&self.central(node));
let s21 = self.central(&self.vstack(&sub.ne, &sub.se));
let s02 = self.central(&sub.sw);
let s12 = self.central(&self.hstack(&sub.sw, &sub.se));
let s22 = self.central(&sub.se);
self.subtree(
&self.node_step(&self.subtree(&s00, &s10, &s01, &s11), k),
&self.node_step(&self.subtree(&s10, &s20, &s11, &s21), k),
&self.node_step(&self.subtree(&s01, &s11, &s02, &s12), k),
&self.node_step(&self.subtree(&s11, &s21, &s12, &s22), k),
)
}
}
}
impl Default for Engine {
fn default() -> Self {
Self::new()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment