import time | |
time_start = time.time() | |
import numpy as np | |
import chainer | |
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils | |
from chainer import Link, Chain, ChainList | |
import chainer.functions as F | |
import chainer.links as L | |
import sys | |
import copy | |
import commands | |
class MyChain(Chain): | |
def __init__(self): | |
super(MyChain, self).__init__( | |
l1 = L.Linear(65, 256), | |
l2 = L.Linear(256, 256), | |
l3 = L.Linear(256, 1), | |
) | |
def __call__(self, x, y): | |
return F.mean_squared_error(self.fwd(x), y) | |
def fwd(self, x): | |
h = F.sigmoid(self.l1(x)) | |
h2 = F.sigmoid(self.l2(h)) | |
h3 = self.l3(h2) | |
return h3 | |
class Valid: | |
def __init__(self): | |
self.model = MyChain() | |
serializers.load_hdf5('sigMyChain0.model', self.model) | |
return | |
def validate(self, nowboard, now): | |
test_x = [] | |
xx = [] | |
for i in range(8): | |
for j in range(8): | |
x = 0.0 | |
if nowboard[i][j] == 1: | |
x = 1.0 | |
elif nowboard[i][j] == 0: | |
x = -1.0 | |
xx.append(x) | |
if now == 1: | |
xx.append(1.0) | |
else: | |
xx.append(-1.0) | |
test_x.append(xx) | |
test_x = np.array(test_x).astype(np.float32) | |
xt = Variable(test_x, volatile='on') | |
val = self.model.fwd(xt) | |
return val.data | |
dx = [ 0, 1, 1, 1, 0,-1,-1,-1] | |
dy = [ 1, 1, 0,-1,-1,-1, 0, 1] | |
board = [[0 for j in range(8)] for i in range(8)] | |
validator = [Valid(), Valid()] | |
def search(board, turn, index): | |
ret = [] | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
continue; | |
newboard = copy.deepcopy(board) | |
ok = False | |
for k in range(8): | |
x = i+dx[k] | |
y = j+dy[k] | |
if x < 0 or x >= 8 or y < 0 or y >= 8: | |
continue | |
if board[x][y] != (turn^1): | |
continue; | |
while True: | |
x += dx[k] | |
y += dy[k] | |
if x < 0 or x >= 8 or y < 0 or y >= 8: | |
break | |
if board[x][y] == (turn^1): | |
continue | |
if board[x][y] == turn: | |
ok = True | |
x -= dx[k] | |
y -= dy[k] | |
while x != i or y != j: | |
newboard[x][y] = turn | |
x -= dx[k] | |
y -= dy[k] | |
newboard[x][y] = turn | |
break | |
if ok == False: | |
continue | |
now_val = validator[index].validate(newboard, turn) | |
ret.append([now_val,i,j]) | |
stone_count = 0 | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
stone_count += 1 | |
ret.sort() | |
if turn == 1: | |
ret.reverse() | |
if len(ret) == 1 or stone_count >= 15: | |
return [ret[0][1],ret[0][2]] | |
if np.random.randint(10) <= 6: | |
return [ret[0][1],ret[0][2]] | |
else: | |
return [ret[1][1],ret[1][2]] | |
def make_net(): | |
for i in range(256): | |
for j in range(65): | |
validator[1].model.l1.W.data[i][j] += np.random.randn()*validator[1].model.l1.W.data[i][j]/10.0 | |
validator[1].model.l1.b.data[i] += np.random.randn()*validator[1].model.l1.b.data[i]/10.0 | |
for i in range(256): | |
for j in range(256): | |
validator[1].model.l2.W.data[i][j] += np.random.randn()*validator[1].model.l2.W.data[i][j]/10.0 | |
validator[1].model.l2.b.data[i] += np.random.randn()*validator[1].model.l2.b.data[i]/10.0 | |
for i in range(1): | |
for j in range(256): | |
validator[1].model.l3.W.data[i][j] += np.random.randn()*validator[1].model.l3.W.data[i][j]/10.0 | |
validator[1].model.l3.b.data[i] += np.random.randn()*validator[1].model.l3.b.data[i]/10.0 | |
def init_board(): | |
for i in range(8): | |
for j in range(8): | |
board[i][j] = 2 | |
board[3][3] = 0 | |
board[3][4] = 1 | |
board[4][3] = 1 | |
board[4][4] = 0 | |
def change_board(x, y, turn): | |
board[x][y] = turn | |
for i in range(8): | |
xx = x+dx[i] | |
yy = y+dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
continue | |
if board[xx][yy] != (turn^1): | |
continue | |
while True: | |
xx += dx[i] | |
yy += dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
break | |
if board[xx][yy] == (turn^1): | |
continue | |
if board[xx][yy] == turn: | |
xx -= dx[i] | |
yy -= dy[i] | |
while xx != x or yy != y: | |
board[xx][yy] = turn | |
xx -= dx[i] | |
yy -= dy[i] | |
break | |
return | |
def can_put(x, y, turn): | |
if board[x][y] != 2: | |
return False | |
for i in range(8): | |
xx = x+dx[i] | |
yy = y+dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
continue | |
if board[xx][yy] != (turn^1): | |
continue | |
while True: | |
xx += dx[i] | |
yy += dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
break | |
if board[xx][yy] == (turn^1): | |
continue | |
if board[xx][yy] == turn: | |
return True | |
break; | |
return False | |
def change_turn(turn): | |
turn ^= 1 | |
for i in range(8): | |
for j in range(8): | |
if can_put(i, j, turn) == True: | |
return 1 | |
turn ^= 1 | |
for i in range(8): | |
for j in range(8): | |
if can_put(i, j, turn) == True: | |
return 2 | |
return 0 | |
if __name__ == '__main__': | |
make_net() | |
MATCH_TIMES = 100 | |
UPDATE_WIN = 80 | |
first_win = 0 | |
second_win = 0 | |
for train_time in range(MATCH_TIMES): | |
init_board() | |
turn = 1 | |
while True: | |
if turn == train_time%2: | |
hand = search(board, turn, 0) | |
change_board(hand[0], hand[1], turn) | |
tmp = change_turn(turn) | |
if tmp == 0: | |
break | |
elif tmp == 1: | |
turn ^= 1 | |
else: | |
hand = search(board, turn, 1) | |
change_board(hand[0], hand[1], turn) | |
tmp = change_turn(turn) | |
if tmp == 0: | |
break | |
elif tmp == 1: | |
turn ^= 1 | |
empty_area = 64 | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
empty_area -= 1 | |
if empty_area <= 15: | |
break | |
''' | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] == 1: | |
sys.stdout.write('*') | |
elif board[i][j] == 2: | |
sys.stdout.write('_') | |
else: | |
sys.stdout.write('o') | |
print('') | |
print('') | |
''' | |
black = 0L | |
white = 0L | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] == 1: | |
black += (1L<<(i*8+j)) | |
elif board[i][j] == 0: | |
white += (1L<<(i*8+j)) | |
f = open('input.txt', 'w') | |
f.write(str(black) + ' ' + str(white) + ' ' + str(turn) + '\n') | |
f.close() | |
result = commands.getoutput("./complete_read/a.out < input.txt") | |
if result == "black": | |
if train_time%2 == 1: | |
first_win += 1 | |
else: | |
second_win += 1 | |
elif result == "white": | |
if train_time%2 == 1: | |
second_win += 1 | |
else: | |
first_win += 1 | |
if first_win > MATCH_TIMES-UPDATE_WIN: | |
break | |
print "final " + str(first_win) + " " + str(second_win) | |
if second_win >= UPDATE_WIN: | |
serializers.save_hdf5('sigMyChain0.model', validator[1].model) | |
print "updated!" | |
time_end = time.time() | |
sys.stderr.write("train.py : " + str(time_end-time_start) + "\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment