Instantly share code, notes, and snippets.

@ssylvan /main.cpp
Last active Aug 17, 2018

Embed
What would you like to do?
Sudoku solver
#include <intrin.h>
#include <assert.h>
#include <vector>
#include <chrono>
#include <memory>
int get_set_bit(uint16_t x) {
assert(__popcnt16(x) == 1);
unsigned long index;
_BitScanForward(&index, x);
return index;
}
// Go through all peers. x and y are implied to contain the current square,
// and the body will be run with x and y set to each element in turn.
// After the loop has finished, x and y will not have been modified.
#define FOREACH_PEER_ELEM(body) {\
int currentx = x; int currenty = y; \
for (int x = 0; x < 9; ++x) { if (x != currentx) body; } \
for (int y = 0; y < 9; ++y) { if (y != currenty) body; } \
for(int x = (currentx/3)*3; x < (currentx/3 + 1)*3; ++x) {\
for(int y = (currenty/3)*3; y < (currenty/3 + 1)*3; ++y) {\
if (x != currentx && y != currenty) body\
} \
} \
}
#define FOREACH_SET_BIT(digit_mask_input, body) { \
unsigned long bit_index; uint16_t digit_mask = (digit_mask_input); \
while (_BitScanForward(&bit_index, digit_mask)) { \
digit_mask &= digit_mask-1; body\
}\
}
bool is_exactly_one_bit_set(uint16_t x) {
assert(x > 0);
return (x & (x - 1)) == 0;
}
struct board {
// Contains a 9-bit mask per square, indicating possible digits in each square
uint16_t squares[9][9];
// For each "unit" (rows, cols, blocks) and each digit (0..8), store a bit mask
// indicating where this digit is still possible. This could be reconstructed from
// the squares array above, but it's faster to maintain a redundant "reverse index".
uint16_t possible_digit_locations_blocks[3][3][9];
uint16_t possible_digit_locations_rows[9][9];
uint16_t possible_digit_locations_cols[9][9];
board() {
std::fill_n((uint16_t*)squares, 81, (uint16_t)0x1ff);
std::fill_n((uint16_t*)possible_digit_locations_rows, 9 * 9, (uint16_t)0x1FF);
std::fill_n((uint16_t*)possible_digit_locations_cols, 9 * 9, (uint16_t)0x1FF);
std::fill_n((uint16_t*)possible_digit_locations_blocks, 9 * 9, (uint16_t)0x1FF);
}
uint8_t next_square_candidate_x = 0, next_square_candidate_y = 0;
bool eliminate(int x, int y, uint16_t digits_to_eliminate) {
digits_to_eliminate &= squares[x][y]; // Only eliminate digits that exists in square
if (digits_to_eliminate == 0) {
return true; // already eliminated.
}
squares[x][y] &= ~digits_to_eliminate; // clear digit
uint16_t remaining_digits = squares[x][y];
if (remaining_digits == 0) {
return false; // contradiction found, no possible digits left.
}
int block_x = x / 3;
int block_y = y / 3;
int block_bit_index = (x%3) + 3*(y%3);
// Clear out the "reverse index"
FOREACH_SET_BIT(digits_to_eliminate,
assert(possible_digit_locations_cols[x][bit_index] & (1 << y));
possible_digit_locations_cols[x][bit_index] &= ~(1 << y);
if (possible_digit_locations_cols[x][bit_index] == 0) {
return false;
}
assert(possible_digit_locations_rows[y][bit_index] & (1 << x));
possible_digit_locations_rows[y][bit_index] &= ~(1 << x);
if (possible_digit_locations_rows[y][bit_index] == 0) {
return false;
}
assert(possible_digit_locations_blocks[block_x][block_y][bit_index] & (1 << block_bit_index));
possible_digit_locations_blocks[block_x][block_y][bit_index] &= ~(1 << block_bit_index);
if (possible_digit_locations_blocks[block_x][block_y][bit_index] == 0) {
return false;
}
);
// If we've eliminated all but one digit, then we should eliminate that digit from all the peers.
if (is_exactly_one_bit_set(remaining_digits)) {
int remaining_digit_index = get_set_bit(remaining_digits);
// Get all the positions where this digit is set in the current row, column and block,
// and eliminate them from those squares.
// Start with the current row.
uint16_t remaining_pos_mask = possible_digit_locations_rows[y][remaining_digit_index];
remaining_pos_mask &= ~(1 << x); // Don't eliminate from the current square.
FOREACH_SET_BIT(remaining_pos_mask,
if (!eliminate(bit_index,y, remaining_digits)) {
return false;
}
)
// Next eliminate it from the column.
remaining_pos_mask = possible_digit_locations_cols[x][remaining_digit_index];
remaining_pos_mask &= ~(1 << y); // Don't eliminate from the current square
FOREACH_SET_BIT(remaining_pos_mask,
if (!eliminate(x, bit_index, remaining_digits)) {
return false;
}
)
// Next eliminate it from the current block
remaining_pos_mask = possible_digit_locations_blocks[block_x][block_y][remaining_digit_index];
remaining_pos_mask &= ~(1 << block_bit_index); // Don't eliminate from the current square
FOREACH_SET_BIT(remaining_pos_mask,
int x_offset = bit_index % 3;
int y_offset = bit_index / 3;
if (!eliminate(block_x*3 + x_offset, block_y *3 + y_offset, remaining_digits)) {
return false;
}
)
}
// For each digit we just eliminated, find if it now only has one remaining posible location
// in either the row, column or block. If so, set the digit
FOREACH_SET_BIT(digits_to_eliminate,
// Check the row
if (is_exactly_one_bit_set(possible_digit_locations_rows[y][bit_index])) {
int digit_x = get_set_bit(possible_digit_locations_rows[y][bit_index]);
if (!set_digit(digit_x, y, bit_index)) {
return false;
}
}
// Column
if (is_exactly_one_bit_set(possible_digit_locations_cols[x][bit_index])) {
int digit_y = get_set_bit(possible_digit_locations_cols[x][bit_index]);
if (!set_digit(x, digit_y, bit_index)) {
return false;
}
}
// Block
if (is_exactly_one_bit_set(possible_digit_locations_blocks[block_x][block_y][bit_index])) {
int bit_with_digit = get_set_bit(possible_digit_locations_blocks[block_x][block_y][bit_index]);
int x_offset = bit_with_digit % 3;
int y_offset = bit_with_digit / 3;
if (!set_digit(block_x*3 + x_offset, block_y *3 + y_offset, bit_index)) {
return false;
}
}
)
// While we're here, check if the square has a popcnt of 2, this means it's has the minimum
// number of possibilities without being "done", making it an excellent candidate for the next
// trial assignment.
if (__popcnt16(squares[x][y]) == 2) {
next_square_candidate_x = x;
next_square_candidate_y = y;
}
return true;
}
bool set_digit(int x, int y, int d) {
int digit_mask = 1 << d;
assert(squares[x][y] & digit_mask);
if (!eliminate(x, y, ~digit_mask)) {
return false;
}
assert(squares[x][y] == digit_mask);
return true;
}
bool find_next_square_to_assign(int& x, int& y) const {
// Check most recent squares to have been found to have 2 digits, if it still
// does, just use that one since it would be a minimum.
if (__popcnt16(squares[next_square_candidate_x][next_square_candidate_y]) == 2) {
x = next_square_candidate_x;
y = next_square_candidate_y;
return true;
}
int min_count = 10;
for (int i = 0; i < 9; ++i) {
for (int j = 0; j < 9; ++j) {
int c = __popcnt16(squares[i][j]);
// Again, early out on popcnt == 2, since that would be a minimum
if (c == 2) {
x = i;
y = j;
return true;
}
// Else find smallest
if (c > 1 && c < min_count) {
x = i;
y = j;
min_count = c;
}
}
}
return min_count != 10; // Did we find a square to process?
}
void print() {
for (int y = 0; y < 9; ++y) {
for (int x = 0; x < 9; ++x) {
printf("[");
for (int i = 0; i < 9; ++i) {
if (squares[x][y] & (1 << i)) {
printf("%d", i + 1);
}
else {
printf(" ");
}
}
printf("] ");
}
printf("\n");
}
}
bool is_solved() const {
for (int y = 0; y < 9; ++y) {
for (int x = 0; x < 9; ++x) {
if (!is_exactly_one_bit_set(squares[x][y]))
return false;
// Make sure this digit isn't in any of the peers
int d = get_set_bit(squares[x][y]);
FOREACH_PEER_ELEM({
if (get_set_bit(squares[x][y]) == d)
return false;
})
}
}
return true;
}
};
bool search(const board& current_board, board& final_board) {
// Find the next square to do a trial assignment on
int x, y;
if (current_board.find_next_square_to_assign(x, y)) {
uint16_t digits = current_board.squares[x][y];
// Then assign each possible digit to this square
FOREACH_SET_BIT(digits, {
board board_copy = current_board;
// If we can successfully set this digit, do search from here
if (board_copy.set_digit(x, y, bit_index)) {
if (search(board_copy, final_board)) {
return true;
}
}
})
}
else {
// No more squares to assign, so we're done!
assert(current_board.is_solved());
final_board = current_board;
return true;
}
return false;
}
int main()
{
// Load sudoku grids
FILE* sudoku_file;
if (fopen_s(&sudoku_file, "sudoku17.txt", "r")) {
printf("Failed to open sudoku file");
return -1;
}
std::vector<std::unique_ptr<const char[]>> lines;
while (!feof(sudoku_file)) {
char line_buf[82];
fgets(line_buf, _countof(line_buf), sudoku_file);
size_t n = strlen(line_buf);
if (n < 81)
continue;
lines.emplace_back(_strdup(line_buf));
}
fclose(sudoku_file);
for (int runs = 0; runs < 1; ++runs) { // Increase the loop count here for performance profiling
std::chrono::high_resolution_clock clock;
// Solve all the grids
auto start = clock.now();
std::vector<board> boards;
for (size_t board_ix = 0; board_ix < lines.size(); ++board_ix) {
board b;
for (int row = 0; row < 9; ++row) {
for (int col = 0; col < 9; ++col) {
char c = lines[board_ix][row * 9 + col];
if (c > '0' && c <= '9') {
bool res = b.set_digit(col, row, c - '1'); // we use zero-based digits, hence the '1'
assert(res);
}
}
}
boards.push_back(b);
}
for (int i = 0; i < boards.size(); ++i) {
board final_board;
if (!search(boards[i], final_board)) {
printf("Failed to solve grid %d\n", i);
boards[i].print();
}
}
auto end = clock.now();
uint64_t nanosecs = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
printf("Took %.1f milliseconds for %d puzzles, or %.2f microseconds average\n", (float)nanosecs / 1000000, (int)boards.size(), (float)nanosecs / (1000.0f*boards.size()));
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment