Last active
December 21, 2021 19:04
-
-
Save depp/3a6f0377284fbb9b33984063856051b1 to your computer and use it in GitHub Desktop.
A Faster Solution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Faster solution for: | |
// http://www.boyter.org/2017/03/golang-solution-faster-equivalent-java-solution/ | |
// With threading. | |
// g++ -std=c++11 -Wall -Wextra -O3 -pthread | |
// On my computer (i5-6600K 3.50 GHz 4 cores), takes about ~160 ms after the CPU | |
// has warmed up, or ~80 ms if the CPU is cold (due to Turbo Boost). | |
// How it works: Start by generating a list of losing states -- states where the | |
// game can end in one turn. Generate a new list of states by running the game | |
// backwards one turn, but remove any states we've seen before. Eventually we | |
// will stop seeing new states. | |
// The way we can parallelize it is by giving each thread a different amount of | |
// total money to use in the game. Since the total amount of money in the game | |
// is invariant, this means that different threads will never visit the same | |
// state. | |
// You can run the program without arguments to search for states with a minimum | |
// number of turns before they end, or you can pass three arguments to show a | |
// game with the minimum number of turns from that starting point. | |
#define NDEBUG 1 | |
#include <algorithm> | |
#include <array> | |
#include <cassert> | |
#include <iostream> | |
#include <mutex> | |
#include <thread> | |
#include <utility> | |
#include <vector> | |
typedef std::array<unsigned, 3> State; | |
void Show(const State &s) { | |
std::cout << s[0] << ' ' << s[1] << ' ' << s[2] << '\n'; | |
} | |
void Show(const std::vector<State> &v) { | |
std::cout << "Turns: " << v.size() << '\n'; | |
for (State s : v) { | |
Show(s); | |
} | |
} | |
// Number of turns we are searching for. | |
static const unsigned NumTurns = 12; | |
// Max starting money for each player. Don't change this without changing how | |
// the encoding works. | |
static const unsigned Max = 255; | |
// A bound an the number of different states for a fixed amount of money. | |
static const unsigned NumStates = 256 * ((Max * 3 - 1) / 2); | |
// Put a state into the canonical format, sorted ascending. | |
void Sort(State &p) { | |
using std::swap; | |
if (p[0] > p[1]) | |
swap(p[0], p[1]); | |
if (p[1] > p[2]) | |
swap(p[1], p[2]); | |
if (p[0] > p[1]) | |
swap(p[0], p[1]); | |
} | |
unsigned Total(const State &s) { return s[0] + s[1] + s[2]; } | |
unsigned Encode(const State &p) { return (p[1] << 8) + p[0]; } | |
State Decode(unsigned money, unsigned x) { | |
unsigned a = x & 0xff, b = x >> 8; | |
return {a, b, money - a - b}; | |
} | |
// Search for a minimum-length game with the given starting state. | |
std::vector<State> Play(State start) { | |
Sort(start); | |
unsigned startpos = Encode(start); | |
unsigned money = Total(start); | |
std::vector<unsigned> visited; | |
visited.resize(NumStates, 0); | |
visited[startpos] = startpos; | |
std::vector<State> q1, q2; | |
q1.push_back(start); | |
int turns = 0; | |
do { | |
turns++; | |
std::swap(q1, q2); | |
q1.clear(); | |
for (State s : q2) { | |
unsigned oldpos = Encode(s); | |
for (int i = 0; i < 3; i++) { | |
int j = (i + 1) % 3; | |
int k = (i + 2) % 3; | |
unsigned x = s[i], y = s[j], z = s[k]; | |
if (x == y) { | |
std::vector<State> r; | |
for (unsigned p = oldpos; p != startpos; p = visited[p]) { | |
r.push_back(Decode(money, p)); | |
} | |
r.push_back(start); | |
assert(r.size() == static_cast<std::size_t>(turns)); | |
std::reverse(r.begin(), r.end()); | |
return r; | |
} | |
if (x < y) { | |
y -= x; | |
x += x; | |
} else { | |
x -= y; | |
y += y; | |
} | |
State ss{x, y, z}; | |
Sort(ss); | |
unsigned pos = Encode(ss); | |
if (visited[pos]) | |
continue; | |
visited[pos] = oldpos; | |
q1.push_back(ss); | |
} | |
} | |
} while (!q1.empty()); | |
std::cout << "Failed!\n"; | |
return std::vector<State>{}; | |
} | |
class Solver { | |
// Lock for all solver fields. | |
std::mutex m_mutex; | |
// Amount of money we are currently using in our games. This counts from 1 | |
// up to 3 * Max. | |
unsigned m_money; | |
// Whether we found the solution. | |
bool m_found; | |
// The solution, if found. | |
State m_solution; | |
private: | |
Solver(); | |
void Work(); | |
public: | |
static State Solve(); | |
}; | |
Solver::Solver() : m_money(3), m_found(false) {} | |
void Solver::Work() { | |
std::vector<bool> visited; | |
visited.resize(NumStates, false); | |
std::vector<State> q1, q2; | |
while (true) { | |
unsigned money; | |
{ | |
std::unique_lock<std::mutex> lock(m_mutex); | |
if (m_found || m_money > Max * 3) | |
return; | |
money = m_money; | |
m_money++; | |
} | |
std::fill(visited.begin(), visited.end(), false); | |
// Generate all of the final states with the given amount of total money. | |
// Put these states into q1. Mark the states as visited. | |
for (unsigned smallest = 1; smallest * 3 <= money; smallest++) { | |
unsigned other1 = money - 2 * smallest; | |
State s{smallest, smallest, other1}; | |
visited[Encode(s)] = true; | |
q1.push_back(s); | |
unsigned other2 = (money - smallest) / 2; | |
if (((money - smallest) & 1) == 0 && smallest != other1) { | |
State s{smallest, other2, other2}; | |
visited[Encode(s)] = true; | |
q1.push_back(s); | |
} | |
} | |
if (q1.empty()) | |
continue; | |
int turn = 1; | |
// Loop invariants: "q1" contains all the possible states where the minimum | |
// number of turns before the game might end is equal to "turn". | |
do { | |
// If we've reached the target number of turns, we're done--but only if | |
// the starting state is legal. | |
if (turn == NumTurns) { | |
for (State s : q1) { | |
if (s[2] <= Max) { | |
std::unique_lock<std::mutex> lock(m_mutex); | |
if (m_found && Total(m_solution) < money) | |
return; | |
m_found = true; | |
m_solution = s; | |
return; | |
} | |
} | |
} | |
// Fill "q1" by running the game backwards, and discarding any state we've | |
// seen before. | |
std::swap(q1, q2); | |
q1.clear(); | |
for (State s : q2) { | |
for (int i = 0; i < 3; i++) { | |
unsigned x = s[i]; | |
if ((x & 1) != 0) | |
continue; | |
x /= 2; | |
for (int j = 0; j < 2; j++) { | |
unsigned y = s[(i + 1 + j) % 3]; | |
unsigned z = s[(i + 2 - j) % 3]; | |
y += x; | |
if (y <= x) | |
continue; | |
State s2{x, y, z}; | |
Sort(s2); | |
unsigned pos = Encode(s2); | |
if (visited[pos]) | |
continue; | |
visited[pos] = true; | |
q1.push_back(s2); | |
} | |
} | |
} | |
turn++; | |
} while (!q1.empty()); | |
} | |
} | |
State Solver::Solve() { | |
unsigned numthreads = std::thread::hardware_concurrency(); | |
if (numthreads == 0) { | |
numthreads = 1; | |
} | |
std::cout << "Work threads: " << numthreads << '\n'; | |
Solver s; | |
std::vector<std::thread> threads; | |
for (unsigned i = 0; i < numthreads; i++) { | |
threads.emplace_back(&Solver::Work, &s); | |
} | |
for (std::thread &t : threads) { | |
t.join(); | |
} | |
return s.m_solution; | |
} | |
int main(int argc, char **argv) { | |
if (argc == 1) { | |
State s = Solver::Solve(); | |
std::cout << "Solution found: "; | |
Show(s); | |
} else if (argc == 4) { | |
State s; | |
for (int i = 0; i < 3; i++) { | |
unsigned long v = std::stoul(argv[i + 1]); | |
if (v < 1 || v > Max) { | |
std::cout << "Invalid\n"; | |
return 1; | |
} | |
s[i] = v; | |
} | |
Show(Play(s)); | |
} else { | |
std::cout << "Bad usage\n"; | |
return 1; | |
} | |
return 0; | |
} |
@jthemphill: 256 because of the << 8
in Encode
. The (Max * 3 - 1) / 2
is the maximum cash that the second player has.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
NumStates = 256 * ((Max * 3 - 1) / 2)
Could I get an explanation on how you calculated the total number of states? Did 256 come from
Max + 1
?