Skip to content

Instantly share code, notes, and snippets.

@WillyPillow
Created January 5, 2022 16:50
Show Gist options
  • Save WillyPillow/febaa44eb2c1b15c2d1357c0df7b762e to your computer and use it in GitHub Desktop.
Save WillyPillow/febaa44eb2c1b15c2d1357c0df7b762e to your computer and use it in GitHub Desktop.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-function-docstring,missing-class-docstring
# pylint: disable=invalid-name
import pprint
from pycryptosat import Solver
data = [
[2, 5, 8, 11],
[1, 5, 6, 7],
[4, 5, 10, 12],
[3, 5, 9, 13],
[1, 2, 3, 4],
[2, 7, 10, 13],
[2, 6, 9, 12],
[1, 11, 12, 13],
[4, 7, 9, 11],
[3, 6, 10, 11],
[1, 8, 9, 10],
[3, 7, 8, 12],
[4, 6, 8, 13],
]
n = len(data[0]) - 1
m = len(data)
assert m == n * n + n + 1
assert max(max(sub) for sub in data) == m
s = Solver()
def get_id(r, c, item):
return r * m * (n + 1) + c * m + item + 1
# Row constraints
for i in range(m):
for x in data[i]:
x -= 1
elems = [get_id(i, j, x) for j in range(n + 1)]
# At least one
s.add_clause(elems)
# At most one
for e1 in elems:
for e2 in elems:
if e1 != e2:
s.add_clause([-e1, -e2])
# Column constraints
for x in range(m):
for j in range(n + 1):
# At most one
for i1 in range(m):
for i2 in range(m):
if i1 != i2:
s.add_clause([-get_id(i1, j, x), -get_id(i2, j, x)])
# No two should occupy the same space
for i in range(m):
for j in range(n + 1):
# At most one
for x1 in range(m):
for x2 in range(m):
if x1 != x2:
s.add_clause([-get_id(i, j, x1), -get_id(i, j, x2)])
sat, solution = s.solve()
ans = []
for i in range(m):
row_ans = []
for j in range(n + 1):
for x in range(m):
if solution[get_id(i, j, x)]:
row_ans.append(x + 1)
ans.append(row_ans)
pprint.pprint(ans)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment