Search is done as a BFS. Board is stored as a uint32 with 3 bits per square.
After calculating each move the states for the next depth are stored in a hashmap so equivalent states are merged together with a count for how many times that state has been reached.
Some of my code snippets are from different versions so may not entirely match up with each other. Some code snippets aren't seperate functions anymore (lots of reorganising to reduce loads / stores and reuse data where possible) in my code so it's possible some of the types / function signatures aren't correct. Almost everything is in unsafe blocks due to the use of x64 compiler intrinsics. I haven't included any of the code for generating the lookup tables - nothing complicated there but it's an ugly mess of macros due to restrictions on what is allowed in const rust code (eg no for loops).
Final version uses multiple lookup tables based on the state of up to 4 squares to speed up moves, rotations and final scoring. I make use of rotations and vertical flip to combine the 8 equivalent states into a single canonical state (keeping 8 separate counts so the final hash can be calculated). I also use AVX intrinsics to perform operations on 8 values at the same time.
Fairly early on I added hardcoded answers for the empty board (only possible as an input as every move leaves at least one dice) so I didn't have to handle this anywhere in my code. This allowed me to use 0 as an invalid / unexplored value rather than having to have any extra variable initialisations when setting up hashtables / lookup tables (this didn't end up being useful but as I'd done the hardcoding it stayed in there). This means on submit test case 8 was solved in 0 ms so some of my scores look slightly better than they really are (my final version was getting around 22 on submit without the hardcodes).
Replace the default rust hash function with xorshift* (https://en.wikipedia.org/wiki/Xorshift#xorshift*) as the default rush hash is slow (for good reasons that aren't important here). We're turning a 27 bit value` into a 64 bit hash so collisions should be almost impossible and the data can't be crafted to cause excess collisions.
Having a more random bit layout is still desirable due to how the hashtables are implemented so we can't just use the 27 bit value as the hash even though that would be unique. Using the 27 bit value as an array index is too slow (even if we convert it down to the minimal base 7 form where there are only 40 millions states).
PEXT (https://www.felixcloutier.com/x86/pext) allows you to extract specific bits from an integer based on a mask and end up with those bits all adjacent in the new value. This allows me to quickly get a 12 bit value based on the state of any combination of 4 squares (for squares without 4 neighbours the higeher bits will never be set).
I use this to create a lookup table with 4096 entries for each square with the move precalculated to allow new state to be calculated with a single xor (clearing any captured squares and adding newly created value). I did eventually rewrite this to use the real amount per square (64/512/4096) for (2/3/4) neighbours but it didn't really help very much and it's much easier to explain if everything just uses the 12 bit values.
For every combination of square(0-8) and lookup value(0-4095) there were 12 values stored - up to 11 possible moves and the number of possible moves.
const fn get_mask(pos: usize) -> usize {
7 << (3 * pos)
}
const POS_MASKS: [u32; 9] = [
get_mask(1) | get_mask(3),
get_mask(0) | get_mask(2) | get_mask(4),
get_mask(1) | get_mask(5),
get_mask(0) | get_mask(6) | get_mask(4),
get_mask(1) | get_mask(3) | get_mask(5) | get_mask(7),
get_mask(2) | get_mask(8) | get_mask(4),
get_mask(3) | get_mask(7),
get_mask(4) | get_mask(6) | get_mask(8),
get_mask(5) | get_mask(7),
];
const PRE_MOVES : [[[u32; 12]; 4096]; 9] = precalc_moves();
fn move_lookup(pos as usize, adjacent_values as usize) -> &[u32] {
//lookup in global array and return a list of possible moves for this square given
// the supplied adjacent_values
let num_moves = PRE_MOVES[pos][adjacent_values][0]
return &PRE_MOVES[pos][adjacent_values][1..=num_moves];
}Then for every state the next states could be calculated quickly:
fn get_next_states(state: u32, count: u32, next: &mut HashMap<u32, u32>) {
for pos in 0..9 {
if (state & get_mask(pos)) == 0 {
let possible_moves = move_lookup(pos, _pext_u32(state, POS_MASKS[pos]));
for st_mod in possible_moves {
let new_state = (state ^ st_mod);
//add new_state + count to hashmap
*next.entry(new_state).or_insert(0) += count;
}
}
}
}After each move calculate all 8 possible equivalent boards (4x rotations and 4x rotations after a vertical mirror). Find the minimum of these 8 states to use as the canonical state. Instead of storing a single count for each state store 8 counts for each canonical state. The input counts array also needs reordering for each rotation / flip that occurs to get from the current state to the canonical version before adding to the total counts.
Noticing there were going to be 8 states I used avx to calculate all 8 rotations at once.
// 0 1 2 Rotated: 2 5 8 Flipped: 2 1 0
// 3 x 5 1 x 7 5 x 3
// 6 7 8 0 3 6 6 7 0
const LAYOUT: [[u32; 8]; 8] = [
[0, 6, 8, 2, 2, 0, 6, 8], //where digit 0 ends up in each possible combination
[1, 3, 7, 5, 1, 3, 7, 5],
[2, 0, 6, 8, 0, 6, 8, 2],
[3, 7, 5, 1, 5, 1, 3, 7],
[5, 1, 3, 7, 3, 7, 5, 1],
[6, 8, 2, 0, 8, 2, 0, 6],
[7, 5, 1, 3, 7, 5, 1, 3],
[8, 2, 0, 6, 6, 8, 2, 0],
];
// everything in layout but multiplied by 3 for shifting purposes.
// this is generated at compile time from layout
const LAYOUT_TIMES_3: [[u32; 8]] = ...
//for scoring purposes we have everything from layout as 10^x
const LAYOUT_AS_10_POW: [[u32; 8]] = ...Then to convert each state to an array of 8 equivalent states:
fn calc_rotations(new_state: u32) -> __m256i {
let inpav = _mm256_set1_epi32(new_state as i32);
//digit 4 is always fixed
let mut rotations = _mm256_and_si256(inpav, _mm256_set1_epi32(get_mask(4) as i32));
//for each of the remaining digits
for i in 0..8 {
let shifted = _mm256_srlv_epi32(inpav, _mm256_set1_epi32(LAYOUT_TIMES_3[i][0] as i32));
let digits = _mm256_and_si256(shifted, _mm256_set1_epi32(7));
let new_layout = _mm256_loadu_epi32(LAYOUT_TIMES_3[i].as_ptr() as *const __m256i);
let new_pos = _mm256_sllv_epi32(digits, new_layout);
rotations = _mm256_or_si256(rotations, pos);
}
return rotations;
}
const ROTATE: [[u32; 8]; 8] = [
[0, 1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 0, 7, 4, 5, 6],
[2, 3, 0, 1, 6, 7, 4, 5],
[3, 0, 1, 2, 5, 6, 7, 4],
[4, 5, 6, 7, 0, 1, 2, 3],
[5, 6, 7, 4, 3, 0, 1, 2],
[6, 7, 4, 5, 2, 3, 0, 1],
[7, 4, 5, 6, 1, 2, 3, 0],
];
fn calc_canonical(state: u32, counts: [u32; 8]) -> (u32, [u32; 8]) {
let mut states = [0; 8];
let rotated = calc_rotations(state);
_mm256_storeu_si256(states.as_mut_ptr() as *mut __m256i, rotated);
let best = states[0];
let bi = 0;
for i in 1..8 {
if states[i] < best {
best = states[i];
bi = i;
}
}
let avx_count = _mm256_loadu_epi32(counts.as_ptr() as *const __m256i);
let rotate_avax = _mm256_loadu_epi32(ROTATE[bi].as_ptr() as *const __m256i);
let rotated_counts = _mm256_permutevar8x32_epi32(avx_count, rotate_avx);
_mm256_storeu_si256(states.as_mut_ptr() as *mut __m256i, rotated_counts);
return (best, states);
}and to calculate the final hash given a single state and its 8 counts is similar but unfortunately needs lots of multiplications (integer multiplication in AVX isn't as fast as you'd expect (but not as slow as not using AVX)):
fn calc_score(st: u32, counts: [u32; 8]) -> __m256i {
let avx_count = _mm256_loadu_epi32(counts.as_ptr() as *const __m256i);
let inpav = _mm256_set1_epi32((st) as i32);
let st = _mm256_srli_epi32(
_mm256_and_si256(inpav, _mm256_set1_epi32(get_mask(4) as i32)),
4 * 3,
);
let mut avx_sum = _mm256_mullo_epi32(st, _mm256_set1_epi32(SCORE_MULT[4] as i32));
for i in 0..8 {
let shifted = _mm256_srlv_epi32(inpav, _mm256_set1_epi32(LAYOUT_TIMES_3[i][0] as i32));
let digits = _mm256_and_si256(shifted, _mm256_set1_epi32(7));
let mul = _mm256_loadu_epi32(LAYOUT_AS_10_POW[i].as_ptr() as *const __m256i);
let val = _mm256_mullo_epi32(digits,mul);
avx_sum = _mm256_add_epi32(avx_sum, val);
}
return _mm256_mullo_epi32(avx_sum, avx_count);
}Looking at the layout table above there are clearly two groups of four positions that don't interact with each other when rotating / mirroring (4 corners and 4 other squares on the edge). Using PEXT in a similar way to the move lookup I made some more lookup tables to precalculate xor state modifiers and final scores after rotations for these groups.
Each table has 4096 entries of 8 uint32 values.
Now my rotation code becomes:
const ROTATE_CORNERS : [[u32; 8]; 4096] = precalc_corners();
const ROTATE_EDGES : [[u32; 8]; 4096] = precalc_edges();
const CORNERS: u32 = get_mask(0) | get_mask(2) | get_mask(6) | get_mask(8);
const EDGES: u32 = get_mask(1) | get_mask(3) | get_mask(5) | get_mask(7);
fn calc_rotations(new_state: u32) -> __m256i {
let corners = _pext_u32(new_state, CORNERS);
let edges = _pext_u32(new_state, EDGES);
let corner_xor = _mm256_loadu_epi32(ROTATE_CORNERS[corners].as_ptr() as *const __m256i);
let edges_xor = _mm256_loadu_epi32(ROTATE_EDGES[edges].as_ptr() as *const __m256i);
let state = _mm256_set1_epi32(new_state as i32);
let rotations = _mm256_xor_si256(_mm256_xor_si256(state, corner_xor),edges_xor);
return rotations;
}And for scoring we can do something similar:
const ROTATE_CORNER_SCORES : [[u32; 8]; 4096] = precalc_corner_scores();
const ROTATE_EDGE_SCORES : [[u32; 8]; 4096] = precalc_edge_scores();
fn calc_score(st: u32, count: [u32; 8]) -> __m256i {
let avx_count = _mm256_loadu_epi32(count.as_ptr() as *const __m256i);
let inpav = _mm256_set1_epi32((st) as i32);
let edges = _pext_u32(new_state, EDGES);
let corners = _pext_u32(st, CORNERS);
let corner_score = _mm256_loadu_epi32(ROTATE_CORNER_SCORES[corners].as_ptr() as *const __m256i);
let edges_score = _mm256_loadu_epi32(ROTATE_EDGE_SCORES[edges].as_ptr() as *const __m256i);
let center_score =_mm256_set1_epi32((_pext_u32(st, get_mask(4)) * SCORE_MULT[4]) as i32);
let single_score = _mm256_add_epi32(edges_score, _mm256_add_epi32(corner_score,center_score));
return _mm256_mullo_epi32(single_score, avx_count)
}Up until this point I'd been using the standard rust HashMap. I'm still not sure I can do better as a generic hashmap but given there are some very specific optimisations possible for this use case I did manage to improve slightly. My version with the best submit performance came from a fixed size IndexMap that would overwrite values when each bucket was full. This means it does calculates more states than would be needed for the worst case inputs but seemed to be the trade off that worked best for me as making it bigger was slower and any attempt I did at a resizable version didn't help. This is probably the part I'm least happy with but couldn't find a way to get better performance.
I ended up with a design based on the swisstables used by the standard rust hashmap but with some changes:
-
no requirement to clear between depths.
-
no iterating as the data is stored in an external vector and the hashmap just stores indexes.
-
Fixed size:
- 8192 buckets, each containing 32 indexs (u32). This uses the highest 13 bits of the calculated hash to choose a bucket. There is also a 32 byte control structure and a byte for the last inserted index
- Each bucket has a 32 byte control structure, one byte for each value in the bucket. This gets set to the current depth xored with 8 other bits from the hash - this ensures we can never get a control match for the same state at different depths so we don't need to clear anything between depths.
- When looking up a state we can then use AVX (_mm256_movemask_epi8 / _mm256_cmpeq_epi8) with the control structure to check which of the 32 positions might have a match for the current state at the current depth.
- If a match to the control hash is found the full state is then checked for equality in the external vector.
- If the value isn't found we insert one step on from the previous insert on that bucket(wrapping round to 0 when we hit 32). As mentioned above this can overwrite useful data if more than 32 hashes have the same upper 13 bits at this depth.
This theretically allows up to 262144 different states per depth. From my testing the worst case I could find is actually around 150000 states (giving a ~60% load factor, ) so this should mean we very rarely get a full bucket that overwrites useful information as long as our hashes are evenly distributed (which unfortunately isn't the case and I do end up calculating some states multiple times).
- 8192 buckets, each containing 32 indexs (u32). This uses the highest 13 bits of the calculated hash to choose a bucket. There is also a 32 byte control structure and a byte for the last inserted index
Interestingly after the final validators were run it looks like my fastest version (which had half this cache size) was chosen over my final submit. This was taking 3-4x longer for the worst case tests but as my final score was 38ms it looks like all the test cases were fairly simple ones, similar to the difficulty of the original submit test cases.