sudoku solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// CP hath spoiled me, I'm too lazy to write headers out. | |
// atleast not lazy enough to use namespace std :P | |
#include <bits/stdc++.h> | |
// Sudoku bitboard: | |
struct board_t { | |
uint8_t data[324]; | |
uint8_t gs[90]; | |
uint8_t rs[90]; | |
uint8_t cs[90]; | |
uint8_t ns[10]; | |
uint8_t n_set; | |
board_t(): data{0}, gs{0}, rs{0}, cs{0}, ns{0}, n_set{0} {} | |
}; | |
namespace board { | |
// x, y measured from top left corner, zero indexed | |
// returns 0 if unfilled, else an integer from 1-9 | |
// which is the number at that point | |
inline uint8_t at(const board_t& b, uint8_t x, uint8_t y) { | |
// TODO optimize? | |
return b.data[y*9+x]; | |
} | |
inline void set(board_t& b, uint8_t x, uint8_t y, uint8_t v) { | |
b.data[y*9+x] = v; | |
b.gs[v*9 + ((y/3)*3 + x/3)]++; | |
b.rs[v*9 + y]++; | |
b.cs[v*9 + x]++; | |
b.ns[v]++; | |
b.n_set++; | |
} | |
inline void unset(board_t& b, uint8_t x, uint8_t y) { | |
int decr = (int)((bool)b.data[y*9+x]); // expensive? | |
b.gs[b.data[y*9+x]*9 + ((y/3)*3 + x/3)] -= decr; | |
b.rs[b.data[y*9+x]*9 + y] -= decr; | |
b.cs[b.data[y*9+x]*9 + x] -= decr; | |
b.ns[b.data[y*9+x]] -= decr; | |
b.n_set -= decr; | |
b.data[y*9+x] = 0; | |
} | |
inline bool check(board_t& b, int x, int y) { | |
return (b.data[y*9+x] == 0) || | |
(b.gs[b.data[y*9+x]*9 + ((y/3)*3 + x/3)] <= 1 && | |
b.rs[b.data[y*9+x]*9 + y] <= 1 && | |
b.cs[b.data[y*9+x]*9 + x] <= 1); | |
} | |
inline board_t copy(board_t& b) { | |
board_t cpy = b; | |
return cpy; | |
} | |
std::string to_string(const board_t& b) { | |
std::ostringstream s; | |
s << "+---+---+---+\n"; | |
for (int i=0; i<3; i++) { | |
for (int j=0; j<3; j++) { | |
s << "|"; | |
for (int k=0; k<3; k++) { | |
for (int l=0; l<3; l++) { | |
s << (int)at(b,k*3+l,i*3+j); | |
} | |
s << "|"; | |
} | |
s << "\n"; | |
} | |
s << "+---+---+---+\n"; | |
} | |
return s.str(); | |
} | |
board_t to_board(const std::string& s) { | |
board_t b; | |
int i=0,j=0; | |
for (char c : s) { | |
if (c != '.') set(b,j,i,c-48); | |
j++; | |
if (j == 9) { | |
j = 0; | |
i++; | |
} | |
if (i == 9) break; | |
} | |
return b; | |
} | |
} | |
// cool. Now to make the solver. | |
bool pair_comp(const std::pair<int,int>& first, const std::pair<int,int>& second) { | |
return first.second > second.second; | |
} | |
int get_mcc(board_t& b) { | |
std::priority_queue<std::pair<int,int>, std::vector<std::pair<int,int>>, | |
decltype(&pair_comp)> Q(pair_comp); | |
for (uint8_t i=0; i<9; i++) { | |
for (uint8_t j=0; j<9; j++) { | |
uint8_t n = board::at(b,j,i); | |
if (n != 0) continue; | |
int cnt = 0; | |
for (uint8_t k=1; k<=9; k++) { | |
board::set(b,j,i,k); | |
if (board::check(b,j,i)) cnt++; | |
board::unset(b,j,i); | |
} | |
Q.push({(int)(i*9+j),cnt}); | |
} | |
} | |
return Q.top().first; | |
} | |
bool solve(board_t& b) { | |
// std::cout << board::to_string(b); | |
if (b.n_set == 81) return true; | |
//std::cout << (int)b.n_set << std::endl; | |
// choose the most constrained cell c in b | |
int mcc = get_mcc(b); // I think this takes the most time now... | |
// atleast 729 ops. | |
int i = mcc/9; | |
int j = mcc%9; | |
std::vector<int> moves; | |
for (int k=1; k<=9; k++) { | |
board::set(b,j,i,k); | |
if (board::check(b,j,i)) moves.push_back(k); | |
board::unset(b,j,i); | |
} | |
// move ordering... fail last | |
sort(moves.begin(), moves.end(), [b](int m1, int m2) { | |
return b.ns[m1] > b.ns[m2]; | |
}); | |
for (int m : moves) { | |
board::set(b,j,i,m); | |
bool solved = solve(b); | |
if (solved) return true; | |
else board::unset(b,j,i); | |
} | |
return false; | |
} | |
void benchmark(std::string dataset) { | |
std::ifstream infile(dataset); | |
std::string line; | |
long long usec_count = 0; | |
long long n_puzzles = 0; | |
while (std::getline(infile, line) && n_puzzles < 100) { | |
if (line[0] == '#') continue; | |
// std::cout << line << std::endl; | |
board_t b = board::to_board(line); | |
auto begin = std::chrono::high_resolution_clock::now(); | |
bool solved = solve(b); | |
auto end = std::chrono::high_resolution_clock::now(); | |
auto time = std::chrono::duration_cast<std::chrono::microseconds>(end - begin); | |
if (solved) { | |
n_puzzles++; | |
} | |
else { | |
std::cout << "Could not solve Puzzle #" << (n_puzzles+1) << std::endl; | |
} | |
usec_count += time.count(); | |
if (solved%100 == 0) std::cout << "."; | |
if (solved%1000 == 0) std::cout << std::endl; | |
} | |
std::cout << "Num solved:" << n_puzzles << std::endl; | |
std::cout << "puzzles/sec:" << (double)(n_puzzles*1000000)/usec_count << std::endl; | |
std::cout << "usec/puzzle:" << (double)(usec_count)/n_puzzles << std::endl; | |
} | |
int main() { | |
// std::string s; | |
// std::cin >> s; | |
// board_t b = board::to_board(s); | |
// std::cout << board::to_string(b) << std::endl; | |
// //get_mcc(b); | |
// bool solved = solve(b); | |
// if (!solved) std::cout << "Could not solve" << std::endl; | |
// else std::cout << board::to_string(b); | |
// fill this with a path to run the benchmark! | |
// note that it's quite slow :( | |
// benchmark(); | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment