Skip to content

Instantly share code, notes, and snippets.

@keveman
Created August 14, 2012 01:55
Show Gist options
  • Save keveman/3345676 to your computer and use it in GitHub Desktop.
Save keveman/3345676 to your computer and use it in GitHub Desktop.
Sudoku solver in C++11
#include <iterator>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <iostream>
#include <cassert>
#include <memory>
#include <future>
#include <chrono>
using namespace std;
template<typename T>
struct maybe {
bool valid;
static unsigned created, destroyed;
union {
shared_ptr<T> value;
};
maybe() : valid(false) {}
maybe(T *_v) : valid(true), value(_v) {}
maybe(const shared_ptr<T>& _v) : valid(true), value(_v) {
++created;
}
maybe(const maybe& other) : valid(other.valid),
value(other.value) { ++created; }
operator T& () {
if (valid) return *value;
assert(false);
}
~maybe() {
if (valid) {
value.reset();
++destroyed;
}
}
bool is_valid() {
return valid;
}
};
template<typename T> unsigned maybe<T>::created = 0;
template<typename T> unsigned maybe<T>::destroyed = 0;
template<typename E1, typename E2>
auto cross(const vector<E1>& a, const vector<E2>& b) ->
vector<decltype(a[0]+b[0])> {
vector<decltype(a[0]+b[0])> output;
for (auto _a : a) {
for (auto _b : b) {
output.push_back(_a+_b);
}
}
return output;
}
vector<string> operator"" _vec(char const* str, size_t N) {
vector<string> output;
for_each(str, str+N, [=, &output] (char c) {
output.push_back(string(1, c));
});
return output;
}
vector<string> digits =
{"1", "2", "3", "4", "5", "6", "7", "8", "9"};
vector<string> rows =
{"A", "B", "C", "D", "E", "F", "G", "H", "I"};
vector<string> cols = digits;
vector<string> squares = cross(rows, cols);
vector<vector<string>> unitlist;
map<string, vector<vector<string>>> units;
map<string, set<string>> peers;
typedef map<string, string> value_t;
maybe<value_t> eliminate(maybe<value_t> _values, string s, string d);
maybe<value_t> assign(maybe<value_t> _values, string s, string d) {
value_t& values = _values;
size_t pos = values[s].find(d);
assert(pos != string::npos);
string temp = values[s];
temp.replace(pos, 1, "");
string other_values = temp;
for (auto const &d2 : other_values) {
if (eliminate(_values, s, string(1, d2)).is_valid() == false)
return maybe<value_t>();
}
return _values;
}
maybe<value_t> eliminate(maybe<value_t> _values, string s, string d) {
value_t& values = _values;
size_t pos = values[s].find(d);
if (pos == string::npos)
return _values;
values[s].replace(pos, 1, "");
if (values[s].length() == 0)
return maybe<value_t>();
else if (values[s].length() == 1) {
string d2 = values[s];
for (auto const &s2 : peers[s]) {
if (eliminate(_values, s2, d2).is_valid() == false)
return maybe<value_t>();
}
}
for (auto const &u : units[s]) {
vector<string> dplaces;
for (auto const &_s : u) {
if (values[_s].find(d) != string::npos)
dplaces.push_back(_s);
}
if (dplaces.size() == 0)
return maybe<value_t>();
if (dplaces.size() == 1)
if (assign(_values, dplaces[0], d).is_valid() == false)
return maybe<value_t>();
}
return _values;
}
void display(maybe<value_t> values);
maybe<value_t> grid_values(string grid) {
vector<string> chars;
for (auto c : grid) {
if ((string(".0").find(c) != string::npos) ||
(find(digits.begin(), digits.end(), string(1, c)) != digits.end()))
chars.push_back(string(1, c));
}
assert(chars.size() == 81);
maybe<value_t> output(make_shared<value_t>());
value_t& temp = output;
for (unsigned i=0; i<81; ++i)
temp[squares[i]] = chars[i];
return output;
}
maybe<value_t> parse_grid(string grid) {
maybe<value_t> _values(make_shared<value_t>());
value_t& values = _values;
for (auto s : squares)
values[s] = string("123456789");
maybe<value_t> _g = grid_values(grid);
value_t& g = _g;
for (auto _i : g) {
auto s = _i.first, d = _i.second;
if ((string("123456789").find(d) != string::npos) &&
(assign(_values, s, d).is_valid() == false))
return maybe<value_t>();
}
return _values;
}
std::string center(const std::string & str, int width) {
int len = (int) str.size();
int marg, left;
if (len >= width) return str;
marg = width - len;
left = marg / 2 + (marg & width & 1);
return std::string(left, ' ') + str + std::string(marg - left, ' ');
}
void check(maybe<value_t> _values) {
value_t& values = _values;
for (auto u1 : unitlist) {
string unitvals("");
for (auto u2 : u1) {
unitvals += values[u2];
}
sort(unitvals.begin(), unitvals.end());
assert (unitvals == "123456789");
}
}
void display(maybe<value_t> values) {
unsigned width = 3;
if (values.is_valid() == false) {
cout << "Invalid grid\n";
return;
}
value_t& _values = values;
for (auto s : squares)
if (_values[s].length() > width) width = _values[s].length();
string line = string(width*3, '-') + "+" + string(width*3, '-') + "+" + string(width*3, '-');
for (auto r : rows) {
for (auto c : cols) {
cout << center(_values[r+c], width) + string((c=="3" || c=="6")?"|":"");
}
cout << "\n";
if (r=="C" || r=="F") cout << line << "\n";
}
}
maybe<value_t> search(maybe<value_t> _values) {
if (_values.is_valid() == false)
return maybe<value_t>();
value_t& values = _values;
value_t temp;
copy_if(values.begin(), values.end(),
inserter(temp, temp.begin()),
[=, &values] (value_t::value_type v) -> bool {
return v.second.length() > 1;
});
if (temp.empty())
return _values;
auto min_s = min_element(temp.begin(), temp.end(),
[=, &values] (value_t::value_type v1,
value_t::value_type v2) -> bool {
return v1.second.length() < v2.second.length();
});
string s = min_s->first;
for (auto const &d : values[s]) {
maybe<value_t> v_copy(make_shared<value_t>(values));
maybe<value_t> retval = search(assign(v_copy, s, string(1, d)));
if (retval.is_valid()) return retval;
}
return maybe<value_t>();
}
int main(int argc, char* argv[])
{
assert (argc == 2);
for (auto _c : cols) {
vector<string> c = {_c};
unitlist.push_back(cross(rows, c));
}
for (auto _r : rows) {
vector<string> r = {_r};
unitlist.push_back(cross(r, cols));
}
for (vector<string> rs : {"ABC"_vec, "DEF"_vec, "GHI"_vec}) {
for (vector<string> cs : {"123"_vec, "456"_vec, "789"_vec}) {
unitlist.push_back(cross(rs, cs));
}
}
for (string s : squares) {
for (vector<string> u : unitlist) {
if (find(u.begin(), u.end(), s) != u.end())
units[s].push_back(u);
}
}
for (auto s : squares) {
vector<string> temp =
accumulate(units[s].begin(), units[s].end(), vector<string>(),
[=](const vector<string>& a1, const vector<string>& a2) ->
vector<string> {
vector<string> ret;
for (auto e : a1) ret.push_back(e);
for (auto e : a2) ret.push_back(e);
return ret;
});
set<string>& output = peers[s];
copy_if(temp.begin(), temp.end(), inserter(output, output.end()),
[=] (string t) { return !(t == s); });
}
string grid(argv[1]);
auto start = chrono::high_resolution_clock::now();
maybe<value_t> ans = search(parse_grid(grid));
auto end = chrono::high_resolution_clock::now();
auto elapsed = chrono::duration_cast<chrono::microseconds>(end-start).count();
cout << elapsed << " us\n";
display(ans);
check(ans);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment