Skip to content

Instantly share code, notes, and snippets.

@270ajay
Created December 21, 2023 05:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 270ajay/153a12627ac020a111f8a5855f88fcb3 to your computer and use it in GitHub Desktop.
Save 270ajay/153a12627ac020a111f8a5855f88fcb3 to your computer and use it in GitHub Desktop.
Constraint Programming solver module 1
# Course: https://www.edx.org/learn/computer-programming/universite-catholique-de-louvain-constraint-programming
# Backtracking search algorithm for solving n-queens problem
import time
from typing import Callable
class NQueensFilter:
def __init__(self, number_of_queens: int):
self._number_of_queens: int = number_of_queens
self._num_solutions: int = 0
self._queens: list[int] = [-1] * number_of_queens
self._num_nodes: int = 0
self._time: float = time.perf_counter()
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None:
self._depth_first_search(0, solution_callback)
print("Number of solutions found:", self._num_solutions)
print("Number of nodes:", self._num_nodes)
self._time = time.perf_counter() - self._time
print("Time taken:", self._time)
def _depth_first_search(
self, index, solution_callback: Callable[[list[int]], None]
) -> None:
self._num_nodes += 1
if index == self._number_of_queens:
if self._are_constraints_satisfied():
self._num_solutions += 1
if solution_callback is not None:
solution_callback(self._queens)
else:
for value in range(0, self._number_of_queens):
self._queens[index] = value
self._depth_first_search(index + 1, solution_callback)
def _are_constraints_satisfied(self) -> bool:
for index in range(0, self._number_of_queens):
for index2 in range(index + 1, self._number_of_queens):
# No two queens on same row
if self._queens[index] == self._queens[index2]:
return False
# No two queens on diagonal
if abs(self._queens[index] - self._queens[index2]) == index2 - index:
return False
return True
class NQueensPrune:
def __init__(self, number_of_queens: int):
self._number_of_queens: int = number_of_queens
self._num_solutions: int = 0
self._queens: list[int] = [-1] * number_of_queens
self._num_nodes: int = 0
self._time: float = time.perf_counter()
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None:
self._depth_first_search(0, solution_callback)
print("Number of solutions found:", self._num_solutions)
print("Number of nodes:", self._num_nodes)
self._time = time.perf_counter() - self._time
print("Time taken:", self._time)
def _depth_first_search(
self, index, solution_callback: Callable[[list[int]], None]
) -> None:
self._num_nodes += 1
if index == self._number_of_queens:
self._num_solutions += 1
if solution_callback is not None:
solution_callback(self._queens)
else:
for value in range(0, self._number_of_queens):
self._queens[index] = value
if self._are_constraints_satisfied(index):
self._depth_first_search(index + 1, solution_callback)
def _are_constraints_satisfied(self, index2: int) -> bool:
for index in range(0, index2):
# No two queens on same row
if self._queens[index] == self._queens[index2]:
return False
# No two queens on diagonal
if abs(self._queens[index] - self._queens[index2]) == index2 - index:
return False
return True
if __name__ == "__main__":
n_queens_filter = NQueensFilter(8)
n_queens_filter.solve(lambda x: print(x))
print("==========================")
n_queens_prune = NQueensPrune(8)
n_queens_prune.solve()
# Course: https://www.edx.org/learn/computer-programming/universite-catholique-de-louvain-constraint-programming
# Tiny CP library
import time
from abc import ABC, abstractmethod
from typing import Callable
import bitarray
class SolverInconsistency(Exception):
pass
class Constraint(ABC):
@abstractmethod
def propagate(self) -> bool:
pass
class NotEqual(Constraint):
def __init__(self, var_x: "Variable", var_y: "Variable", offset: int = 0):
self._var_x = var_x
self._var_y = var_y
self._offset = offset
def propagate(self) -> bool:
if self._var_x.domain.is_fixed():
return self._var_y.domain.remove(self._var_x.domain.min() - self._offset)
if self._var_y.domain.is_fixed():
return self._var_x.domain.remove(self._var_y.domain.min() + self._offset)
return False
class Variable:
def __init__(self, name: str, domain_size: int):
self.domain: Domain = Domain(domain_size)
self.name = name
def __repr__(self):
return self.name
class Domain:
def __init__(self, domain_size: int = 0, domain: bitarray.bitarray = None):
if domain is None:
self._values: bitarray.bitarray = bitarray.bitarray(domain_size)
self._values.setall(1)
else:
self._values: bitarray.bitarray = domain
def is_fixed(self) -> bool:
return self.size() == 1
def size(self) -> int:
return self._values.count(1)
def min(self) -> int:
min_index = -1
for i in range(len(self._values)):
if self._values[i] == True:
min_index = i
break
if min_index == -1:
raise SolverInconsistency()
return min_index
def remove(self, index: int) -> bool:
if (0 <= index) and (index < len(self._values)):
if self._values[index] == True:
self._values[index] = False
if self.size() == 0:
raise SolverInconsistency()
return True
return False
def fix(self, index) -> None:
if self._values[index] == False:
raise SolverInconsistency()
self._values.setall(0)
self._values[index] = True
def clone(self) -> "Domain":
return Domain(domain=self._values.copy())
class TinyCSP:
def __init__(self):
self._constraints: list[Constraint] = []
self._variables: list[Variable] = []
self._num_solutions: int = 0
self._time: float = time.perf_counter()
self._num_nodes: int = 0
def make_var(self, name: str, domain_size: int) -> Variable:
var = Variable(name, domain_size)
self._variables.append(var)
return var
def make_non_equality(self, var_x: Variable, var_y: Variable, offset: int) -> None:
self._constraints.append(NotEqual(var_x, var_y, offset))
self._fix_point()
def _fix_point(self) -> None:
fix = False
while not fix:
fix = True
for constraint in self._constraints:
fix = not constraint.propagate() and fix
def _get_first_not_fixed_var(self) -> Variable:
first_not_fixed_var = None
for var in self._variables:
if not var.domain.is_fixed():
first_not_fixed_var = var
break
return first_not_fixed_var
def _back_up_domains(self) -> list[Domain]:
backed_up_domains = []
for var in self._variables:
backed_up_domains.append(var.domain.clone())
return backed_up_domains
def _restore_domains(self, backed_up_domains: list[Domain]) -> None:
for index, var in enumerate(self._variables):
var.domain = backed_up_domains[index]
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None:
self._depth_first_search(solution_callback)
print("Number of solutions found:", self._num_solutions)
print("Number of nodes:", self._num_nodes)
self._time = time.perf_counter() - self._time
print("Time taken:", self._time)
def _depth_first_search(self, solution_callback: Callable[[list[int]], None]):
self._num_nodes += 1
not_fixed_var = self._get_first_not_fixed_var()
if not_fixed_var is None:
# all variables fixed, a solution is found
self._num_solutions += 1
if solution_callback is not None:
solution_list = [var.domain.min() for var in self._variables]
solution_callback(solution_list)
else:
value = not_fixed_var.domain.min()
backed_up_domains = self._back_up_domains()
# left branch not_fixed_var = value
try:
not_fixed_var.domain.fix(value)
self._fix_point()
self._depth_first_search(solution_callback)
except SolverInconsistency:
pass
self._restore_domains(backed_up_domains)
# right branch not_fixed_var != value
try:
not_fixed_var.domain.remove(value)
self._fix_point()
self._depth_first_search(solution_callback)
except SolverInconsistency:
pass
if __name__ == "__main__":
NUMBER_OF_QUEENS = 8
solver = TinyCSP()
queen_vars = []
for i in range(NUMBER_OF_QUEENS):
queen_vars.append(solver.make_var(f"Queen{i}", NUMBER_OF_QUEENS))
for i in range(NUMBER_OF_QUEENS):
for j in range(i + 1, NUMBER_OF_QUEENS):
# Queens not on same row
solver.make_non_equality(queen_vars[i], queen_vars[j], 0)
# Queens not on same left diagonal
solver.make_non_equality(queen_vars[i], queen_vars[j], i - j)
# Queens not on same right diagonal
solver.make_non_equality(queen_vars[i], queen_vars[j], j - i)
solver.solve(lambda x: print(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment