Skip to content

Instantly share code, notes, and snippets.

@elit8888
Created June 2, 2020 15:23
Show Gist options
  • Save elit8888/4f877b78d98b4dee6b124847d7a63942 to your computer and use it in GitHub Desktop.
Save elit8888/4f877b78d98b4dee6b124847d7a63942 to your computer and use it in GitHub Desktop.
Solve Sudoku (backtrack)
import unittest
from typing import List, Tuple, Set
class Solution:
def solve(self, puzzle: List[List[int]]) -> List[List[int]]:
self.columns = [set() for _ in range(9)]
self.rows = [set() for _ in range(9)]
self.squares = [[set() for _ in range(3)] for _ in range(3)]
self.res = [row[:] for row in puzzle]
self.remains = []
for i, row in enumerate(puzzle):
for j, elem in enumerate(row):
if elem > 0:
self.fill_in((i, j), elem)
self.res[i][j] = elem
else:
self.remains.append((i, j))
if self._solve():
return self.res
raise RuntimeError('The puzzle cannot be solved')
def len_available(self, point: Tuple[int, int]) -> int:
i, j = point
return 9 - len(self.columns[j] | self.rows[i] | self.squares[i // 3][j // 3])
def available_val(self, point: Tuple[int, int]) -> Set:
i, j = point
return set(range(1, 10)) - (self.columns[j] | self.rows[i] | self.squares[i // 3][j // 3])
def fill_in(self, point: Tuple[int, int], val: int) -> None:
i, j = point
self.columns[j].add(val)
self.rows[i].add(val)
self.squares[i // 3][j // 3].add(val)
def discard(self, point: Tuple[int, int], val: int) -> None:
i, j = point
self.columns[j].discard(val)
self.rows[i].discard(val)
self.squares[i // 3][j // 3].discard(val)
def _solve(self):
if len(self.remains) == 0:
return True
self.remains.sort(key=self.len_available)
i, j = self.remains.pop(0)
for candidate in self.available_val((i, j)):
self.res[i][j] = candidate
self.fill_in((i, j), candidate)
if self._solve():
return True
self.discard((i, j), candidate)
self.remains.append((i, j))
return False
class SudokuTest(unittest.TestCase):
def puzzle_transform(self, s: List[str]) -> List[List[int]]:
return [[int(c) for c in row] for row in s]
def test_example(self):
puzzle = ['530070000',
'600195000',
'098000060',
'800060003',
'400803001',
'700020006',
'060000280',
'000419005',
'000080079']
ans = ['534678912',
'672195348',
'198342567',
'859761423',
'426853791',
'713924856',
'961537284',
'287419635',
'345286179']
self.assertEqual(
Solution().solve(self.puzzle_transform(puzzle)),
self.puzzle_transform(ans))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment