Created
December 21, 2023 05:12
-
-
Save 270ajay/153a12627ac020a111f8a5855f88fcb3 to your computer and use it in GitHub Desktop.
Constraint Programming solver module 1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Course: https://www.edx.org/learn/computer-programming/universite-catholique-de-louvain-constraint-programming | |
# Backtracking search algorithm for solving n-queens problem | |
import time | |
from typing import Callable | |
class NQueensFilter: | |
def __init__(self, number_of_queens: int): | |
self._number_of_queens: int = number_of_queens | |
self._num_solutions: int = 0 | |
self._queens: list[int] = [-1] * number_of_queens | |
self._num_nodes: int = 0 | |
self._time: float = time.perf_counter() | |
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None: | |
self._depth_first_search(0, solution_callback) | |
print("Number of solutions found:", self._num_solutions) | |
print("Number of nodes:", self._num_nodes) | |
self._time = time.perf_counter() - self._time | |
print("Time taken:", self._time) | |
def _depth_first_search( | |
self, index, solution_callback: Callable[[list[int]], None] | |
) -> None: | |
self._num_nodes += 1 | |
if index == self._number_of_queens: | |
if self._are_constraints_satisfied(): | |
self._num_solutions += 1 | |
if solution_callback is not None: | |
solution_callback(self._queens) | |
else: | |
for value in range(0, self._number_of_queens): | |
self._queens[index] = value | |
self._depth_first_search(index + 1, solution_callback) | |
def _are_constraints_satisfied(self) -> bool: | |
for index in range(0, self._number_of_queens): | |
for index2 in range(index + 1, self._number_of_queens): | |
# No two queens on same row | |
if self._queens[index] == self._queens[index2]: | |
return False | |
# No two queens on diagonal | |
if abs(self._queens[index] - self._queens[index2]) == index2 - index: | |
return False | |
return True | |
class NQueensPrune: | |
def __init__(self, number_of_queens: int): | |
self._number_of_queens: int = number_of_queens | |
self._num_solutions: int = 0 | |
self._queens: list[int] = [-1] * number_of_queens | |
self._num_nodes: int = 0 | |
self._time: float = time.perf_counter() | |
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None: | |
self._depth_first_search(0, solution_callback) | |
print("Number of solutions found:", self._num_solutions) | |
print("Number of nodes:", self._num_nodes) | |
self._time = time.perf_counter() - self._time | |
print("Time taken:", self._time) | |
def _depth_first_search( | |
self, index, solution_callback: Callable[[list[int]], None] | |
) -> None: | |
self._num_nodes += 1 | |
if index == self._number_of_queens: | |
self._num_solutions += 1 | |
if solution_callback is not None: | |
solution_callback(self._queens) | |
else: | |
for value in range(0, self._number_of_queens): | |
self._queens[index] = value | |
if self._are_constraints_satisfied(index): | |
self._depth_first_search(index + 1, solution_callback) | |
def _are_constraints_satisfied(self, index2: int) -> bool: | |
for index in range(0, index2): | |
# No two queens on same row | |
if self._queens[index] == self._queens[index2]: | |
return False | |
# No two queens on diagonal | |
if abs(self._queens[index] - self._queens[index2]) == index2 - index: | |
return False | |
return True | |
if __name__ == "__main__": | |
n_queens_filter = NQueensFilter(8) | |
n_queens_filter.solve(lambda x: print(x)) | |
print("==========================") | |
n_queens_prune = NQueensPrune(8) | |
n_queens_prune.solve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Course: https://www.edx.org/learn/computer-programming/universite-catholique-de-louvain-constraint-programming | |
# Tiny CP library | |
import time | |
from abc import ABC, abstractmethod | |
from typing import Callable | |
import bitarray | |
class SolverInconsistency(Exception): | |
pass | |
class Constraint(ABC): | |
@abstractmethod | |
def propagate(self) -> bool: | |
pass | |
class NotEqual(Constraint): | |
def __init__(self, var_x: "Variable", var_y: "Variable", offset: int = 0): | |
self._var_x = var_x | |
self._var_y = var_y | |
self._offset = offset | |
def propagate(self) -> bool: | |
if self._var_x.domain.is_fixed(): | |
return self._var_y.domain.remove(self._var_x.domain.min() - self._offset) | |
if self._var_y.domain.is_fixed(): | |
return self._var_x.domain.remove(self._var_y.domain.min() + self._offset) | |
return False | |
class Variable: | |
def __init__(self, name: str, domain_size: int): | |
self.domain: Domain = Domain(domain_size) | |
self.name = name | |
def __repr__(self): | |
return self.name | |
class Domain: | |
def __init__(self, domain_size: int = 0, domain: bitarray.bitarray = None): | |
if domain is None: | |
self._values: bitarray.bitarray = bitarray.bitarray(domain_size) | |
self._values.setall(1) | |
else: | |
self._values: bitarray.bitarray = domain | |
def is_fixed(self) -> bool: | |
return self.size() == 1 | |
def size(self) -> int: | |
return self._values.count(1) | |
def min(self) -> int: | |
min_index = -1 | |
for i in range(len(self._values)): | |
if self._values[i] == True: | |
min_index = i | |
break | |
if min_index == -1: | |
raise SolverInconsistency() | |
return min_index | |
def remove(self, index: int) -> bool: | |
if (0 <= index) and (index < len(self._values)): | |
if self._values[index] == True: | |
self._values[index] = False | |
if self.size() == 0: | |
raise SolverInconsistency() | |
return True | |
return False | |
def fix(self, index) -> None: | |
if self._values[index] == False: | |
raise SolverInconsistency() | |
self._values.setall(0) | |
self._values[index] = True | |
def clone(self) -> "Domain": | |
return Domain(domain=self._values.copy()) | |
class TinyCSP: | |
def __init__(self): | |
self._constraints: list[Constraint] = [] | |
self._variables: list[Variable] = [] | |
self._num_solutions: int = 0 | |
self._time: float = time.perf_counter() | |
self._num_nodes: int = 0 | |
def make_var(self, name: str, domain_size: int) -> Variable: | |
var = Variable(name, domain_size) | |
self._variables.append(var) | |
return var | |
def make_non_equality(self, var_x: Variable, var_y: Variable, offset: int) -> None: | |
self._constraints.append(NotEqual(var_x, var_y, offset)) | |
self._fix_point() | |
def _fix_point(self) -> None: | |
fix = False | |
while not fix: | |
fix = True | |
for constraint in self._constraints: | |
fix = not constraint.propagate() and fix | |
def _get_first_not_fixed_var(self) -> Variable: | |
first_not_fixed_var = None | |
for var in self._variables: | |
if not var.domain.is_fixed(): | |
first_not_fixed_var = var | |
break | |
return first_not_fixed_var | |
def _back_up_domains(self) -> list[Domain]: | |
backed_up_domains = [] | |
for var in self._variables: | |
backed_up_domains.append(var.domain.clone()) | |
return backed_up_domains | |
def _restore_domains(self, backed_up_domains: list[Domain]) -> None: | |
for index, var in enumerate(self._variables): | |
var.domain = backed_up_domains[index] | |
def solve(self, solution_callback: Callable[[list[int]], None] = None) -> None: | |
self._depth_first_search(solution_callback) | |
print("Number of solutions found:", self._num_solutions) | |
print("Number of nodes:", self._num_nodes) | |
self._time = time.perf_counter() - self._time | |
print("Time taken:", self._time) | |
def _depth_first_search(self, solution_callback: Callable[[list[int]], None]): | |
self._num_nodes += 1 | |
not_fixed_var = self._get_first_not_fixed_var() | |
if not_fixed_var is None: | |
# all variables fixed, a solution is found | |
self._num_solutions += 1 | |
if solution_callback is not None: | |
solution_list = [var.domain.min() for var in self._variables] | |
solution_callback(solution_list) | |
else: | |
value = not_fixed_var.domain.min() | |
backed_up_domains = self._back_up_domains() | |
# left branch not_fixed_var = value | |
try: | |
not_fixed_var.domain.fix(value) | |
self._fix_point() | |
self._depth_first_search(solution_callback) | |
except SolverInconsistency: | |
pass | |
self._restore_domains(backed_up_domains) | |
# right branch not_fixed_var != value | |
try: | |
not_fixed_var.domain.remove(value) | |
self._fix_point() | |
self._depth_first_search(solution_callback) | |
except SolverInconsistency: | |
pass | |
if __name__ == "__main__": | |
NUMBER_OF_QUEENS = 8 | |
solver = TinyCSP() | |
queen_vars = [] | |
for i in range(NUMBER_OF_QUEENS): | |
queen_vars.append(solver.make_var(f"Queen{i}", NUMBER_OF_QUEENS)) | |
for i in range(NUMBER_OF_QUEENS): | |
for j in range(i + 1, NUMBER_OF_QUEENS): | |
# Queens not on same row | |
solver.make_non_equality(queen_vars[i], queen_vars[j], 0) | |
# Queens not on same left diagonal | |
solver.make_non_equality(queen_vars[i], queen_vars[j], i - j) | |
# Queens not on same right diagonal | |
solver.make_non_equality(queen_vars[i], queen_vars[j], j - i) | |
solver.solve(lambda x: print(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment