Skip to content

Instantly share code, notes, and snippets.

@skyone-wzw
Created May 18, 2024 20:18
Show Gist options
  • Save skyone-wzw/98de9ed6ab68ddf0d5381bd33bbb2857 to your computer and use it in GitHub Desktop.
Save skyone-wzw/98de9ed6ab68ddf0d5381bd33bbb2857 to your computer and use it in GitHub Desktop.
使用 PuLP 线性规划求解器求解数独问题。Solve Sudoku problems using the PuLP linear programming solver.
import pulp
import re
input_string = """
+-------+-------+-------+
| x x x | 6 5 x | x x x |
| x x 8 | 4 x x | 9 1 x |
| x x x | x x x | x 8 3 |
+-------+-------+-------+
| x 6 x | x x x | x x x |
| 3 5 x | x x x | x x 1 |
| x 1 9 | 7 x x | x 5 x |
+-------+-------+-------+
| x x 7 | x x x | x x 2 |
| 1 x x | x x x | x x x |
| x x x | x 2 9 | 7 6 x |
+-------+-------+-------+
"""
# 数字,行,列的范围都是1-9
VALS = ROWS = COLS = range(1, 10)
# 构造3*3的小方格
Boxes = [
[(3 * i + k + 1, 3 * j + l + 1) for k in range(3) for l in range(3)]
for i in range(3)
for j in range(3)
]
# 建立问题
prob = pulp.LpProblem("Sudoku Problem")
# 建立决策变量
choices = pulp.LpVariable.dicts("X", (VALS, ROWS, COLS), cat="Binary")
# 约束:每个格子只能填入一个数字
for r in ROWS:
for c in COLS:
prob += pulp.lpSum([choices[v][r][c] for v in VALS]) == 1
# 约束:每行、每列、每个3*3小方格1-9不重复
for v in VALS:
for r in ROWS:
prob += pulp.lpSum([choices[v][r][c] for c in COLS]) == 1
for c in COLS:
prob += pulp.lpSum([choices[v][r][c] for r in ROWS]) == 1
for b in Boxes:
prob += pulp.lpSum([choices[v][r][c] for (r, c) in b]) == 1
all_matches = re.finditer(r'(x|\d)', input_string)
input_data = [(int(match.group(1)), i // 9 + 1, i % 9 + 1)
for i, match in enumerate(all_matches) if match.group(1) != 'x']
# 约束:初始状态填入的数字
for (v, r, c) in input_data:
prob += choices[v][r][c] == 1
while True:
prob.solve()
# 输出是否有解的状态
print("Status:", pulp.LpStatus[prob.status])
# 有解就打印出来
if pulp.LpStatus[prob.status] == "Optimal":
text = ""
for r in ROWS:
if r in [1, 4, 7]:
text += "+-------+-------+-------+\n"
for c in COLS:
for v in VALS:
if pulp.value(choices[v][r][c]) == 1:
if c in [1, 4, 7]:
text += "| "
text += str(v) + " "
if c == 9:
text += "|\n"
text += "+-------+-------+-------+"
print(text)
# 添加约束,当前的解不能再次出现
# 因为一共有81个格子,一个解法所有的格子变量加起来应该等于81
# 所以这里的<= 80就是将这个解排除了。
prob += (
pulp.lpSum(
[
choices[v][r][c]
for v in VALS
for r in ROWS
for c in COLS
if pulp.value(choices[v][r][c]) == 1
]
)
<= 80
)
# 没有其他解法了,退出
else:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment