Sudoku solver
#include <intrin.h> | |
#include <assert.h> | |
#include <vector> | |
#include <chrono> | |
#include <memory> | |
int get_set_bit(uint16_t x) { | |
assert(__popcnt16(x) == 1); | |
unsigned long index; | |
_BitScanForward(&index, x); | |
return index; | |
} | |
// Go through all peers. x and y are implied to contain the current square, | |
// and the body will be run with x and y set to each element in turn. | |
// After the loop has finished, x and y will not have been modified. | |
#define FOREACH_PEER_ELEM(body) {\ | |
int currentx = x; int currenty = y; \ | |
for (int x = 0; x < 9; ++x) { if (x != currentx) body; } \ | |
for (int y = 0; y < 9; ++y) { if (y != currenty) body; } \ | |
for(int x = (currentx/3)*3; x < (currentx/3 + 1)*3; ++x) {\ | |
for(int y = (currenty/3)*3; y < (currenty/3 + 1)*3; ++y) {\ | |
if (x != currentx && y != currenty) body\ | |
} \ | |
} \ | |
} | |
#define FOREACH_SET_BIT(digit_mask_input, body) { \ | |
unsigned long bit_index; uint16_t digit_mask = (digit_mask_input); \ | |
while (_BitScanForward(&bit_index, digit_mask)) { \ | |
digit_mask &= digit_mask-1; body\ | |
}\ | |
} | |
bool is_exactly_one_bit_set(uint16_t x) { | |
assert(x > 0); | |
return (x & (x - 1)) == 0; | |
} | |
struct board { | |
// Contains a 9-bit mask per square, indicating possible digits in each square | |
uint16_t squares[9][9]; | |
// For each "unit" (rows, cols, blocks) and each digit (0..8), store a bit mask | |
// indicating where this digit is still possible. This could be reconstructed from | |
// the squares array above, but it's faster to maintain a redundant "reverse index". | |
uint16_t possible_digit_locations_blocks[3][3][9]; | |
uint16_t possible_digit_locations_rows[9][9]; | |
uint16_t possible_digit_locations_cols[9][9]; | |
board() { | |
std::fill_n((uint16_t*)squares, 81, (uint16_t)0x1ff); | |
std::fill_n((uint16_t*)possible_digit_locations_rows, 9 * 9, (uint16_t)0x1FF); | |
std::fill_n((uint16_t*)possible_digit_locations_cols, 9 * 9, (uint16_t)0x1FF); | |
std::fill_n((uint16_t*)possible_digit_locations_blocks, 9 * 9, (uint16_t)0x1FF); | |
} | |
uint8_t next_square_candidate_x = 0, next_square_candidate_y = 0; | |
bool eliminate(int x, int y, uint16_t digits_to_eliminate) { | |
digits_to_eliminate &= squares[x][y]; // Only eliminate digits that exists in square | |
if (digits_to_eliminate == 0) { | |
return true; // already eliminated. | |
} | |
squares[x][y] &= ~digits_to_eliminate; // clear digit | |
uint16_t remaining_digits = squares[x][y]; | |
if (remaining_digits == 0) { | |
return false; // contradiction found, no possible digits left. | |
} | |
int block_x = x / 3; | |
int block_y = y / 3; | |
int block_bit_index = (x%3) + 3*(y%3); | |
// Clear out the "reverse index" | |
FOREACH_SET_BIT(digits_to_eliminate, | |
assert(possible_digit_locations_cols[x][bit_index] & (1 << y)); | |
possible_digit_locations_cols[x][bit_index] &= ~(1 << y); | |
if (possible_digit_locations_cols[x][bit_index] == 0) { | |
return false; | |
} | |
assert(possible_digit_locations_rows[y][bit_index] & (1 << x)); | |
possible_digit_locations_rows[y][bit_index] &= ~(1 << x); | |
if (possible_digit_locations_rows[y][bit_index] == 0) { | |
return false; | |
} | |
assert(possible_digit_locations_blocks[block_x][block_y][bit_index] & (1 << block_bit_index)); | |
possible_digit_locations_blocks[block_x][block_y][bit_index] &= ~(1 << block_bit_index); | |
if (possible_digit_locations_blocks[block_x][block_y][bit_index] == 0) { | |
return false; | |
} | |
); | |
// If we've eliminated all but one digit, then we should eliminate that digit from all the peers. | |
if (is_exactly_one_bit_set(remaining_digits)) { | |
int remaining_digit_index = get_set_bit(remaining_digits); | |
// Get all the positions where this digit is set in the current row, column and block, | |
// and eliminate them from those squares. | |
// Start with the current row. | |
uint16_t remaining_pos_mask = possible_digit_locations_rows[y][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << x); // Don't eliminate from the current square. | |
FOREACH_SET_BIT(remaining_pos_mask, | |
if (!eliminate(bit_index,y, remaining_digits)) { | |
return false; | |
} | |
) | |
// Next eliminate it from the column. | |
remaining_pos_mask = possible_digit_locations_cols[x][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << y); // Don't eliminate from the current square | |
FOREACH_SET_BIT(remaining_pos_mask, | |
if (!eliminate(x, bit_index, remaining_digits)) { | |
return false; | |
} | |
) | |
// Next eliminate it from the current block | |
remaining_pos_mask = possible_digit_locations_blocks[block_x][block_y][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << block_bit_index); // Don't eliminate from the current square | |
FOREACH_SET_BIT(remaining_pos_mask, | |
int x_offset = bit_index % 3; | |
int y_offset = bit_index / 3; | |
if (!eliminate(block_x*3 + x_offset, block_y *3 + y_offset, remaining_digits)) { | |
return false; | |
} | |
) | |
} | |
// For each digit we just eliminated, find if it now only has one remaining posible location | |
// in either the row, column or block. If so, set the digit | |
FOREACH_SET_BIT(digits_to_eliminate, | |
// Check the row | |
if (is_exactly_one_bit_set(possible_digit_locations_rows[y][bit_index])) { | |
int digit_x = get_set_bit(possible_digit_locations_rows[y][bit_index]); | |
if (!set_digit(digit_x, y, bit_index)) { | |
return false; | |
} | |
} | |
// Column | |
if (is_exactly_one_bit_set(possible_digit_locations_cols[x][bit_index])) { | |
int digit_y = get_set_bit(possible_digit_locations_cols[x][bit_index]); | |
if (!set_digit(x, digit_y, bit_index)) { | |
return false; | |
} | |
} | |
// Block | |
if (is_exactly_one_bit_set(possible_digit_locations_blocks[block_x][block_y][bit_index])) { | |
int bit_with_digit = get_set_bit(possible_digit_locations_blocks[block_x][block_y][bit_index]); | |
int x_offset = bit_with_digit % 3; | |
int y_offset = bit_with_digit / 3; | |
if (!set_digit(block_x*3 + x_offset, block_y *3 + y_offset, bit_index)) { | |
return false; | |
} | |
} | |
) | |
// While we're here, check if the square has a popcnt of 2, this means it's has the minimum | |
// number of possibilities without being "done", making it an excellent candidate for the next | |
// trial assignment. | |
if (__popcnt16(squares[x][y]) == 2) { | |
next_square_candidate_x = x; | |
next_square_candidate_y = y; | |
} | |
return true; | |
} | |
bool set_digit(int x, int y, int d) { | |
int digit_mask = 1 << d; | |
assert(squares[x][y] & digit_mask); | |
if (!eliminate(x, y, ~digit_mask)) { | |
return false; | |
} | |
assert(squares[x][y] == digit_mask); | |
return true; | |
} | |
bool find_next_square_to_assign(int& x, int& y) const { | |
// Check most recent squares to have been found to have 2 digits, if it still | |
// does, just use that one since it would be a minimum. | |
if (__popcnt16(squares[next_square_candidate_x][next_square_candidate_y]) == 2) { | |
x = next_square_candidate_x; | |
y = next_square_candidate_y; | |
return true; | |
} | |
int min_count = 10; | |
for (int i = 0; i < 9; ++i) { | |
for (int j = 0; j < 9; ++j) { | |
int c = __popcnt16(squares[i][j]); | |
// Again, early out on popcnt == 2, since that would be a minimum | |
if (c == 2) { | |
x = i; | |
y = j; | |
return true; | |
} | |
// Else find smallest | |
if (c > 1 && c < min_count) { | |
x = i; | |
y = j; | |
min_count = c; | |
} | |
} | |
} | |
return min_count != 10; // Did we find a square to process? | |
} | |
void print() { | |
for (int y = 0; y < 9; ++y) { | |
for (int x = 0; x < 9; ++x) { | |
printf("["); | |
for (int i = 0; i < 9; ++i) { | |
if (squares[x][y] & (1 << i)) { | |
printf("%d", i + 1); | |
} | |
else { | |
printf(" "); | |
} | |
} | |
printf("] "); | |
} | |
printf("\n"); | |
} | |
} | |
bool is_solved() const { | |
for (int y = 0; y < 9; ++y) { | |
for (int x = 0; x < 9; ++x) { | |
if (!is_exactly_one_bit_set(squares[x][y])) | |
return false; | |
// Make sure this digit isn't in any of the peers | |
int d = get_set_bit(squares[x][y]); | |
FOREACH_PEER_ELEM({ | |
if (get_set_bit(squares[x][y]) == d) | |
return false; | |
}) | |
} | |
} | |
return true; | |
} | |
}; | |
bool search(const board& current_board, board& final_board) { | |
// Find the next square to do a trial assignment on | |
int x, y; | |
if (current_board.find_next_square_to_assign(x, y)) { | |
uint16_t digits = current_board.squares[x][y]; | |
// Then assign each possible digit to this square | |
FOREACH_SET_BIT(digits, { | |
board board_copy = current_board; | |
// If we can successfully set this digit, do search from here | |
if (board_copy.set_digit(x, y, bit_index)) { | |
if (search(board_copy, final_board)) { | |
return true; | |
} | |
} | |
}) | |
} | |
else { | |
// No more squares to assign, so we're done! | |
assert(current_board.is_solved()); | |
final_board = current_board; | |
return true; | |
} | |
return false; | |
} | |
int main() | |
{ | |
// Load sudoku grids | |
FILE* sudoku_file; | |
if (fopen_s(&sudoku_file, "sudoku17.txt", "r")) { | |
printf("Failed to open sudoku file"); | |
return -1; | |
} | |
std::vector<std::unique_ptr<const char[]>> lines; | |
while (!feof(sudoku_file)) { | |
char line_buf[82]; | |
fgets(line_buf, _countof(line_buf), sudoku_file); | |
size_t n = strlen(line_buf); | |
if (n < 81) | |
continue; | |
lines.emplace_back(_strdup(line_buf)); | |
} | |
fclose(sudoku_file); | |
for (int runs = 0; runs < 1; ++runs) { // Increase the loop count here for performance profiling | |
std::chrono::high_resolution_clock clock; | |
// Solve all the grids | |
auto start = clock.now(); | |
std::vector<board> boards; | |
for (size_t board_ix = 0; board_ix < lines.size(); ++board_ix) { | |
board b; | |
for (int row = 0; row < 9; ++row) { | |
for (int col = 0; col < 9; ++col) { | |
char c = lines[board_ix][row * 9 + col]; | |
if (c > '0' && c <= '9') { | |
bool res = b.set_digit(col, row, c - '1'); // we use zero-based digits, hence the '1' | |
assert(res); | |
} | |
} | |
} | |
boards.push_back(b); | |
} | |
for (int i = 0; i < boards.size(); ++i) { | |
board final_board; | |
if (!search(boards[i], final_board)) { | |
printf("Failed to solve grid %d\n", i); | |
boards[i].print(); | |
} | |
} | |
auto end = clock.now(); | |
uint64_t nanosecs = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count(); | |
printf("Took %.1f milliseconds for %d puzzles, or %.2f microseconds average\n", (float)nanosecs / 1000000, (int)boards.size(), (float)nanosecs / (1000.0f*boards.size())); | |
} | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment