Skip to content

Instantly share code, notes, and snippets.

@stdk
Last active December 23, 2015 18:29
Show Gist options
  • Save stdk/6676151 to your computer and use it in GitHub Desktop.
Save stdk/6676151 to your computer and use it in GitHub Desktop.
Knut's Algorithm X for python
from numpy import matrix,array
import numpy as np
from time import clock
def has_overlap(m):
for i in xrange(m.shape[0]):
for j in xrange(m.shape[1]):
if m[i,j] > 1:
return True
return False
def place_figure(field,figure):
for i in range(field.shape[0]):
for j in range(field.shape[1]):
subfield = field[i:i+figure.shape[0],j:j+figure.shape[1]]
if subfield.shape != figure.shape: continue
for rot in xrange(4):
current = np.rot90(figure,rot)
subfield += current
if not has_overlap(subfield):
yield (i,j,rot)
subfield -= current
DEBUG = False
def solve_base_row(base,row):
if DEBUG: print 'solve_base_row',row
active_cols = [col for col,value in enumerate(base[row]) if value]
#print base[row].nonzero() #slower
#print '1', np.sum(base[:,active_cols],axis = 1) # slower
rows_left = [k for k,i in enumerate(sum([base[:,col] for col in active_cols])) if i == 0]
cols_left = [col for col in xrange(base.shape[1]) if not col in active_cols]
if DEBUG:
print 'rows_left',rows_left
print 'cols_left',cols_left
if not cols_left: yield [active_cols]
if not rows_left: return
for partial in solve_base(base[:,cols_left][rows_left,:]):
yield [active_cols] + [[cols_left[c] for c in s] for s in partial]
def solve_base(base):
if DEBUG:
input()
print base
selected_col = min([(sum(base[:,i]),i) for i in xrange(base.shape[1])])[1]
if DEBUG: print 'selected_col',selected_col
selected_rows = [row for row,value in enumerate(base[:,selected_col]) if value]
for row in selected_rows:
for solution in solve_base_row(base,row):
yield solution
def show_solution(field,base_solution):
solution = np.zeros(field.shape[0]*field.shape[1])
def get_num(solution_line):
return [i - len(solution) + 1 for i in solution_line if i >= len(solution)]
[[solution.__setitem__(j,num) if j < len(solution) else None for j in base_solution[u]]
for u in xrange(len(base_solution))
for num in get_num(base_solution[u])
]
return solution.reshape(*field.shape)
def solve(field,figures):
a = clock()
full = np.array([ np.append(field.flatten(),[1 if k == i else 0 for k in xrange(len(figures))])
for i,figure in enumerate(figures)
for p in place_figure(field,figure)])
b = full.view(np.dtype((np.void, full.dtype.itemsize * full.shape[1])))
idx = np.unique(b, return_index=True)[1]
b = None
unique_full = full[idx]
#print 'unique', clock() - a
#print unique_full
field_flat_len = field.shape[0]*field.shape[1]
columns = [i for i,value in enumerate(field.flatten()) if not value] + [i + field_flat_len for i in xrange(len(figures))]
for b in solve_base(unique_full[:,columns]):
base_solution = [[columns[c] for c in s] for s in b]
#print clock() - a
#print show_solution(field,base_solution)
print clock() - a
figures = [
array([[1, 1, 1, 1, 1],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
]),
array([[1, 1],
[1, 1]]),
] + [
array([[0, 1],
[1, 1]]),
] * 4
field = array([
[0,0,0,0,0],
[0,0,0,0,0],
[0,0,0,0,0],
[0,0,0,0,0],
[0,0,0,0,0],
])
solve(field,figures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment