Skip to content

Instantly share code, notes, and snippets.

@123jimin
Created February 15, 2023 09:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 123jimin/f0c20d53814dc7e25de3d578f67fa579 to your computer and use it in GitHub Desktop.
Save 123jimin/f0c20d53814dc7e25de3d578f67fa579 to your computer and use it in GitHub Desktop.
Snake Puzzle Solver
from ortools.sat.python import cp_model
class SolverCallback(cp_model.CpSolverSolutionCallback):
def __init__(self, vars):
super().__init__()
self.example_solutions = list()
self.num_solutions = 0
self.max_keep_solutions = 100
self.max_show_solutions = 100
self.vars = vars
def OnSolutionCallback(self):
self.num_solutions += 1
if self.max_keep_solutions > 0 and len(self.example_solutions) >= self.max_keep_solutions:
return
solution_str = "\n".join(" ".join('_' if self.Value(v) == 0 else '#' for v in row) for row in self.vars.var_grid)
self.example_solutions.append(solution_str)
if self.max_show_solutions > 0 and len(self.example_solutions) > self.max_show_solutions:
return
print(f"Solution #{len(self.example_solutions)} after {self.NumBranches()} branches:")
print(solution_str)
if self.max_show_solutions > 0 and len(self.example_solutions) >= self.max_show_solutions:
print("(further solutions will be not shown)")
class Solver:
def __init__(self, rows, cols):
self.model = cp_model.CpModel()
self.rows = rows
self.cols = cols
def __repr__(self):
return f"Solver(rows={repr(self.rows)}, cols={repr(self.cols)})"
def _init_vars(self):
R = len(self.rows)
C = len(self.cols)
model = self.model
# Cell variables
var_grid = self.var_grid = list(list(model.NewBoolVar(f"C[{r},{c}]") for c in range(C)) for r in range(R))
var_ind = self.var_ind = list(list(model.NewIntVar(0, R*C, f"I[{r},{c}]") for c in range(C)) for r in range(R))
var_start = self.var_start = model.NewIntVar(0, R*C-1, "START")
var_end = self.var_end = model.NewIntVar(0, R*C-1, "END")
# Condition: snake
model.Add(var_start < var_end)
for r in range(R):
for c in range(C):
ind = r*C + c
grid_cell = var_grid[r][c]
grid_ind = var_ind[r][c]
is_start = model.NewBoolVar(f"S[${r},${c}]")
is_middle = model.NewBoolVar(f"M[${r},${c}]")
is_end = model.NewBoolVar(f"E[${r},${c}]")
is_none = grid_cell.Not()
# Exactly one of four can be true
model.Add(is_start + is_middle + is_end + is_none == 1)
# enforce grid_ind
model.Add(grid_ind == 0).OnlyEnforceIf(is_none)
model.Add(grid_ind == 1).OnlyEnforceIf(is_start)
model.Add(grid_ind > 1).OnlyEnforceIf(is_middle)
model.Add(grid_ind > 1).OnlyEnforceIf(is_end)
# set is_start
model.Add(ind == var_start).OnlyEnforceIf(is_start)
model.Add(ind != var_start).OnlyEnforceIf(is_start.Not())
# set is_end
model.Add(ind == var_end).OnlyEnforceIf(is_end)
model.Add(ind != var_end).OnlyEnforceIf(is_end.Not())
neighbor_cells = []
neighbor_inds = []
neighbor_prev_inds = []
for (dr, dc) in ((-1, 0), (+1, 0), (0, -1), (0, +1)):
if 0 <= r+dr < R and 0 <= c+dc < C:
neighbor_cells.append(var_grid[r+dr][c+dc])
neighbor_ind = var_ind[r+dr][c+dc]
neighbor_inds.append(neighbor_ind)
is_prev_ind = model.NewBoolVar("")
model.Add(neighbor_ind+1 == grid_ind).OnlyEnforceIf(is_prev_ind)
model.Add(neighbor_ind+1 != grid_ind).OnlyEnforceIf(is_prev_ind.Not())
neighbor_prev_inds.append(is_prev_ind)
# enforce neighbor_count
model.Add(sum(neighbor_cells) == 1).OnlyEnforceIf(is_start)
model.Add(sum(neighbor_cells) == 2).OnlyEnforceIf(is_middle)
model.Add(sum(neighbor_cells) == 1).OnlyEnforceIf(is_end)
# enforce middle
model.Add(sum(neighbor_inds) == 2 * grid_ind).OnlyEnforceIf(is_middle)
model.AddBoolOr(neighbor_prev_inds).OnlyEnforceIf(is_middle)
model.AddBoolOr(neighbor_prev_inds).OnlyEnforceIf(is_end)
# Condition: sum
for r in range(R):
if self.rows[r] > 0:
model.Add(sum(var_grid[r]) == self.rows[r])
for c in range(C):
if self.cols[c] > 0:
model.Add(sum(var_grid[r][c] for r in range(R)) == self.cols[c])
def solve(self):
self._init_vars()
solver = cp_model.CpSolver()
callback = SolverCallback(self)
status = solver.SearchForAllSolutions(self.model, callback)
print(f"{solver.NumBooleans()} booleans, {solver.NumBranches()} branches, {solver.NumConflicts()} conflicts")
print(f"{callback.num_solutions} solutions (time taken: wall {solver.WallTime():.3f}s, user {solver.UserTime():.3f}s)")
if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE):
print(solver.StatusName(status))
return
if __name__ == "__main__":
rows = list(map(int, input("Clues on left, from top to bottom, with 0 for empty: ").split()))
cols = list(map(int, input("Clues on top, from left to right, with 0 for empty: ").split()))
solver = Solver(rows, cols)
print(solver)
solver.solve()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment