Last active
December 23, 2015 18:29
-
-
Save stdk/6676151 to your computer and use it in GitHub Desktop.
Knut's Algorithm X for python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from numpy import matrix,array | |
| import numpy as np | |
| from time import clock | |
| def has_overlap(m): | |
| for i in xrange(m.shape[0]): | |
| for j in xrange(m.shape[1]): | |
| if m[i,j] > 1: | |
| return True | |
| return False | |
| def place_figure(field,figure): | |
| for i in range(field.shape[0]): | |
| for j in range(field.shape[1]): | |
| subfield = field[i:i+figure.shape[0],j:j+figure.shape[1]] | |
| if subfield.shape != figure.shape: continue | |
| for rot in xrange(4): | |
| current = np.rot90(figure,rot) | |
| subfield += current | |
| if not has_overlap(subfield): | |
| yield (i,j,rot) | |
| subfield -= current | |
| DEBUG = False | |
| def solve_base_row(base,row): | |
| if DEBUG: print 'solve_base_row',row | |
| active_cols = [col for col,value in enumerate(base[row]) if value] | |
| #print base[row].nonzero() #slower | |
| #print '1', np.sum(base[:,active_cols],axis = 1) # slower | |
| rows_left = [k for k,i in enumerate(sum([base[:,col] for col in active_cols])) if i == 0] | |
| cols_left = [col for col in xrange(base.shape[1]) if not col in active_cols] | |
| if DEBUG: | |
| print 'rows_left',rows_left | |
| print 'cols_left',cols_left | |
| if not cols_left: yield [active_cols] | |
| if not rows_left: return | |
| for partial in solve_base(base[:,cols_left][rows_left,:]): | |
| yield [active_cols] + [[cols_left[c] for c in s] for s in partial] | |
| def solve_base(base): | |
| if DEBUG: | |
| input() | |
| print base | |
| selected_col = min([(sum(base[:,i]),i) for i in xrange(base.shape[1])])[1] | |
| if DEBUG: print 'selected_col',selected_col | |
| selected_rows = [row for row,value in enumerate(base[:,selected_col]) if value] | |
| for row in selected_rows: | |
| for solution in solve_base_row(base,row): | |
| yield solution | |
| def show_solution(field,base_solution): | |
| solution = np.zeros(field.shape[0]*field.shape[1]) | |
| def get_num(solution_line): | |
| return [i - len(solution) + 1 for i in solution_line if i >= len(solution)] | |
| [[solution.__setitem__(j,num) if j < len(solution) else None for j in base_solution[u]] | |
| for u in xrange(len(base_solution)) | |
| for num in get_num(base_solution[u]) | |
| ] | |
| return solution.reshape(*field.shape) | |
| def solve(field,figures): | |
| a = clock() | |
| full = np.array([ np.append(field.flatten(),[1 if k == i else 0 for k in xrange(len(figures))]) | |
| for i,figure in enumerate(figures) | |
| for p in place_figure(field,figure)]) | |
| b = full.view(np.dtype((np.void, full.dtype.itemsize * full.shape[1]))) | |
| idx = np.unique(b, return_index=True)[1] | |
| b = None | |
| unique_full = full[idx] | |
| #print 'unique', clock() - a | |
| #print unique_full | |
| field_flat_len = field.shape[0]*field.shape[1] | |
| columns = [i for i,value in enumerate(field.flatten()) if not value] + [i + field_flat_len for i in xrange(len(figures))] | |
| for b in solve_base(unique_full[:,columns]): | |
| base_solution = [[columns[c] for c in s] for s in b] | |
| #print clock() - a | |
| #print show_solution(field,base_solution) | |
| print clock() - a | |
| figures = [ | |
| array([[1, 1, 1, 1, 1], | |
| [1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 0], | |
| ]), | |
| array([[1, 1], | |
| [1, 1]]), | |
| ] + [ | |
| array([[0, 1], | |
| [1, 1]]), | |
| ] * 4 | |
| field = array([ | |
| [0,0,0,0,0], | |
| [0,0,0,0,0], | |
| [0,0,0,0,0], | |
| [0,0,0,0,0], | |
| [0,0,0,0,0], | |
| ]) | |
| solve(field,figures) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment