Skip to content

Instantly share code, notes, and snippets.

@gavento
Created January 18, 2021 14:59
import csv
import functools
import math
import os
import shutil
import sys
import tempfile
from collections import namedtuple
import gmpy2
import tqdm
Sol = namedtuple("Sol", ["val", "cost", "expr", "maxval", "op", "src1", "src2"])
def sol_expr(sols_table, sol):
if sol.expr is not None:
return sol.expr
if sol.src2 is not None:
return f"({sol_expr(sols_table, sols_table[sol.src1])} {sol.op} {sol_expr(sols_table, sols_table[sol.src2])})"
if sol.src1 is not None:
return f"{sol.op}({sol_expr(sols_table, sols_table[sol.src1])})"
assert 0
@functools.lru_cache(maxsize=None)
def isqrt(x):
assert isinstance(x, type(gmpy2.mpz(1)))
sq = gmpy2.isqrt(x)
if sq * sq == x:
return sq
return None
import click
@click.command()
@click.argument("maxcost", default=200)
@click.option("--output", "-o", default="problem1-res.txt")
@click.option("--maxdigits", default=1000)
@click.option("--maxwrite", default=10000)
def find_solutions(maxcost, output, maxdigits, maxwrite):
maxval = gmpy2.mpz(10 ** maxdigits)
# Number -> Sol
SOLS = {}
# cost -> Number -> Sol
SOLS_P = {p: {} for p in range(maxcost + 1)}
MSGS = []
def propose(cost, op, src1, src2=None):
nval = None
if op == "+":
nval = src1.val + src2.val
elif op == "-":
if src1.val >= src2.val:
nval = src1.val - src2.val
elif op == "*":
nval = src1.val * src2.val
elif op == "//":
if src2.val > 0 and src1.val % src2.val == 0:
nval = src1.val // src2.val
elif op == "**":
if (
src1.val > 1
and (src1.val.bit_length() - 1) * src2.val < maxval.bit_length()
):
nval = src1.val ** src2.val
elif op == "fact":
if src1.val > 2 and src1.val < maxval.bit_length() + 2:
nval = gmpy2.fac(src1.val)
elif op == "isqrt":
nval = isqrt(src1.val)
else:
assert 0
if nval is None or nval > maxval:
return None
assert cost == src1.cost * (src2.cost if src2 is not None else 1)
nmaxval = gmpy2.mpz(
max(nval, src1.maxval, src2.maxval if src2 is not None else 0)
)
if nval in SOLS and (SOLS[nval].cost < cost or SOLS[nval].maxval <= nmaxval):
return None
nsol = Sol(
nval,
cost,
None,
nmaxval,
op,
src1.val,
src2.val if src2 is not None else None,
)
SOLS_P[cost][nval] = nsol
SOLS[nval] = nsol
return nsol
for cost in range(2, maxcost + 1):
constsol = Sol(
gmpy2.mpz(cost), cost, str(cost), gmpy2.mpz(cost), None, None, None
)
if constsol.val not in SOLS or (
SOLS[constsol.val].cost
>= constsol.cost # This can't happen, just sanity-check remnant
and SOLS[constsol.val].maxval > constsol.maxval
):
SOLS_P[cost][constsol.val] = constsol
SOLS[constsol.val] = constsol
# Apply all binary ops
divisors = [i for i in range(2, cost) if cost % i == 0]
tot = sum(len(SOLS_P[ca]) * len(SOLS_P[cost // ca]) for ca in divisors)
with tqdm.trange(tot, leave=False) as bar:
for ca in divisors:
cb = cost // ca
for src1 in SOLS_P[ca].values():
for src2 in SOLS_P[cb].values():
bar.update()
for op in ["+", "-", "*", "//", "**"]:
propose(cost, op, src1, src2)
# Apply unary ops
unary_ops = 0
candidates = list(SOLS_P[cost].values())
while candidates:
src1 = candidates.pop()
for op in ["fact", "isqrt"]:
nsol = propose(cost, op, src1)
if nsol is not None:
unary_ops += 1
candidates.append(nsol)
sols_under = sum(gmpy2.mpz(a) in SOLS for a in range(maxwrite))
MSGS.append(
f"Cost {cost:4}: {len(SOLS_P[cost]):8} new exact-cost sols, {unary_ops:5} of that added by unary, {len(SOLS):10} all sols, of that {sols_under:6} under {maxwrite}"
)
if output is not None:
sols_sorted = sorted(s for s in SOLS.values() if s.val <= maxwrite)
missing_under = [
str(a) for a in range(maxwrite) if gmpy2.mpz(a) not in SOLS
]
with tempfile.NamedTemporaryFile("wt", delete=False) as f:
ms = "\n# ".join(MSGS)
f.write(
f"# Log for maxcost={maxcost}, maxdigits={maxdigits}, maxwrite={maxwrite} \n# {ms}\n#\n# Missing under {maxwrite}: {', '.join(missing_under)}\n\n"
)
f.write(
"\n".join(
f"{s.val:8} (cost {s.cost:4}, maxval {float(s.maxval):10.3e}) {sol_expr(SOLS, s)}"
for s in sols_sorted
)
)
with tempfile.NamedTemporaryFile("wt", delete=False) as fcsv:
writer = csv.writer(fcsv)
writer.writerow(["value", "cost", "maxval", "expr"])
for s in sols_sorted:
writer.writerow([s.val, s.cost, s.maxval, sol_expr(SOLS, s)])
shutil.move(f.name, output)
shutil.move(fcsv.name, output + ".csv")
print(MSGS[-1])
if __name__ == "__main__":
find_solutions()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment