Skip to content

Instantly share code, notes, and snippets.

@Parcly-Taxel
Last active September 24, 2021 05:33
Show Gist options
  • Save Parcly-Taxel/224bae32a3dca7ab0c85ade5a045a923 to your computer and use it in GitHub Desktop.
Save Parcly-Taxel/224bae32a3dca7ab0c85ade5a045a923 to your computer and use it in GitHub Desktop.
SAT-based solver for PSE #67058 and related problems
#!/usr/bin/env python3
# https://puzzling.stackexchange.com/questions/52527/seven-matches-digits (part a)
import re
from subprocess import run
import numpy as np
lit_re = re.compile(r"-?\d+")
verts = np.arange(18, dtype=int).reshape(2,3,3)
edges = tuple({i,j} for i in range(18) for j in range(i+1,18) if sum(abs(np.argwhere(verts == i)-np.argwhere(verts == j)).flat) == 1 and (i,j) != (4,13))
print(edges)
bvars = len(edges) # base variables
positions = [verts[0,:,:2], verts[0,:,1:], verts[:,0,:].T, np.rot90(verts[0,:2,:]), np.rot90(verts[0,1:,:]),
verts[1,:2,:].T, verts[1,1:,:].T, np.fliplr(verts[1,:,:2]), np.fliplr(verts[1,:,1:]),
np.rot90(verts[:,2,:]), np.rot90(verts[:,:,0]), verts[:,:,2].T]
positions += [np.rot90(pos, 2) for pos in positions]
positions += [np.fliplr(pos) for pos in positions]
nvars = bvars + 10*len(positions)
nclauses = 80*len(positions) + 10
# Must use the "Casio 7" with four segments, otherwise there is no solution at all
digits = "1110111 0010010 1011101 1011011 0111010 1101011 1101111 1110010 1111111 1111011"
digits = [np.array([1 if c == '1' else -1 for c in bp]) for bp in digits.split()]
clauses = []
for (pnum, pos) in enumerate(positions):
for (d, bp) in enumerate(digits):
pvarnum = bvars+1 + 10*pnum+d
plits = np.array([edges.index(set(edge))+1 for edge in (pos[0,:], pos[:2,0], pos[:2,1], pos[1,:], pos[1:,0], pos[1:,1], pos[2,:])])
plits *= digits[d]
for l in plits:
clauses.append([l, -pvarnum])
clauses.append(list(-plits) + [pvarnum])
for d in range(10):
dvars = list(range(bvars+1+d, nvars+1+d, 10))
if d == 4: # to factor out symmetries
dvars = dvars[:3]
if d in (7,6): # even with the Casio 7 it's impossible to show the 2 and 5 differently
dvars = dvars[:len(dvars)//2]
clauses.append(dvars)
cnfn = "ssdcnf"
with open(cnfn, 'w') as f:
print(f"p cnf {nvars} {nclauses}", file=f)
for clause in clauses:
print(" ".join(map(str, clause + [0])), file=f)
nsols = 0
while True:
proc = run(["./cadical", "-q", cnfn], capture_output=True, encoding="utf-8")
if proc.returncode != 10:
break
nsols += 1
sol = list(map(int, lit_re.findall(proc.stdout)[:bvars]))
print(sol)
with open(cnfn, 'r') as f:
contents = f.read().partition("\n")[2]
with open(cnfn, 'w') as f:
print(f"p cnf {nvars} {nclauses+nsols}", file=f)
f.write(contents)
print(" ".join(str(-x) for x in sol), "0", file=f)
#!/usr/bin/env python3
import re
from subprocess import run
import numpy as np
lit_re = re.compile(r"-?\d+")
verts = np.arange(12, dtype=int).reshape(3,2,2)
edges = tuple({i,j} for i in range(12) for j in range(i+1,12) if sum(abs(np.argwhere(verts == i)-np.argwhere(verts == j)).flat) == 1)
print(edges)
bvars = len(edges) # base variables
positions = []
for k in range(4):
v = np.rot90(verts, k, (1,2))
A = np.array([[v[0,0,0], v[0,0,1], 0],
[v[0,1,0], v[0,1,1], v[0,0,1]],
[v[1,1,0], v[1,1,1], v[1,0,1]],
[v[2,1,0], v[2,1,1], v[2,0,1]],
[v[2,0,0], v[2,0,1], 0]])
positions.extend([A[:3,:2], A[1:4,:2], A[2:,:2], np.rot90(A[1:3,:]), np.rot90(A[2:4,:])])
positions += [np.rot90(pos, 2) for pos in positions]
positions += [np.fliplr(pos) for pos in positions]
nvars = bvars + 10*len(positions)
nclauses = 80*len(positions) + 10
digits = "1110111 0010010 1011101 1011011 0111010 1101011 1101111 1110010 1111111 1111011"
digits = [np.array([1 if c == '1' else -1 for c in bp]) for bp in digits.split()]
clauses = []
for (pnum, pos) in enumerate(positions):
for (d, bp) in enumerate(digits):
pvarnum = bvars+1 + 10*pnum+d
plits = np.array([edges.index(set(edge))+1 for edge in (pos[0,:], pos[:2,0], pos[:2,1], pos[1,:], pos[1:,0], pos[1:,1], pos[2,:])])
plits *= digits[d]
for l in plits:
clauses.append([l, -pvarnum])
clauses.append(list(-plits) + [pvarnum])
for d in range(10):
dvars = list(range(bvars+1+d, nvars+1+d, 10))
if d == 4:
dvars = dvars[:5]
clauses.append(dvars)
cnfn = "ssdcnf"
with open(cnfn, 'w') as f:
print(f"p cnf {nvars} {nclauses}", file=f)
for clause in clauses:
print(" ".join(map(str, clause + [0])), file=f)
nsols = 0
while True:
proc = run(["./cadical", "-q", cnfn], capture_output=True, encoding="utf-8")
if proc.returncode != 10:
break
nsols += 1
sol = list(map(int, lit_re.findall(proc.stdout)[:bvars]))
print(sol)
with open(cnfn, 'r') as f:
contents = f.read().partition("\n")[2]
with open(cnfn, 'w') as f:
print(f"p cnf {nvars} {nclauses+nsols}", file=f)
f.write(contents)
print(" ".join(str(-x) for x in sol), "0", file=f)
#!/usr/bin/env python3
# https://puzzling.stackexchange.com/questions/67058/fitting-7-segments-digits-on-smallest-rectangular-grid
import re
from subprocess import run
import numpy as np
lit_re = re.compile(r"-?\d+")
W = 3
H = 4
cnfn = "ssdcnf"
cnfile = open(cnfn, 'w')
# The number of variables and clauses are determined by W and H
numbasevars = (H+1)*W + H*(W+1)
numpositions = (H-1)*W + H*(W-1)
nvars = numbasevars + 10*numpositions
nclauses = 80*numpositions + 10
print(f"p cnf {nvars} {nclauses}", file=cnfile)
# Make the base variable grid
IND = np.zeros((W+H,W+H), dtype=int)
n = 0
for i in range(H+1):
for j in range(W):
n += 1
IND[i+j,H-i+j] = n
for i in range(H):
for j in range(W+1):
n += 1
IND[i+j,H-i-1+j] = n
d0 = np.array([[0,1,1], [1,-1,1], [1,1,0]])
d1 = np.array([[0,-1,-1], [-1,-1,1], [-1,1,0]])
d2 = np.array([[0,-1,1], [1,1,1], [1,-1,0]])
d3 = np.array([[0,-1,1], [-1,1,1], [1,1,0]])
d4 = np.array([[0,1,-1], [-1,1,1], [-1,1,0]])
d5 = np.array([[0,1,1], [-1,1,-1], [1,1,0]])
d6 = np.array([[0,1,1], [1,1,-1], [1,1,0]])
d7 = np.array([[0,1,1], [-1,-1,1], [-1,1,0]])
d8 = np.array([[0,1,1], [1,1,1], [1,1,0]])
d9 = np.array([[0,1,1], [-1,1,1], [1,1,0]])
digits = (d0, d1, d2, d3, d4, d5, d6, d7, d8, d9)
pnum = 0
for i in range(H-1):
for j in range(W):
subM = IND[i+j:3+i+j,H-2-i+j:H+1-i+j]
for (n, d) in enumerate(digits):
checknum = numbasevars+1+10*pnum+n
lits = subM * d
lits = lits[lits != 0]
for l in lits:
print(f"{l} {-checknum} 0", file=cnfile)
t = " ".join(str(-l) for l in lits)
print(f"{t} {checknum} 0", file=cnfile)
pnum += 1
for i in range(H):
for j in range(W-1):
subM = IND[i+j:3+i+j,H-1-i+j:H+2-i+j]
for (n, d) in enumerate(digits):
checknum = numbasevars+1+10*pnum+n
lits = subM * np.rot90(d)
lits = lits[lits != 0]
for l in lits:
print(f"{l} {-checknum} 0", file=cnfile)
t = " ".join(str(-l) for l in lits)
print(f"{t} {checknum} 0", file=cnfile)
pnum += 1
for d in range(10):
rang = range(numbasevars+1+d, 10*numpositions+numbasevars+1+d, 10)
rs = " ".join(str(l) for l in rang)
print(f"{rs} 0", file=cnfile)
cnfile.close()
# Find all solutions using your favourite SAT solver – I used CaDiCaL
nsols = 0
while True:
proc = run(["./cadical", "-q", cnfn], capture_output=True, encoding="utf-8")
if proc.returncode != 10:
break
nsols += 1
sol = list(map(int, lit_re.findall(proc.stdout)[:numbasevars]))
print(sol)
with open(cnfn, 'r') as f:
contents = f.read().partition("\n")[2]
with open(cnfn, 'w') as f:
print(f"p cnf {nvars} {nclauses+nsols}", file=f)
f.write(contents)
print(" ".join(str(-x) for x in sol), "0", file=f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment