Created
June 30, 2019 12:16
-
-
Save lan496/f391b251151ee1f3690491f0bbe8176e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pulp | |
class SudokuSolver: | |
# ref: https://pythonhosted.org/PuLP/CaseStudies/a_sudoku_problem.html | |
def __init__(self, cell_initial): | |
self.cell_initial = cell_initial | |
self.rows = [i for i in range(1, 9 + 1)] | |
self.cols = [i for i in range(1, 9 + 1)] | |
self.vals = [i for i in range(1, 9 + 1)] | |
# The boxes list is created, with the row and column index of each square in each box | |
self.boxes = [] | |
for i in range(3): | |
for j in range(3): | |
self.boxes.append([(self.rows[3 * i + k], self.cols[3 * j + l]) | |
for k in range(3) for l in range(3)]) | |
def solve(self): | |
# The prob variable is created to contain the problem data | |
prob = pulp.LpProblem("Sudoku", pulp.LpMinimize) | |
# The arbitrary objective function is added | |
prob += 0, "Arbitrary Objective Function" | |
# The problem variables are created | |
# (name, indexs, lowbound, upbound, cat) | |
# choices[i, j , k] == 1 iff cell[i, j] == k | |
choices = pulp.LpVariable.dicts("Choice", (self.rows, self.cols, self.vals), | |
0, 1, pulp.LpInteger) | |
# add constraints | |
constraints = self.get_constraints(choices) | |
for const in constraints: | |
prob += const | |
# solve | |
result_status = prob.solve() | |
# The status of the solution is printed to the screen | |
print("Status:", pulp.LpStatus[prob.status]) | |
return self.choices_to_cell(choices) | |
def enumerate(self): | |
# The prob variable is created to contain the problem data | |
prob = pulp.LpProblem("Sudoku", pulp.LpMinimize) | |
# The arbitrary objective function is added | |
prob += 0, "Arbitrary Objective Function" | |
# The problem variables are created | |
# (name, indexs, lowbound, upbound, cat) | |
# choices[i, j , k] == 1 iff cell[i, j] == k | |
choices = pulp.LpVariable.dicts("Choice", (self.rows, self.cols, self.vals), | |
0, 1, pulp.LpInteger) | |
# add constraints | |
constraints = self.get_constraints(choices) | |
for const in constraints: | |
prob += const | |
# loop while search all solutions | |
all_cells = [] | |
while True: | |
prob.solve() | |
print("Status:", pulp.LpStatus[prob.status]) | |
if pulp.LpStatus[prob.status] == "Optimal": | |
all_cells.append(self.choices_to_cell(choices)) | |
else: | |
break | |
# The constraint is added that the same solution cannot be returned again | |
prob += pulp.lpSum([choices[r][c][v] for v in self.vals | |
for r in self.rows | |
for c in self.cols | |
if pulp.value(choices[r][c][v]) == 1]) <= 80 | |
return all_cells | |
def get_constraints(self, choices): | |
constraints = [] | |
# restrict to overlap different numbers in the same position | |
for r in self.rows: | |
for c in self.cols: | |
constraints.append((pulp.lpSum([choices[r][c][v] for v in self.vals]) == 1, "")) | |
# The row, column and box constraints are added for each value | |
for v in self.vals: | |
for r in self.rows: | |
constraints.append((pulp.lpSum([choices[r][c][v] for c in self.cols]) == 1, "")) | |
for c in self.cols: | |
constraints.append((pulp.lpSum([choices[r][c][v] for r in self.rows]) == 1, "")) | |
for b in self.boxes: | |
constraints.append((pulp.lpSum([choices[r][c][v] for (r, c) in b]) == 1, "")) | |
# The starting numbers are entered as constraints | |
for r in self.rows: | |
for c in self.cols: | |
v = cell_initial[r - 1][c - 1] | |
if v == 0: | |
continue | |
constraints.append((choices[r][c][v] == 1, "")) | |
return constraints | |
def choices_to_cell(self, choices): | |
cell = [[0 for _ in self.rows] for _ in self.cols] | |
for r in self.rows: | |
for c in self.cols: | |
for v in self.vals: | |
if pulp.value(choices[r][c][v]) == 1: | |
cell[r - 1][c - 1] = v | |
return cell | |
def pretty_cell_str(cell): | |
ret = "" | |
for i in range(9): | |
if i % 3 == 0: | |
ret += "+-------+-------+-------+\n" | |
for j in range(9): | |
v = cell[i][j] | |
if j % 3 == 0: | |
ret += "|" | |
ret += " " | |
if v == 0: | |
ret += " " | |
else: | |
ret += str(v) + " " | |
if j == 8: | |
ret += "|\n" | |
ret += "+-------+-------+-------+\n" | |
return ret | |
if __name__ == '__main__': | |
# 0 means not known | |
cell_initial = [ | |
[0, 3, 0, 0, 7, 0, 0, 0, 0], | |
[0, 0, 0, 1, 9, 5, 0, 0, 0], | |
[0, 9, 8, 0, 0, 0, 0, 6, 0], | |
[0, 0, 0, 0, 6, 0, 0, 0, 3], | |
[0, 0, 0, 8, 0, 3, 0, 0, 1], | |
[7, 0, 0, 0, 2, 0, 0, 0, 6], | |
[0, 6, 0, 0, 0, 0, 2, 8, 0], | |
[0, 0, 0, 4, 1, 9, 0, 0, 5], | |
[0, 0, 0, 0, 8, 0, 0, 7, 9], | |
] | |
print('initial') | |
print(pretty_cell_str(cell_initial)) | |
ss = SudokuSolver(cell_initial) | |
# cell = ss.solve() | |
# print(pretty_cell_str(cell)) | |
all_cells = ss.enumerate() | |
for cell in all_cells: | |
print(pretty_cell_str(cell)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment