Skip to content

Instantly share code, notes, and snippets.

@tamuhey
Last active February 21, 2020 12:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tamuhey/80cbf8b74a9fc565a3bf76092b7c4052 to your computer and use it in GitHub Desktop.
Save tamuhey/80cbf8b74a9fc565a3bf76092b7c4052 to your computer and use it in GitHub Desktop.
量子アニーリングで数独を解く ref: https://qiita.com/tamurahey/items/105f1dc9ee9a3bc01f15
pip install dwave-qbsolv
pip install networkx
g=nx.Graph()
for i in idxs:
for j in peers[i]:
g.add_edge(i,j)
theta=np.arange(n**2)
x=np.cos(2*np.pi*theta/n**2)
y=np.sin(2*np.pi*theta/n**2)
pos={i:(x[i],y[i]) for i in range(n**2)}
nx.draw_networkx(g,pos=pos)
alpha=1. # 適当な係数
qubo=np.zeros((n**2,n,n**2,n)) # quboの準備
qubo[range(n**2),:,range(n**2),:] += alpha*(np.ones((n,n))-2*np.eye(n))
# さっき作ったグラフをつかう
for e in g.edges:
i,j=e
qubo[i,range(n),j,range(n)]+=1
# 問題aの初期値
init_values=parse(a)
beta=10.
i = np.where(init_values>=0)[0]:
j=init_values[i]
qubo[i,j,i,j]=-beta
idx=np.unravel_index(np.arange(n**2*n),(n**2,n)) # idxは後で使う
_Q=qubo[idx[0],idx[1],:,:]
_Q=_Q[:,idx[0],idx[1]] # shape (729, 729)
Q={}
# ゼロでない係数のみ渡す
for i in zip(*np.nonzero(_Q)):
Q[i]=_Q[i]
res=QBSolv().sample_qubo(Q)
def get_answer(init_values, dct):
x=np.zeros((n**2,n))
y=np.zeros(n**3)
y[list(dct.keys())]=np.array(list(dct.values()))
x[idx]=y
values=x @ np.arange(n)
values[init_values>=0]=init_values[init_values>=0]
return values
answer=get_answer(init_values,list(res.samples())[0])
from dwave_qbsolv import QBSolv
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
n=9 # num of rows
def check_answer(answer):
# 埋められていないマス(=-1)があるならば不正解
if np.min(answer)<0:
return False
# すべてのマスがpeersと異なるならば正解
return (answer[peers]-answer[:,None] == 0).sum() == 0
print(check_answer(answer))
def display(values):
x=values.reshape(n,n)
line = "-"*19
for r in range(n):
print(*list(str(int(x[r,c])+1)+('|' if c == 2 or c == 5 else '') for c in range(n)))
if r == 2 or r==5: print(line)
display(answer)
4 8 3| 9 2 1| 6 5 7
9 6 7| 3 4 5| 8 2 1
2 5 1| 8 7 6| 4 9 3
-------------------
5 4 8| 1 3 2| 9 7 6
7 2 9| 5 6 4| 1 3 8
1 3 6| 7 9 8| 2 4 5
-------------------
3 7 2| 6 8 9| 5 1 4
8 1 4| 2 5 3| 7 6 9
6 9 5| 4 1 7| 3 8 2
idxs=np.arange(n**2)
grid=np.arange(n**2).reshape(n,n)
# list of indices of blocks
block=np.stack([grid[i*3:(i+1)*3,j*3:(j+1)*3].flatten() for i in range(3) for j in range(3)])
# list of indices of row, col, block
unitlist=np.concatenate([grid, grid.T,block])
# list of units. shape (81, 3, 9)
def get_units(s):
return unitlist[np.isin(unitlist,s).sum(-1).astype(bool)]
units=np.array([get_units(s) for s in range(n**2)])
# list of peers
def get_peers(s):
a=np.unique(get_units(s))
idx=a!=s
return a[idx].tolist()
# peers[i] indicates peers of i-th square.
peers=np.array([get_peers(s) for s in idxs])
a= '003020600900305001001806400008102900700000008006708200002609500800203009005010300'
digits = np.arange(n)
def parse(grid):
assert len(grid) == 81
return np.array([int(c)-1 if c in (digits+1).astype(str) else -1 for c in grid])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment