/problem1.py Secret
Created
January 18, 2021 14:59
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import functools | |
import math | |
import os | |
import shutil | |
import sys | |
import tempfile | |
from collections import namedtuple | |
import gmpy2 | |
import tqdm | |
Sol = namedtuple("Sol", ["val", "cost", "expr", "maxval", "op", "src1", "src2"]) | |
def sol_expr(sols_table, sol): | |
if sol.expr is not None: | |
return sol.expr | |
if sol.src2 is not None: | |
return f"({sol_expr(sols_table, sols_table[sol.src1])} {sol.op} {sol_expr(sols_table, sols_table[sol.src2])})" | |
if sol.src1 is not None: | |
return f"{sol.op}({sol_expr(sols_table, sols_table[sol.src1])})" | |
assert 0 | |
@functools.lru_cache(maxsize=None) | |
def isqrt(x): | |
assert isinstance(x, type(gmpy2.mpz(1))) | |
sq = gmpy2.isqrt(x) | |
if sq * sq == x: | |
return sq | |
return None | |
import click | |
@click.command() | |
@click.argument("maxcost", default=200) | |
@click.option("--output", "-o", default="problem1-res.txt") | |
@click.option("--maxdigits", default=1000) | |
@click.option("--maxwrite", default=10000) | |
def find_solutions(maxcost, output, maxdigits, maxwrite): | |
maxval = gmpy2.mpz(10 ** maxdigits) | |
# Number -> Sol | |
SOLS = {} | |
# cost -> Number -> Sol | |
SOLS_P = {p: {} for p in range(maxcost + 1)} | |
MSGS = [] | |
def propose(cost, op, src1, src2=None): | |
nval = None | |
if op == "+": | |
nval = src1.val + src2.val | |
elif op == "-": | |
if src1.val >= src2.val: | |
nval = src1.val - src2.val | |
elif op == "*": | |
nval = src1.val * src2.val | |
elif op == "//": | |
if src2.val > 0 and src1.val % src2.val == 0: | |
nval = src1.val // src2.val | |
elif op == "**": | |
if ( | |
src1.val > 1 | |
and (src1.val.bit_length() - 1) * src2.val < maxval.bit_length() | |
): | |
nval = src1.val ** src2.val | |
elif op == "fact": | |
if src1.val > 2 and src1.val < maxval.bit_length() + 2: | |
nval = gmpy2.fac(src1.val) | |
elif op == "isqrt": | |
nval = isqrt(src1.val) | |
else: | |
assert 0 | |
if nval is None or nval > maxval: | |
return None | |
assert cost == src1.cost * (src2.cost if src2 is not None else 1) | |
nmaxval = gmpy2.mpz( | |
max(nval, src1.maxval, src2.maxval if src2 is not None else 0) | |
) | |
if nval in SOLS and (SOLS[nval].cost < cost or SOLS[nval].maxval <= nmaxval): | |
return None | |
nsol = Sol( | |
nval, | |
cost, | |
None, | |
nmaxval, | |
op, | |
src1.val, | |
src2.val if src2 is not None else None, | |
) | |
SOLS_P[cost][nval] = nsol | |
SOLS[nval] = nsol | |
return nsol | |
for cost in range(2, maxcost + 1): | |
constsol = Sol( | |
gmpy2.mpz(cost), cost, str(cost), gmpy2.mpz(cost), None, None, None | |
) | |
if constsol.val not in SOLS or ( | |
SOLS[constsol.val].cost | |
>= constsol.cost # This can't happen, just sanity-check remnant | |
and SOLS[constsol.val].maxval > constsol.maxval | |
): | |
SOLS_P[cost][constsol.val] = constsol | |
SOLS[constsol.val] = constsol | |
# Apply all binary ops | |
divisors = [i for i in range(2, cost) if cost % i == 0] | |
tot = sum(len(SOLS_P[ca]) * len(SOLS_P[cost // ca]) for ca in divisors) | |
with tqdm.trange(tot, leave=False) as bar: | |
for ca in divisors: | |
cb = cost // ca | |
for src1 in SOLS_P[ca].values(): | |
for src2 in SOLS_P[cb].values(): | |
bar.update() | |
for op in ["+", "-", "*", "//", "**"]: | |
propose(cost, op, src1, src2) | |
# Apply unary ops | |
unary_ops = 0 | |
candidates = list(SOLS_P[cost].values()) | |
while candidates: | |
src1 = candidates.pop() | |
for op in ["fact", "isqrt"]: | |
nsol = propose(cost, op, src1) | |
if nsol is not None: | |
unary_ops += 1 | |
candidates.append(nsol) | |
sols_under = sum(gmpy2.mpz(a) in SOLS for a in range(maxwrite)) | |
MSGS.append( | |
f"Cost {cost:4}: {len(SOLS_P[cost]):8} new exact-cost sols, {unary_ops:5} of that added by unary, {len(SOLS):10} all sols, of that {sols_under:6} under {maxwrite}" | |
) | |
if output is not None: | |
sols_sorted = sorted(s for s in SOLS.values() if s.val <= maxwrite) | |
missing_under = [ | |
str(a) for a in range(maxwrite) if gmpy2.mpz(a) not in SOLS | |
] | |
with tempfile.NamedTemporaryFile("wt", delete=False) as f: | |
ms = "\n# ".join(MSGS) | |
f.write( | |
f"# Log for maxcost={maxcost}, maxdigits={maxdigits}, maxwrite={maxwrite} \n# {ms}\n#\n# Missing under {maxwrite}: {', '.join(missing_under)}\n\n" | |
) | |
f.write( | |
"\n".join( | |
f"{s.val:8} (cost {s.cost:4}, maxval {float(s.maxval):10.3e}) {sol_expr(SOLS, s)}" | |
for s in sols_sorted | |
) | |
) | |
with tempfile.NamedTemporaryFile("wt", delete=False) as fcsv: | |
writer = csv.writer(fcsv) | |
writer.writerow(["value", "cost", "maxval", "expr"]) | |
for s in sols_sorted: | |
writer.writerow([s.val, s.cost, s.maxval, sol_expr(SOLS, s)]) | |
shutil.move(f.name, output) | |
shutil.move(fcsv.name, output + ".csv") | |
print(MSGS[-1]) | |
if __name__ == "__main__": | |
find_solutions() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment