Skip to content

Instantly share code, notes, and snippets.

@jsimpson
Forked from depp/problem.cpp
Created April 1, 2017 14:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jsimpson/205e6258c5371488b2fc116b79ae1e28 to your computer and use it in GitHub Desktop.
Save jsimpson/205e6258c5371488b2fc116b79ae1e28 to your computer and use it in GitHub Desktop.
A Faster Solution
// Faster solution for:
// http://www.boyter.org/2017/03/golang-solution-faster-equivalent-java-solution/
// With threading.
// g++ -std=c++11 -Wall -Wextra -O3 -pthread
// On my computer (i5-6600K 3.50 GHz 4 cores), takes about ~160 ms after the CPU
// has warmed up, or ~80 ms if the CPU is cold (due to Turbo Boost).
// How it works: Start by generating a list of losing states -- states where the
// game can end in one turn. Generate a new list of states by running the game
// backwards one turn, but remove any states we've seen before. Eventually we
// will stop seeing new states.
// The way we can parallelize it is by giving each thread a different amount of
// total money to use in the game. Since the total amount of money in the game
// is invariant, this means that different threads will never visit the same
// state.
// You can run the program without arguments to search for states with a minimum
// number of turns before they end, or you can pass three arguments to show a
// game with the minimum number of turns from that starting point.
#define NDEBUG 1
#include <algorithm>
#include <array>
#include <cassert>
#include <iostream>
#include <mutex>
#include <thread>
#include <utility>
#include <vector>
typedef std::array<unsigned, 3> State;
void Show(const State &s) {
std::cout << s[0] << ' ' << s[1] << ' ' << s[2] << '\n';
}
void Show(const std::vector<State> &v) {
std::cout << "Turns: " << v.size() << '\n';
for (State s : v) {
Show(s);
}
}
// Number of turns we are searching for.
static const unsigned NumTurns = 12;
// Max starting money for each player. Don't change this without changing how
// the encoding works.
static const unsigned Max = 255;
// A bound an the number of different states for a fixed amount of money.
static const unsigned NumStates = 256 * ((Max * 3 - 1) / 2);
// Put a state into the canonical format, sorted ascending.
void Sort(State &p) {
using std::swap;
if (p[0] > p[1])
swap(p[0], p[1]);
if (p[1] > p[2])
swap(p[1], p[2]);
if (p[0] > p[1])
swap(p[0], p[1]);
}
unsigned Total(const State &s) { return s[0] + s[1] + s[2]; }
unsigned Encode(const State &p) { return (p[1] << 8) + p[0]; }
State Decode(unsigned money, unsigned x) {
unsigned a = x & 0xff, b = x >> 8;
return {a, b, money - a - b};
}
// Search for a minimum-length game with the given starting state.
std::vector<State> Play(State start) {
Sort(start);
unsigned startpos = Encode(start);
unsigned money = Total(start);
std::vector<unsigned> visited;
visited.resize(NumStates, 0);
visited[startpos] = startpos;
std::vector<State> q1, q2;
q1.push_back(start);
int turns = 0;
do {
turns++;
std::swap(q1, q2);
q1.clear();
for (State s : q2) {
unsigned oldpos = Encode(s);
for (int i = 0; i < 3; i++) {
int j = (i + 1) % 3;
int k = (i + 2) % 3;
unsigned x = s[i], y = s[j], z = s[k];
if (x == y) {
std::vector<State> r;
for (unsigned p = oldpos; p != startpos; p = visited[p]) {
r.push_back(Decode(money, p));
}
r.push_back(start);
assert(r.size() == static_cast<std::size_t>(turns));
std::reverse(r.begin(), r.end());
return r;
}
if (x < y) {
y -= x;
x += x;
} else {
x -= y;
y += y;
}
State ss{x, y, z};
Sort(ss);
unsigned pos = Encode(ss);
if (visited[pos])
continue;
visited[pos] = oldpos;
q1.push_back(ss);
}
}
} while (!q1.empty());
std::cout << "Failed!\n";
return std::vector<State>{};
}
class Solver {
// Lock for all solver fields.
std::mutex m_mutex;
// Amount of money we are currently using in our games. This counts from 1
// up to 3 * Max.
unsigned m_money;
// Whether we found the solution.
bool m_found;
// The solution, if found.
State m_solution;
private:
Solver();
void Work();
public:
static State Solve();
};
Solver::Solver() : m_money(3), m_found(false) {}
void Solver::Work() {
std::vector<bool> visited;
visited.resize(NumStates, false);
std::vector<State> q1, q2;
while (true) {
unsigned money;
{
std::unique_lock<std::mutex> lock(m_mutex);
if (m_found || m_money > Max * 3)
return;
money = m_money;
m_money++;
}
std::fill(visited.begin(), visited.end(), false);
// Generate all of the final states with the given amount of total money.
// Put these states into q1. Mark the states as visited.
for (unsigned smallest = 1; smallest * 3 <= money; smallest++) {
unsigned other1 = money - 2 * smallest;
State s{smallest, smallest, other1};
visited[Encode(s)] = true;
q1.push_back(s);
unsigned other2 = (money - smallest) / 2;
if (((money - smallest) & 1) == 0 && smallest != other1) {
State s{smallest, other2, other2};
visited[Encode(s)] = true;
q1.push_back(s);
}
}
if (q1.empty())
continue;
int turn = 1;
// Loop invariants: "q1" contains all the possible states where the minimum
// number of turns before the game might end is equal to "turn".
do {
// If we've reached the target number of turns, we're done--but only if
// the starting state is legal.
if (turn == NumTurns) {
for (State s : q1) {
if (s[2] <= Max) {
std::unique_lock<std::mutex> lock(m_mutex);
if (m_found && Total(m_solution) < money)
return;
m_found = true;
m_solution = s;
return;
}
}
}
// Fill "q1" by running the game backwards, and discarding any state we've
// seen before.
std::swap(q1, q2);
q1.clear();
for (State s : q2) {
for (int i = 0; i < 3; i++) {
unsigned x = s[i];
if ((x & 1) != 0)
continue;
x /= 2;
for (int j = 0; j < 2; j++) {
unsigned y = s[(i + 1 + j) % 3];
unsigned z = s[(i + 2 - j) % 3];
y += x;
if (y <= x)
continue;
State s2{x, y, z};
Sort(s2);
unsigned pos = Encode(s2);
if (visited[pos])
continue;
visited[pos] = true;
q1.push_back(s2);
}
}
}
turn++;
} while (!q1.empty());
}
}
State Solver::Solve() {
unsigned numthreads = std::thread::hardware_concurrency();
if (numthreads == 0) {
numthreads = 1;
}
std::cout << "Work threads: " << numthreads << '\n';
Solver s;
std::vector<std::thread> threads;
for (unsigned i = 0; i < numthreads; i++) {
threads.emplace_back(&Solver::Work, &s);
}
for (std::thread &t : threads) {
t.join();
}
return s.m_solution;
}
int main(int argc, char **argv) {
if (argc == 1) {
State s = Solver::Solve();
std::cout << "Solution found: ";
Show(s);
} else if (argc == 4) {
State s;
for (int i = 0; i < 3; i++) {
unsigned long v = std::stoul(argv[i + 1]);
if (v < 1 || v > Max) {
std::cout << "Invalid\n";
return 1;
}
s[i] = v;
}
Show(Play(s));
} else {
std::cout << "Bad usage\n";
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment