Skip to content

Instantly share code, notes, and snippets.

@hashlash
Last active April 27, 2020 16:02
Show Gist options
  • Save hashlash/966bbceadc9bfbf9642298711a9db37c to your computer and use it in GitHub Desktop.
Save hashlash/966bbceadc9bfbf9642298711a9db37c to your computer and use it in GitHub Desktop.
Battleship puzzle solver
from itertools import combinations, product
from pysat.solvers import Minisat22
from sympy.core.symbol import Symbol
from sympy.logic.boolalg import And, Equivalent, Not, Or, to_cnf
def simplify_op(op, args):
simplified = op(*args)
if isinstance(simplified, op):
return simplified.args
return [simplified]
class Battleships:
"""
https://pdfs.semanticscholar.org/b623/82bcebee19a7cfeb1cfc18e6b1b05c680dfb.pdf
C1: all ships in the fleet are put in the grid;
C2: the indications in the initial grid are respected;
C3: no two ships occupy adjacent (orthogonally or diagonally) squares;
C4: the number of ship segments in column (row) i is equal to the ith value of
the column (row) tally.
"""
def __init__(self, row, col, row_seg, col_seg, ships, initial_map):
self.row = row
self.col = col
self.row_seg = row_seg
self.col_seg = col_seg
self.ships = ships
self._solution = None
self._symbols = [[Symbol(str(self.var(i, j)))
for j in range(col)] for i in range(row)]
iter_ij = [(i, j) for i in range(row) for j in range(col)]
s = self.symbol
self.c1 = And(*[self.cardinality(
simplify_op(Or, [self.ship_pos(i, j, k) for i, j in iter_ij]),
self.ships[k],
) for k in ships.keys()])
self.c2 = self.map_problem(initial_map)
self.c3 = And(*[self.surroundings(i, j) for i, j in iter_ij])
self.c4 = And(And(*[self.cardinality(
[s(i, j) for j in range(col)],
self.row_seg[i]
) for i in range(row)]),
And(*[self.cardinality(
[s(i, j) for i in range(row)],
self.col_seg[j]
) for j in range(col)]))
self.problem = And(self.c1, self.c2, self.c3, self.c4)
def symbol(self, i, j, fallback=False):
if not (0 <= i < self.row) or not (0 <= j < self.col):
return fallback
return self._symbols[i][j]
def var(self, i, j, fallback=None):
if not (0 <= i < self.row) or not (0 <= j < self.col):
return fallback
return i*self.col + j+1
def map_problem(self, initial_map):
"""
sympy
"""
s = self.symbol
expr = True
for i, row_str in enumerate(initial_map.split()):
for j, c in enumerate(row_str):
if c == '.':
continue
if c == 'x':
expr &= ~s(i, j)
if c == 'o':
expr &= And(s(i, j),
Not(s(i-1, j)), Not(s(i+1, j)),
Not(s(i, j-1)), Not(s(i, j+1)))
if c == '^':
expr &= s(i, j) & s(i+1, j, True) & Not(s(i-1, j))
if c == '<':
expr &= s(i, j) & s(i, j+1, True) & Not(s(i, j-1))
if c == '>':
expr &= s(i, j) & s(i, j-1, True) & Not(s(i, j+1))
if c == 'v':
expr &= s(i, j) & s(i-1, j, True) & Not(s(i+1, j))
if c == '|':
expr &= s(i, j) & s(i-1, j, True) & s(i+1, j, True)
if c == '-':
expr &= s(i, j) & s(i, j-1, True) & s(i, j+1, True)
return expr
def cardinality(self, literals, size):
"""
sympy
"""
expr = True
for idx in combinations(range(len(literals)), size):
selected = [x for i, x in enumerate(literals) if i in idx]
unselected = [x for i, x in enumerate(literals) if i not in idx]
expr &= Equivalent(And(*selected), ~Or(*unselected))
return expr
def ship_pos(self, i, j, l):
"""
sympy
"""
s = self.symbol
v = And(*[s(x, j) for x in range(i, i+l)])
v &= ~s(i-1, j) if i > 0 else True
v &= ~s(i+l, j) if i+l < self.row else True
h = And(*[s(i, x) for x in range(j, j+l)])
h &= ~s(i, j-1) if j > 0 else True
h &= ~s(i, j+l) if j+l < self.col else True
return v & h if l == 1 else v ^ h
def surroundings(self, i, j):
"""
sympy
"""
s = self.symbol
return And(s(i, j) >> ~Or(s(i-1, j-1),
s(i-1, j+1),
s(i+1, j-1),
s(i+1, j+1)),
And(s(i, j),
Or(s(i-1, j),
s(i+1, j))) >> ~Or(s(i, j-1),
s(i, j+1)),
And(s(i, j),
Or(s(i, j-1),
s(i, j+1))) >> ~Or(s(i-1, j),
s(i+1, j)))
def solve(self):
if self._solution is not None:
return self._solution
with Minisat22(bootstrap_with=self.pysat_cnf(to_cnf(self.problem))) as m:
m.solve()
self._solution = m.get_model()
return self._solution
def map_solve(self):
model = self.solve()
if model is None:
return
solved_map = ""
for i in range(self.row):
map_row = ""
for j in range(self.col):
if self.var(i, j) in model:
map_row += '#'
elif -self.var(i, j) in model:
map_row += '.'
else:
map_row += '?'
solved_map += map_row + '\n'
return solved_map
def pysat_cnf(self, sympy_cnf):
cnf = []
for conj in sympy_cnf.args:
if conj.is_Atom:
cnf.append([int(str(conj))])
elif conj.is_Not:
cnf.append([-int(str(conj.args[0]))])
else:
cnf.append([
int(str(x)) if x.is_Atom else -int(str(x.args[0]))
for x in conj.args
])
return cnf
# https://puzzlemadness.co.uk/battleships/2020/4/13
# b = Battleships(
# row=10,
# col=10,
# row_seg=[4, 2, 4, 2, 1, 1, 1, 6, 0, 4],
# col_seg=[0, 3, 1, 2, 5, 2, 1, 4, 1, 6],
# ships={1: 4,
# 2: 3,
# 3: 2,
# 4: 1,
# 5: 1},
# initial_map=""".....xx...
# ..........
# .o.......v
# ..........
# ..........
# ..........
# ..........
# ...-..x...
# ..........
# .o..x.x...""",
# )
'''
b = Battleships(
row=4,
col=4,
row_seg=[1, 1, 2, 2],
col_seg=[3, 0, 1, 2],
ships={1: 1,
2: 1,
3: 1},
initial_map="""..o.
^...
....
...v""",
)
'''
"""
solution:
..o.
^...
#..^
v..v
"""
'''
b = Battleships(
row=1,
col=4,
row_seg=[3],
col_seg=[1, 1, 1, 0],
ships={1: 0,
2: 0,
3: 1},
initial_map="""....""",
)
'''
b = Battleships(
row=1,
col=2,
row_seg=[1],
col_seg=[1,0],
ships={1: 1,
2: 0,
3: 0},
initial_map="..",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment