Skip to content

Instantly share code, notes, and snippets.

@yi-jiayu
Created March 11, 2019 08:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yi-jiayu/ef2d83e26db5fd22fb3a86356df8076d to your computer and use it in GitHub Desktop.
Save yi-jiayu/ef2d83e26db5fd22fb3a86356df8076d to your computer and use it in GitHub Desktop.
A simple backtracking sudoku solver
def rows(puzzle):
return [puzzle[9 * i:9 * i + 9] for i in range(9)]
def columns(puzzle):
return [puzzle[i::9] for i in range(9)]
def boxes(puzzle):
return [puzzle[i:i + 3] +
puzzle[i + 9:i + 9 + 3] +
puzzle[i + 9 + 9:i + 9 + 9 + 3] for i in
(j * 3 + k * 27 for j in range(3) for k in range(3))]
def valid_group(group):
digits = [d for d in group if d != '.']
return len(set(digits)) == len(digits)
def valid(puzzle):
return all(valid_group(group) for group in rows(puzzle)) and all(
valid_group(group) for group in columns(puzzle)) and all(
valid_group(group) for group in boxes(puzzle))
def solved(puzzle):
return valid(puzzle) and '.' not in puzzle
def permutations(puzzle):
return [puzzle.replace('.', f'{i + 1}', 1) for i in range(9)]
def solve(puzzle):
stack = [puzzle]
while stack:
curr = stack.pop()
if solved(curr):
return curr
stack.extend(perm for perm in permutations(curr) if valid(perm))
def valid_input(inp):
return len(inp) == 81 and not set(inp) - set('.123456789')
def main():
puzzle = input()
if not valid_input(puzzle):
print('Invalid input!')
return
solved_puzzle = solve(puzzle)
if not solved_puzzle:
print('No solution!')
return
print(solved_puzzle)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment