Skip to content

Instantly share code, notes, and snippets.

@gavento
Created January 18, 2021 14:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gavento/504e268a1ecbb31174e40c91cc8995c5 to your computer and use it in GitHub Desktop.
Save gavento/504e268a1ecbb31174e40c91cc8995c5 to your computer and use it in GitHub Desktop.
import csv
import functools
import math
import os
import shutil
import sys
import tempfile
from collections import namedtuple
import gmpy2
import tqdm
Sol = namedtuple("Sol", ["val", "cost", "expr", "maxval", "op", "src1", "src2"])
def sol_expr(sols_table, sol):
if sol.expr is not None:
return sol.expr
if sol.src2 is not None:
return f"({sol_expr(sols_table, sols_table[sol.src1])} {sol.op} {sol_expr(sols_table, sols_table[sol.src2])})"
if sol.src1 is not None:
return f"{sol.op}({sol_expr(sols_table, sols_table[sol.src1])})"
assert 0
@functools.lru_cache(maxsize=None)
def isqrt(x):
assert isinstance(x, type(gmpy2.mpz(1)))
sq = gmpy2.isqrt(x)
if sq * sq == x:
return sq
return None
import click
@click.command()
@click.argument("maxcost", default=200)
@click.option("--output", "-o", default="problem1-res.txt")
@click.option("--maxdigits", default=1000)
@click.option("--maxwrite", default=10000)
def find_solutions(maxcost, output, maxdigits, maxwrite):
maxval = gmpy2.mpz(10 ** maxdigits)
# Number -> Sol
SOLS = {}
# cost -> Number -> Sol
SOLS_P = {p: {} for p in range(maxcost + 1)}
MSGS = []
def propose(cost, op, src1, src2=None):
nval = None
if op == "+":
nval = src1.val + src2.val
elif op == "-":
if src1.val >= src2.val:
nval = src1.val - src2.val
elif op == "*":
nval = src1.val * src2.val
elif op == "//":
if src2.val > 0 and src1.val % src2.val == 0:
nval = src1.val // src2.val
elif op == "**":
if (
src1.val > 1
and (src1.val.bit_length() - 1) * src2.val < maxval.bit_length()
):
nval = src1.val ** src2.val
elif op == "fact":
if src1.val > 2 and src1.val < maxval.bit_length() + 2:
nval = gmpy2.fac(src1.val)
elif op == "isqrt":
nval = isqrt(src1.val)
else:
assert 0
if nval is None or nval > maxval:
return None
assert cost == src1.cost * (src2.cost if src2 is not None else 1)
nmaxval = gmpy2.mpz(
max(nval, src1.maxval, src2.maxval if src2 is not None else 0)
)
if nval in SOLS and (SOLS[nval].cost < cost or SOLS[nval].maxval <= nmaxval):
return None
nsol = Sol(
nval,
cost,
None,
nmaxval,
op,
src1.val,
src2.val if src2 is not None else None,
)
SOLS_P[cost][nval] = nsol
SOLS[nval] = nsol
return nsol
for cost in range(2, maxcost + 1):
constsol = Sol(
gmpy2.mpz(cost), cost, str(cost), gmpy2.mpz(cost), None, None, None
)
if constsol.val not in SOLS or (
SOLS[constsol.val].cost
>= constsol.cost # This can't happen, just sanity-check remnant
and SOLS[constsol.val].maxval > constsol.maxval
):
SOLS_P[cost][constsol.val] = constsol
SOLS[constsol.val] = constsol
# Apply all binary ops
divisors = [i for i in range(2, cost) if cost % i == 0]
tot = sum(len(SOLS_P[ca]) * len(SOLS_P[cost // ca]) for ca in divisors)
with tqdm.trange(tot, leave=False) as bar:
for ca in divisors:
cb = cost // ca
for src1 in SOLS_P[ca].values():
for src2 in SOLS_P[cb].values():
bar.update()
for op in ["+", "-", "*", "//", "**"]:
propose(cost, op, src1, src2)
# Apply unary ops
unary_ops = 0
candidates = list(SOLS_P[cost].values())
while candidates:
src1 = candidates.pop()
for op in ["fact", "isqrt"]:
nsol = propose(cost, op, src1)
if nsol is not None:
unary_ops += 1
candidates.append(nsol)
sols_under = sum(gmpy2.mpz(a) in SOLS for a in range(maxwrite))
MSGS.append(
f"Cost {cost:4}: {len(SOLS_P[cost]):8} new exact-cost sols, {unary_ops:5} of that added by unary, {len(SOLS):10} all sols, of that {sols_under:6} under {maxwrite}"
)
if output is not None:
sols_sorted = sorted(s for s in SOLS.values() if s.val <= maxwrite)
missing_under = [
str(a) for a in range(maxwrite) if gmpy2.mpz(a) not in SOLS
]
with tempfile.NamedTemporaryFile("wt", delete=False) as f:
ms = "\n# ".join(MSGS)
f.write(
f"# Log for maxcost={maxcost}, maxdigits={maxdigits}, maxwrite={maxwrite} \n# {ms}\n#\n# Missing under {maxwrite}: {', '.join(missing_under)}\n\n"
)
f.write(
"\n".join(
f"{s.val:8} (cost {s.cost:4}, maxval {float(s.maxval):10.3e}) {sol_expr(SOLS, s)}"
for s in sols_sorted
)
)
with tempfile.NamedTemporaryFile("wt", delete=False) as fcsv:
writer = csv.writer(fcsv)
writer.writerow(["value", "cost", "maxval", "expr"])
for s in sols_sorted:
writer.writerow([s.val, s.cost, s.maxval, sol_expr(SOLS, s)])
shutil.move(f.name, output)
shutil.move(fcsv.name, output + ".csv")
print(MSGS[-1])
if __name__ == "__main__":
find_solutions()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment