Last active
March 22, 2024 20:14
-
-
Save iBug/ea958ca7f1270128d58b5176858d71cb to your computer and use it in GitHub Desktop.
Even more advanced 24 game solver. https://ibug.io/p/51
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cctype> | |
#include <cmath> | |
#include <iostream> | |
#include <random> | |
#include <sstream> | |
#include <string> | |
#include <unordered_set> | |
#include <vector> | |
using std::string; | |
using std::vector; | |
inline bool is_equal(double a, double b, double epsilon = 1e-6) { | |
return std::abs(a - b) < epsilon; | |
} | |
struct Expression { | |
virtual ~Expression() {} | |
virtual int sort_order() const = 0; | |
virtual void normalize() {} | |
virtual operator string() const = 0; | |
virtual operator double() const = 0; | |
virtual void invert() {} | |
virtual bool is_invertible() const { return false; } | |
virtual bool is_canonical(bool is_top_level = true) const { return true; } | |
}; | |
const int SORT_ORDER_NUMBER = 3; | |
const int SORT_ORDER_ADDITION = 1; | |
const int SORT_ORDER_MULTIPLICATION = 2; | |
bool operator<(const Expression& a, const Expression& b) { | |
if (a.sort_order() < b.sort_order()) | |
return true; | |
if (a.sort_order() > b.sort_order()) | |
return false; | |
if (double(a) < double(b)) | |
return true; | |
if (double(a) > double(b)) | |
return false; | |
return string(a) < string(b); | |
} | |
bool compare_expression(const Expression* a, const Expression* b) { | |
return *a < *b; | |
} | |
void sort_expressions(vector<Expression*>& e) { | |
std::sort(e.begin(), e.end(), compare_expression); | |
} | |
struct Number : Expression { | |
double v; | |
string s; | |
Number() : Number(0.0) {} | |
Number(double v_) : v(v_), s(std::to_string(v_)) {} | |
Number(const string& s_) : s(s_), v(std::stod(s_)) { | |
if (!std::isdigit(s[0])) | |
s = "(" + s + ")"; | |
} | |
int sort_order() const override { return SORT_ORDER_NUMBER; } | |
operator string() const override { return s; } | |
operator double() const override { return v; } | |
}; | |
struct ExpressionGroup : Expression { | |
vector<Expression*> positive; | |
vector<Expression*> negative; | |
void normalize() override { | |
for (auto& e : positive) | |
e->normalize(); | |
for (auto& e : negative) | |
e->normalize(); | |
sort_expressions(positive); | |
sort_expressions(negative); | |
} | |
void invert() override { | |
std::swap(positive, negative); | |
} | |
bool is_invertible() const override { | |
return !negative.empty(); | |
} | |
bool has_negative_pairs() const { | |
for (auto& e1 : positive) | |
for (auto& e2 : negative) | |
if (is_equal(double(*e1), double(*e2))) | |
return true; | |
return false; | |
} | |
bool is_canonical(bool is_top_level = true) const override { | |
for (auto& e : positive) | |
if (!e->is_canonical(false)) | |
return false; | |
for (auto& e : negative) | |
if (!e->is_canonical(false)) | |
return false; | |
return true; | |
} | |
}; | |
struct AdditiveGroup : ExpressionGroup { | |
int sort_order() const override { return SORT_ORDER_ADDITION; } | |
operator double() const override { | |
double result = 0.0; | |
for (auto& e : positive) | |
result += double(*e); | |
for (auto& e : negative) | |
result -= double(*e); | |
return result; | |
} | |
operator string() const override { | |
std::stringstream ss; | |
bool first = true; | |
for (auto& e : positive) { | |
if (first) { | |
first = false; | |
} else { | |
ss << "+"; | |
} | |
ss << string(*e); | |
} | |
for (auto& e : negative) { | |
if (first) { | |
first = false; | |
} else { | |
ss << "-"; | |
} | |
ss << string(*e); | |
} | |
return ss.str(); | |
} | |
void normalize() override { | |
for (auto it = negative.begin(); it != negative.end();) { | |
auto& e = *it; | |
if (e->is_invertible()) { | |
e->invert(); | |
positive.push_back(e); | |
it = negative.erase(it); | |
} else if (is_equal(*e, 0.0)) { | |
positive.push_back(e); | |
it = negative.erase(it); | |
} else { | |
++it; | |
} | |
} | |
this->ExpressionGroup::normalize(); | |
} | |
bool is_invertible() const override { | |
return double(*this) < 0 && this->ExpressionGroup::is_invertible(); | |
} | |
bool is_canonical(bool is_top_level = true) const override; | |
}; | |
struct MultiplicativeGroup : ExpressionGroup { | |
int sort_order() const override { return SORT_ORDER_MULTIPLICATION; } | |
operator double() const override { | |
double result = 1.0; | |
for (auto& e : positive) | |
result *= double(*e); | |
for (auto& e : negative) | |
result /= double(*e); | |
return result; | |
} | |
operator string() const override { | |
std::stringstream ss; | |
bool first = true; | |
for (auto& e : positive) { | |
if (first) { | |
first = false; | |
} else { | |
ss << "*"; | |
} | |
if (dynamic_cast<Number*>(e)) | |
ss << string(*e); | |
else | |
ss << "(" << string(*e) << ")"; | |
} | |
for (auto& e : negative) { | |
if (first) { | |
first = false; | |
} else { | |
ss << "/"; | |
} | |
if (dynamic_cast<Number*>(e)) | |
ss << string(*e); | |
else | |
ss << "(" << string(*e) << ")"; | |
} | |
return ss.str(); | |
} | |
void normalize() override { | |
int neg_count = 0; | |
for (const auto& e : positive) | |
neg_count += e->is_invertible(); | |
for (const auto& e : negative) | |
neg_count += e->is_invertible(); | |
neg_count -= neg_count % 2; | |
for (const auto& e : negative) | |
if (neg_count == 0) | |
break; | |
else if (e->is_invertible()) { | |
e->invert(); | |
neg_count--; | |
} | |
for (const auto& e : positive) | |
if (neg_count == 0) | |
break; | |
else if (e->is_invertible()) { | |
e->invert(); | |
neg_count--; | |
} | |
for (auto it = negative.begin(); it != negative.end();) { | |
auto& e = *it; | |
if (is_equal(*e, 1.0) || is_equal(*e, -1.0)) { | |
positive.push_back(e); | |
it = negative.erase(it); | |
} else { | |
++it; | |
} | |
} | |
this->ExpressionGroup::normalize(); | |
} | |
bool is_invertible() const override { | |
if (double(*this) >= 0) | |
return false; | |
for (const auto& e : positive) | |
if (e->is_invertible()) | |
return true; | |
for (const auto& e : negative) | |
if (e->is_invertible()) | |
return true; | |
return false; | |
} | |
void invert() override { | |
for (const auto& e : negative) | |
if (e->is_invertible()) { | |
e->invert(); | |
return; | |
} | |
for (const auto& e : positive) | |
if (e->is_invertible()) { | |
e->invert(); | |
return; | |
} | |
} | |
// the only boolean argument is interpreted a bit differently here | |
bool is_canonical(bool allow_ones = true) const override { | |
if (!ExpressionGroup::is_canonical(allow_ones)) | |
return false; | |
if (positive.size() > 1) { | |
int ones = 0; | |
for (const auto& e : positive) | |
ones += is_equal(*e, 1.0); | |
if (!allow_ones && ones >= 1) | |
return false; | |
if (ones >= 2) | |
return false; | |
} | |
return !(positive.size() > 1 && has_negative_pairs()); | |
} | |
}; | |
bool AdditiveGroup::is_canonical(bool is_top_level) const { | |
if (is_top_level) { | |
// with only a single MulGroup child and everything else zero, it's allowed to have a one | |
int mg_count = 0; | |
MultiplicativeGroup *mg, *t; | |
for (auto& e : positive) | |
if ((t = dynamic_cast<MultiplicativeGroup*>(e)) != nullptr) { | |
mg_count++; | |
mg = t; | |
} | |
for (auto& e : negative) | |
if ((t = dynamic_cast<MultiplicativeGroup*>(e)) != nullptr) { | |
mg_count++; | |
mg = t; | |
} | |
if (mg_count == 1 && is_equal(*this, *mg)) { | |
for (auto& e : positive) | |
if (!e->is_canonical(dynamic_cast<MultiplicativeGroup*>(e) != nullptr)) | |
return false; | |
for (auto& e : negative) | |
if (!e->is_canonical(dynamic_cast<MultiplicativeGroup*>(e) != nullptr)) | |
return false; | |
} else if (!ExpressionGroup::is_canonical(is_top_level)) { | |
return false; | |
} | |
return true; | |
} | |
if (!ExpressionGroup::is_canonical(is_top_level)) | |
return false; | |
return !(positive.size() > 1 && has_negative_pairs()); | |
} | |
template <typename T> | |
T* join_group(Expression* a, Expression* b, bool negative) { | |
static_assert(std::is_base_of<ExpressionGroup, T>::value, "T must be derived from ExpressionGroup"); | |
auto e = new T(); | |
if (auto a_group = dynamic_cast<T*>(a)) { | |
e->positive = a_group->positive; | |
e->negative = a_group->negative; | |
} else { | |
e->positive.push_back(a); | |
} | |
if (auto b_group = dynamic_cast<T*>(b)) { | |
if (negative) { | |
e->positive.insert(e->positive.end(), b_group->negative.begin(), b_group->negative.end()); | |
e->negative.insert(e->negative.end(), b_group->positive.begin(), b_group->positive.end()); | |
} else { | |
e->positive.insert(e->positive.end(), b_group->positive.begin(), b_group->positive.end()); | |
e->negative.insert(e->negative.end(), b_group->negative.begin(), b_group->negative.end()); | |
} | |
} else { | |
if (negative) { | |
e->negative.push_back(b); | |
} else { | |
e->positive.push_back(b); | |
} | |
} | |
return e; | |
} | |
inline AdditiveGroup* join_additive_group(Expression* a, Expression* b, bool negative) { | |
return join_group<AdditiveGroup>(a, b, negative); | |
} | |
inline MultiplicativeGroup* join_multiplicative_group(Expression* a, Expression* b, bool negative) { | |
return join_group<MultiplicativeGroup>(a, b, negative); | |
} | |
struct Solver { | |
double target; // @param target value | |
bool all_answers; // @param find all answers | |
bool count_only; // @param only show number of solutions | |
bool use_states; // switch: intermediate states are dedup'd only for large inputs | |
std::unordered_set<string> states; | |
std::unordered_set<string> answers; | |
Solver() : target(24), all_answers(false) {} | |
bool dedup_state(const vector<Expression*>& nodes) { | |
if (!use_states) | |
return false; | |
auto n = nodes; | |
sort_expressions(n); | |
std::stringstream ss; | |
for (auto& e : n) { | |
e->normalize(); | |
ss << ":" << string(*e); | |
} | |
return !states.insert(ss.str()).second; | |
} | |
bool eval_result(Expression* node) { | |
bool result = is_equal(*node, target); | |
if (result) { | |
node->normalize(); | |
if (!node->is_canonical(true)) | |
return false; | |
auto expr = string(*node); | |
auto is_new_answer = answers.insert(expr).second; | |
if (is_new_answer && !count_only) | |
std::cout << expr << " = " << target << std::endl; | |
} | |
return result; | |
} | |
bool search(const vector<Expression*>& nodes) { | |
if (nodes.size() == 1) | |
return eval_result(nodes[0]) && !all_answers; | |
if (dedup_state(nodes)) | |
return false; | |
bool result = false; | |
for (size_t i = 0; i < nodes.size(); i++) { | |
for (size_t j = 0; j < nodes.size(); j++) { | |
if (i == j) | |
continue; | |
vector<Expression*> new_nodes; | |
new_nodes.reserve(nodes.size() - 1); | |
for (size_t k = 0; k < nodes.size(); k++) { | |
if (k == i || k == j) | |
continue; | |
new_nodes.push_back(nodes[k]); | |
} | |
new_nodes.push_back(nullptr); | |
if (i < j) { | |
new_nodes.back() = join_additive_group(nodes[i], nodes[j], false); | |
result = result || search(new_nodes); | |
delete new_nodes.back(); | |
new_nodes.back() = join_multiplicative_group(nodes[i], nodes[j], false); | |
result = result || search(new_nodes); | |
delete new_nodes.back(); | |
} | |
new_nodes.back() = join_additive_group(nodes[i], nodes[j], true); | |
result = result || search(new_nodes); | |
delete new_nodes.back(); | |
new_nodes.back() = join_multiplicative_group(nodes[i], nodes[j], true); | |
result = result || search(new_nodes); | |
delete new_nodes.back(); | |
} | |
} | |
return result && !all_answers; | |
} | |
int solve(const vector<Expression*>& nums) { | |
states.clear(); | |
answers.clear(); | |
if (count_only) | |
all_answers = true; | |
use_states = nums.size() >= 5; | |
search(nums); | |
return answers.size(); | |
} | |
}; | |
int main(int argc, char** argv) { | |
vector<string> args(argv, argv + argc); | |
vector<Expression*> nums; | |
Solver solver; | |
for (int i = 1; i < args.size(); i++) { | |
if (args[i] == "-a") { | |
solver.all_answers = true; | |
} else if (args[i] == "-n") { | |
solver.count_only = true; | |
} else if (args[i] == "-t") { | |
solver.target = std::stod(args[++i]); | |
} else { | |
nums.emplace_back(new Number(args[i])); | |
} | |
} | |
std::random_device rd{}; | |
std::default_random_engine rng{rd()}; | |
std::shuffle(nums.begin(), nums.end(), rng); | |
auto n = solver.solve(nums); | |
if (n == 0) { | |
std::cout << "No solution" << std::endl; | |
} else if (n == 1) { | |
if (solver.all_answers) | |
std::cout << "1 solution" << std::endl; | |
} else { | |
std::cout << n << " solutions" << std::endl; | |
} | |
for (auto& num : nums) | |
delete num; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment