Skip to content

Instantly share code, notes, and snippets.

@SiLiKhon
Created March 8, 2020 13:11
Show Gist options
  • Save SiLiKhon/ac2d3565db8a2bd31dbd9ff3d0c8eab7 to your computer and use it in GitHub Desktop.
Save SiLiKhon/ac2d3565db8a2bd31dbd9ff3d0c8eab7 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.widgets as w
from scipy.signal import convolve2d
def check(state):
base = [1] * 5
filters = [
np.array(base).reshape(5, 1),
np.array(base).reshape(1, 5),
np.diag(base),
np.diag(base)[::-1]
]
checks = [convolve2d(state, f, mode='valid') for f in filters]
potential = np.zeros_like(state)
def nonlin(num):
return ((np.abs(num) / 3) ** (np.where(num > 0, 4, 4.1)) * 100).astype(int)
for i in range(len(base)):
potential += np.pad(nonlin(checks[0]), ((i, len(base) - 1 - i), (0, 0)), mode='constant')
potential += np.pad(nonlin(checks[1]), ((0, 0), (i, len(base) - 1 - i)), mode='constant')
potential += np.pad(nonlin(checks[2]), ((i, len(base) - 1 - i), (i, len(base) - 1 - i)), mode='constant')
potential += np.pad(nonlin(checks[3]), ((len(base) - 1 - i, i), (i, len(base) - 1 - i)), mode='constant')
potential[state != 0] = -1000
#print(potential)
best_y, best_x = np.unravel_index(potential.ravel().argsort()[-1:], potential.shape)
return (sum([(c == len(base)).sum() for c in checks]),
sum([(c == -len(base)).sum() for c in checks]),
(best_y, best_x))
fig = plt.figure(figsize=(5, 5))
ax = plt.axes([0., 0., 1., 1.])
h, w = 20, 20
state = np.zeros((h, w), dtype=int)
rects = [plt.Rectangle((ix / w, iy / h), 1. / w, 1 / h, facecolor='white')
for iy, row in enumerate(state) for ix, _ in enumerate(row)]
for r in rects:
ax.add_patch(r)
for iw in range(w + 1):
plt.plot([iw / w, iw / w], [0., 1.], lw=1, c='black')
for ih in range(h + 1):
plt.plot([0., 1.], [ih / h, ih / h], lw=1, c='black')
def onclick(event):
if event.xdata is None or event.ydata is None:
return
x, y = event.xdata, event.ydata
if x <= 0 or x >= 1: return
if y <= 0 or y >= 1: return
ix = np.floor(x * w).astype(int)
iy = np.floor(y * h).astype(int)
if state[iy, ix] != 0:
return
state[iy, ix] = 1
win, lose, (best_y, best_x)= (check(state))
print(win, lose)
assert state[best_y, best_x] == 0
state[best_y, best_x] = -1
for iy, row in enumerate(state):
for ix, val in enumerate(row):
rects[iy * w + ix].set_facecolor('white' if val == 0 else ('red' if val > 0 else 'blue'))
plt.draw()
fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment