Skip to content

Instantly share code, notes, and snippets.

@iBug
Last active March 22, 2024 20:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iBug/ea958ca7f1270128d58b5176858d71cb to your computer and use it in GitHub Desktop.
Save iBug/ea958ca7f1270128d58b5176858d71cb to your computer and use it in GitHub Desktop.
Even more advanced 24 game solver. https://ibug.io/p/51
#include <algorithm>
#include <cctype>
#include <cmath>
#include <iostream>
#include <random>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
using std::string;
using std::vector;
inline bool is_equal(double a, double b, double epsilon = 1e-6) {
return std::abs(a - b) < epsilon;
}
struct Expression {
virtual ~Expression() {}
virtual int sort_order() const = 0;
virtual void normalize() {}
virtual operator string() const = 0;
virtual operator double() const = 0;
virtual void invert() {}
virtual bool is_invertible() const { return false; }
virtual bool is_canonical(bool is_top_level = true) const { return true; }
};
const int SORT_ORDER_NUMBER = 3;
const int SORT_ORDER_ADDITION = 1;
const int SORT_ORDER_MULTIPLICATION = 2;
bool operator<(const Expression& a, const Expression& b) {
if (a.sort_order() < b.sort_order())
return true;
if (a.sort_order() > b.sort_order())
return false;
if (double(a) < double(b))
return true;
if (double(a) > double(b))
return false;
return string(a) < string(b);
}
bool compare_expression(const Expression* a, const Expression* b) {
return *a < *b;
}
void sort_expressions(vector<Expression*>& e) {
std::sort(e.begin(), e.end(), compare_expression);
}
struct Number : Expression {
double v;
string s;
Number() : Number(0.0) {}
Number(double v_) : v(v_), s(std::to_string(v_)) {}
Number(const string& s_) : s(s_), v(std::stod(s_)) {
if (!std::isdigit(s[0]))
s = "(" + s + ")";
}
int sort_order() const override { return SORT_ORDER_NUMBER; }
operator string() const override { return s; }
operator double() const override { return v; }
};
struct ExpressionGroup : Expression {
vector<Expression*> positive;
vector<Expression*> negative;
void normalize() override {
for (auto& e : positive)
e->normalize();
for (auto& e : negative)
e->normalize();
sort_expressions(positive);
sort_expressions(negative);
}
void invert() override {
std::swap(positive, negative);
}
bool is_invertible() const override {
return !negative.empty();
}
bool has_negative_pairs() const {
for (auto& e1 : positive)
for (auto& e2 : negative)
if (is_equal(double(*e1), double(*e2)))
return true;
return false;
}
bool is_canonical(bool is_top_level = true) const override {
for (auto& e : positive)
if (!e->is_canonical(false))
return false;
for (auto& e : negative)
if (!e->is_canonical(false))
return false;
return true;
}
};
struct AdditiveGroup : ExpressionGroup {
int sort_order() const override { return SORT_ORDER_ADDITION; }
operator double() const override {
double result = 0.0;
for (auto& e : positive)
result += double(*e);
for (auto& e : negative)
result -= double(*e);
return result;
}
operator string() const override {
std::stringstream ss;
bool first = true;
for (auto& e : positive) {
if (first) {
first = false;
} else {
ss << "+";
}
ss << string(*e);
}
for (auto& e : negative) {
if (first) {
first = false;
} else {
ss << "-";
}
ss << string(*e);
}
return ss.str();
}
void normalize() override {
for (auto it = negative.begin(); it != negative.end();) {
auto& e = *it;
if (e->is_invertible()) {
e->invert();
positive.push_back(e);
it = negative.erase(it);
} else if (is_equal(*e, 0.0)) {
positive.push_back(e);
it = negative.erase(it);
} else {
++it;
}
}
this->ExpressionGroup::normalize();
}
bool is_invertible() const override {
return double(*this) < 0 && this->ExpressionGroup::is_invertible();
}
bool is_canonical(bool is_top_level = true) const override;
};
struct MultiplicativeGroup : ExpressionGroup {
int sort_order() const override { return SORT_ORDER_MULTIPLICATION; }
operator double() const override {
double result = 1.0;
for (auto& e : positive)
result *= double(*e);
for (auto& e : negative)
result /= double(*e);
return result;
}
operator string() const override {
std::stringstream ss;
bool first = true;
for (auto& e : positive) {
if (first) {
first = false;
} else {
ss << "*";
}
if (dynamic_cast<Number*>(e))
ss << string(*e);
else
ss << "(" << string(*e) << ")";
}
for (auto& e : negative) {
if (first) {
first = false;
} else {
ss << "/";
}
if (dynamic_cast<Number*>(e))
ss << string(*e);
else
ss << "(" << string(*e) << ")";
}
return ss.str();
}
void normalize() override {
int neg_count = 0;
for (const auto& e : positive)
neg_count += e->is_invertible();
for (const auto& e : negative)
neg_count += e->is_invertible();
neg_count -= neg_count % 2;
for (const auto& e : negative)
if (neg_count == 0)
break;
else if (e->is_invertible()) {
e->invert();
neg_count--;
}
for (const auto& e : positive)
if (neg_count == 0)
break;
else if (e->is_invertible()) {
e->invert();
neg_count--;
}
for (auto it = negative.begin(); it != negative.end();) {
auto& e = *it;
if (is_equal(*e, 1.0) || is_equal(*e, -1.0)) {
positive.push_back(e);
it = negative.erase(it);
} else {
++it;
}
}
this->ExpressionGroup::normalize();
}
bool is_invertible() const override {
if (double(*this) >= 0)
return false;
for (const auto& e : positive)
if (e->is_invertible())
return true;
for (const auto& e : negative)
if (e->is_invertible())
return true;
return false;
}
void invert() override {
for (const auto& e : negative)
if (e->is_invertible()) {
e->invert();
return;
}
for (const auto& e : positive)
if (e->is_invertible()) {
e->invert();
return;
}
}
// the only boolean argument is interpreted a bit differently here
bool is_canonical(bool allow_ones = true) const override {
if (!ExpressionGroup::is_canonical(allow_ones))
return false;
if (positive.size() > 1) {
int ones = 0;
for (const auto& e : positive)
ones += is_equal(*e, 1.0);
if (!allow_ones && ones >= 1)
return false;
if (ones >= 2)
return false;
}
return !(positive.size() > 1 && has_negative_pairs());
}
};
bool AdditiveGroup::is_canonical(bool is_top_level) const {
if (is_top_level) {
// with only a single MulGroup child and everything else zero, it's allowed to have a one
int mg_count = 0;
MultiplicativeGroup *mg, *t;
for (auto& e : positive)
if ((t = dynamic_cast<MultiplicativeGroup*>(e)) != nullptr) {
mg_count++;
mg = t;
}
for (auto& e : negative)
if ((t = dynamic_cast<MultiplicativeGroup*>(e)) != nullptr) {
mg_count++;
mg = t;
}
if (mg_count == 1 && is_equal(*this, *mg)) {
for (auto& e : positive)
if (!e->is_canonical(dynamic_cast<MultiplicativeGroup*>(e) != nullptr))
return false;
for (auto& e : negative)
if (!e->is_canonical(dynamic_cast<MultiplicativeGroup*>(e) != nullptr))
return false;
} else if (!ExpressionGroup::is_canonical(is_top_level)) {
return false;
}
return true;
}
if (!ExpressionGroup::is_canonical(is_top_level))
return false;
return !(positive.size() > 1 && has_negative_pairs());
}
template <typename T>
T* join_group(Expression* a, Expression* b, bool negative) {
static_assert(std::is_base_of<ExpressionGroup, T>::value, "T must be derived from ExpressionGroup");
auto e = new T();
if (auto a_group = dynamic_cast<T*>(a)) {
e->positive = a_group->positive;
e->negative = a_group->negative;
} else {
e->positive.push_back(a);
}
if (auto b_group = dynamic_cast<T*>(b)) {
if (negative) {
e->positive.insert(e->positive.end(), b_group->negative.begin(), b_group->negative.end());
e->negative.insert(e->negative.end(), b_group->positive.begin(), b_group->positive.end());
} else {
e->positive.insert(e->positive.end(), b_group->positive.begin(), b_group->positive.end());
e->negative.insert(e->negative.end(), b_group->negative.begin(), b_group->negative.end());
}
} else {
if (negative) {
e->negative.push_back(b);
} else {
e->positive.push_back(b);
}
}
return e;
}
inline AdditiveGroup* join_additive_group(Expression* a, Expression* b, bool negative) {
return join_group<AdditiveGroup>(a, b, negative);
}
inline MultiplicativeGroup* join_multiplicative_group(Expression* a, Expression* b, bool negative) {
return join_group<MultiplicativeGroup>(a, b, negative);
}
struct Solver {
double target; // @param target value
bool all_answers; // @param find all answers
bool count_only; // @param only show number of solutions
bool use_states; // switch: intermediate states are dedup'd only for large inputs
std::unordered_set<string> states;
std::unordered_set<string> answers;
Solver() : target(24), all_answers(false) {}
bool dedup_state(const vector<Expression*>& nodes) {
if (!use_states)
return false;
auto n = nodes;
sort_expressions(n);
std::stringstream ss;
for (auto& e : n) {
e->normalize();
ss << ":" << string(*e);
}
return !states.insert(ss.str()).second;
}
bool eval_result(Expression* node) {
bool result = is_equal(*node, target);
if (result) {
node->normalize();
if (!node->is_canonical(true))
return false;
auto expr = string(*node);
auto is_new_answer = answers.insert(expr).second;
if (is_new_answer && !count_only)
std::cout << expr << " = " << target << std::endl;
}
return result;
}
bool search(const vector<Expression*>& nodes) {
if (nodes.size() == 1)
return eval_result(nodes[0]) && !all_answers;
if (dedup_state(nodes))
return false;
bool result = false;
for (size_t i = 0; i < nodes.size(); i++) {
for (size_t j = 0; j < nodes.size(); j++) {
if (i == j)
continue;
vector<Expression*> new_nodes;
new_nodes.reserve(nodes.size() - 1);
for (size_t k = 0; k < nodes.size(); k++) {
if (k == i || k == j)
continue;
new_nodes.push_back(nodes[k]);
}
new_nodes.push_back(nullptr);
if (i < j) {
new_nodes.back() = join_additive_group(nodes[i], nodes[j], false);
result = result || search(new_nodes);
delete new_nodes.back();
new_nodes.back() = join_multiplicative_group(nodes[i], nodes[j], false);
result = result || search(new_nodes);
delete new_nodes.back();
}
new_nodes.back() = join_additive_group(nodes[i], nodes[j], true);
result = result || search(new_nodes);
delete new_nodes.back();
new_nodes.back() = join_multiplicative_group(nodes[i], nodes[j], true);
result = result || search(new_nodes);
delete new_nodes.back();
}
}
return result && !all_answers;
}
int solve(const vector<Expression*>& nums) {
states.clear();
answers.clear();
if (count_only)
all_answers = true;
use_states = nums.size() >= 5;
search(nums);
return answers.size();
}
};
int main(int argc, char** argv) {
vector<string> args(argv, argv + argc);
vector<Expression*> nums;
Solver solver;
for (int i = 1; i < args.size(); i++) {
if (args[i] == "-a") {
solver.all_answers = true;
} else if (args[i] == "-n") {
solver.count_only = true;
} else if (args[i] == "-t") {
solver.target = std::stod(args[++i]);
} else {
nums.emplace_back(new Number(args[i]));
}
}
std::random_device rd{};
std::default_random_engine rng{rd()};
std::shuffle(nums.begin(), nums.end(), rng);
auto n = solver.solve(nums);
if (n == 0) {
std::cout << "No solution" << std::endl;
} else if (n == 1) {
if (solver.all_answers)
std::cout << "1 solution" << std::endl;
} else {
std::cout << n << " solutions" << std::endl;
}
for (auto& num : nums)
delete num;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment