-
-
Save torchlight/ace4f5fee81104ff693d to your computer and use it in GitHub Desktop.
Sudoku
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Sudoku.""" | |
import random | |
from itertools import combinations | |
def complement(A,B): | |
"""Return A\\B.""" | |
return list(filter(lambda a:a not in B,A)) | |
def intersect(A,B): | |
"""Return A \\cap B.""" | |
return list(filter(A.__contains__,B)) | |
def dot(A,B): | |
"""Dot-multiply two sequences, assuming len(A) == len(B).""" | |
s = 0 | |
for i in range(len(A)): | |
s += A[i]*B[i] | |
return s | |
def shuffle(A): | |
"""Return a shuffled copy of A.""" | |
A = list(A) | |
random.shuffle(A) | |
return A | |
def box(xpitch,ypitch,start): | |
""" | |
Return a list of indices, start + x*xpitch + y*ypitch, with x, y going from | |
0 to 2. For example, the indices of the top-left 3x3 box may be obtained | |
with box(1,9,0). | |
""" | |
return [start+x*xpitch+y*ypitch for y in range(3) for x in range(3)] | |
def gen_groups(standard=True,diagonals=False,hyper=False,disjoint=False): | |
"""Generate a list of groups for which the one rule should be satisfied.""" | |
groups = [] | |
if standard: | |
groups += [box(1,3,x) for x in box(9,27,0)] # rows, | |
groups += [box(9,27,x) for x in box(1,3,0)] # columns, | |
groups += [box(1,9,x) for x in box(3,27,0)] # and boxes | |
if diagonals: | |
groups += [box(10,30,0),box(8,24,8)] | |
if hyper: | |
groups += [box(1,9,x) for x in (10,14,46,50)] | |
# the next five groups are implied by the standard groups along with the | |
# above four, but we include them anyway to make the autosolvers' jobs | |
# easier. | |
groups += [box(1,36,5),box(1,36,5),box(4,9,9),box(4,9,45),box(4,36,0)] | |
if disjoint: | |
# sudokuwiki calls this colour sudoku; some other sites call it | |
# "disjoint groups" | |
groups += [box(3,27,x) for x in box(1,9,0)] | |
return groups | |
def convert(g): | |
"""Convert a grid from a flat list to the internal format.""" | |
grid = [0]*81 | |
for i in range(81): | |
grid[i] = [i] | |
if g[i]: | |
grid[i] += [False]*9 | |
grid[i][g[i]] = True | |
else: | |
grid[i] += [True]*9 | |
return grid | |
def convert_str(s): | |
"""Convert a grid from a string representation to the internal format.""" | |
g = [0]*81 | |
i = 0 | |
for c in s: | |
if c in '._': | |
c = '0' | |
if '0' <= c <= '9': | |
g[i] = int(c) | |
i += 1 | |
return convert(g) | |
def unconvert(grid,hints=False): | |
"""Convert a grid from the internal format to a flat list.""" | |
g = [0]*81 | |
for i in range(81): | |
s = sum(grid[i][1:]) | |
if s == 1 or hints: | |
ss = 0 | |
for v in range(1,10): | |
if grid[i][v]: | |
ss = ss*10+v | |
g[i] = ss | |
else: | |
g[i] = 0 | |
return g | |
def print_grid(grid): | |
"""Print the grid in a fixed-width format.""" | |
ss = '' | |
for i in range(9): | |
for j in range(9): | |
s = '' | |
for v in range(1,10): | |
if grid[i*9+j][v]: | |
s += str(v) | |
ss += s.ljust(10) | |
ss += '\n' | |
print(ss) | |
def copy_grid(grid): | |
return [c[:] for c in grid] | |
def count_candidates(grid): | |
return sum(map(lambda c:sum(c[1:]),grid)) | |
def elim(grid,groups): | |
"""Eliminate conflicting candidates.""" | |
grid = copy_grid(grid) | |
changed = True | |
while changed: | |
changed = False | |
for group in groups: | |
for cell in group: | |
if sum(grid[cell][1:]) != 1: | |
continue | |
v = dot(grid[cell],range(10)) | |
for cell_ in group: | |
if cell == cell_ or not grid[cell_][v]: | |
continue | |
grid[cell_][v] = False | |
changed = True | |
return grid | |
def elim_multi(grid,groups,n=4): | |
"""Eliminate candidates conflicting naked n-tuples.""" | |
if n == 1: | |
return elim(grid,groups) | |
grid = elim_multi(grid,groups,n-1) | |
changed = True | |
while changed: | |
changed = False | |
check = [2 <= sum(grid[cell][1:]) <= n for cell in range(81)] | |
for group in groups: | |
indices = filter(check.__getitem__,group) | |
for cells in combinations(indices,n): | |
flags = [False]*10 | |
for cell in cells: | |
for i in range(1,10): | |
flags[i] = flags[i] or grid[cell][i] | |
if sum(flags) < n: | |
return grid # not solvable | |
elif sum(flags) > n: | |
continue # not a naked n-tuple | |
for cell in group: | |
if cell in cells: | |
continue | |
for v in range(1,10): | |
if flags[v] and grid[cell][v]: | |
grid[cell][v] = False | |
changed = True | |
return grid | |
def scan(grid,groups): | |
"""Eliminate candidates by hidden singles.""" | |
grid = copy_grid(grid) | |
changed = True | |
while changed: | |
changed = False | |
for group in groups: | |
for v in range(1,10): | |
if sum((grid[cell][v] for cell in group)) != 1: | |
continue | |
for cell in group: | |
if grid[cell][v]: | |
if sum(grid[cell][1:]) == 1: | |
break | |
changed = True | |
grid[cell][1:] = [False]*9 | |
grid[cell][v] = True | |
break | |
return grid | |
def scan_multi(grid,groups,n=4): | |
"""Eliminate candidates by hidden n-tuples.""" | |
if n == 1: | |
return scan(grid,groups) | |
grid = scan_multi(grid,groups,n-1) | |
changed = True | |
while changed: | |
changed = False | |
for group in groups: | |
for V in combinations(range(1,10),n): | |
cells = [] | |
for cell in group: | |
if sum((grid[cell][v] for v in V)) > 0: | |
cells.append(cell) | |
if len(cells) < n: | |
return grid # not solvable | |
elif len(cells) > n: | |
continue # not a hidden n-tuple | |
for cell in cells: | |
if sum(grid[cell][1:]) == sum((grid[cell][v] for v in V)): | |
break | |
changed = True | |
for v in range(1,10): | |
if v not in V: | |
grid[cell][v] = False | |
return grid | |
def ir(grid,groups): | |
""" | |
Eliminate candidates by intersection removal. This includes pointing | |
n-tuples and box/line reduction under the standard rules, but is generalised | |
to arbitrary pairs of regions. | |
""" | |
grid = copy_grid(grid) | |
changed = True | |
while changed: | |
changed = False | |
for (group0,group1) in combinations(groups,2): | |
intersection = intersect(group0,group1) | |
if len(intersection) < 2: | |
continue # we need at least two common cells | |
flags = [None] + [False]*9 | |
for cell in intersection: | |
for i in range(1,10): | |
flags[i] = flags[i] or grid[cell][i] | |
candidates = list(filter(flags.__getitem__,range(1,10))) | |
for (groupa,groupb) in ((group0,group1),(group1,group0)): | |
V = [] | |
for v in candidates: | |
skip = False | |
for cell in groupa: | |
if cell in intersection: | |
continue | |
if grid[cell][v]: | |
skip = True | |
break | |
if skip: | |
continue | |
V.append(v) | |
for v in V: | |
for cell in groupb: | |
if cell in intersection or not grid[cell][v]: | |
continue | |
grid[cell][v] = False | |
changed = True | |
return grid | |
def solve_logic(grid,groups,methods=[elim,scan],verbose=False): | |
""" | |
Attempt to solve a puzzle using a specified list of methods, defaulting to | |
naked/hidden singles. | |
""" | |
oldcount = 0 | |
count = count_candidates(grid) | |
if verbose: | |
print(count) | |
while count != oldcount: | |
for method in methods: | |
grid = method(grid,groups) | |
oldcount = count | |
count = count_candidates(grid) | |
if verbose: | |
print(count) | |
return grid | |
def solve_weak(grid,groups): | |
"""Attempt to solve a puzzle using only naked/hidden singles.""" | |
oldcount = 0 | |
count = count_candidates(grid) | |
while count != oldcount: | |
grid = scan(elim(grid,groups),groups) | |
oldcount = count | |
count = count_candidates(grid) | |
return grid | |
def solve_basic(grid,groups): | |
""" | |
Attempt to solve a puzzle using the basic techniques; this includes naked/ | |
hidden n-tuples and intersection removal. | |
""" | |
return solve_logic(grid,groups,[elim_multi,scan_multi,ir]) | |
def check_validity(grid,groups): | |
"""Check that the grid is sensible without solving it.""" | |
for c in grid: | |
s = sum(c[1:]) | |
if s == 0: # cell has no candidates | |
return False | |
elif s == 1: # cell has one candidate; check that it doesn't conflict | |
v = dot(c,range(10)) | |
for group in groups: | |
if c[0] not in group: | |
continue | |
for cell in group: | |
if c[0] >= cell: | |
continue | |
if sum(grid[cell][1:]) == 1 and grid[cell][v]: | |
return False | |
return True | |
def solve_full(grid,groups,deterministic=False,order=None,maxdepth=0): | |
""" | |
Solve a puzzle, falling back on brute force if necessary. For brute-forcing, | |
order specifies the order in which cells should be attempted, and maxdepth | |
specifies the maximum search depth plus one. | |
""" | |
grid = solve_weak(grid,groups) | |
if not check_validity(grid,groups): | |
return None # invalid | |
if count_candidates(grid) == 81: | |
return grid # solved | |
if maxdepth == 1: | |
return grid # max depth reached | |
if order is None: | |
order = list(range(81)) | |
if not deterministic: | |
random.shuffle(order) | |
if len(order) == 0: | |
return grid # no cells left to check | |
cell,order = order[0],order[1:] | |
while sum(grid[cell][1:]) == 1: | |
if len(order) == 0: | |
return grid # no unsolved cells left to check | |
cell,order = order[0],order[1:] | |
V = list(range(1,10)) | |
if not deterministic: | |
random.shuffle(V) | |
for v in V: | |
if not grid[cell][v]: | |
continue | |
g = copy_grid(grid) | |
g[cell][1:] = [False]*9 | |
g[cell][v] = True | |
g = solve_full(g,groups,deterministic,order,maxdepth and maxdepth-1) | |
if g: | |
return g | |
return None # no solution | |
def check_unique_solution(grid,groups,refgrid=None,deterministic=False): | |
""" | |
Check that a puzzle has a unique solution. If refgrid is specified, try to | |
find a solution that is not refgrid, assuming that refgrid is a solution. | |
""" | |
grid = solve_weak(grid,groups) | |
if not check_validity(grid,groups): | |
return False | |
if count_candidates(grid) == 81: | |
return True | |
if refgrid is None: | |
refgrid = solve_full(grid,groups,deterministic) | |
if refgrid is None: | |
return False | |
for cell in range(81): | |
V = list(range(1,10)) | |
if not deterministic: | |
random.shuffle(V) | |
for v in V: | |
if refgrid[cell][v] or not grid[cell][v]: | |
continue | |
g = copy_grid(grid) | |
g[cell][1:] = [False]*9 | |
g[cell][v] = True | |
g = solve_full(g,groups,deterministic) | |
if g is not None: | |
return False | |
return True | |
def gen_standard(): | |
"""Generate a random filled sudoku grid.""" | |
# populate the three diagonal boxes, then try to find a random solution; | |
# this is moderately efficient (slightly faster than a row-by-row method) | |
# and has pretty much negligible bias, but can occasionally take thirty | |
# times as long as usual to generate a grid if the RNG hates you. | |
while True: | |
g = [0]*81 | |
for j in (0,30,60): | |
V = shuffle(range(1,10)) | |
for i in box(1,9,j): | |
g[i] = V[-1] | |
V = V[:-1] | |
grid = solve_full(convert(g),gen_groups()) | |
if grid is not None: | |
return grid | |
def add_holes(grid,groups,hard=True,order=None,verbose=False): | |
"""Add blanks to a grid to make it a puzzle.""" | |
hgrid = copy_grid(grid) | |
h = order or shuffle(range(81)) | |
if verbose: | |
print(h) | |
n = 0 | |
for i in h: | |
ngrid = copy_grid(hgrid) | |
ngrid[i][1:] = [True]*9 | |
sgrid = solve_basic(ngrid,groups) | |
if sum(sgrid[i][1:]) == 1 or (hard and check_unique_solution(sgrid,groups,grid)): | |
n += 1 | |
hgrid = ngrid | |
if verbose: | |
print('added at %d (total %d)'%(i,n)) | |
continue | |
elif verbose: | |
print('rejected add at %d'%i) | |
return hgrid |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment