This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
time_start = time.time() | |
import numpy as np | |
import chainer | |
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils | |
from chainer import Link, Chain, ChainList | |
import chainer.functions as F | |
import chainer.links as L | |
import sys | |
import copy | |
import commands | |
class MyChain(Chain): | |
def __init__(self): | |
super(MyChain, self).__init__( | |
l1 = L.Linear(65, 256), | |
l2 = L.Linear(256, 256), | |
l3 = L.Linear(256, 1), | |
) | |
def __call__(self, x, y): | |
return F.mean_squared_error(self.fwd(x), y) | |
def fwd(self, x): | |
h = F.sigmoid(self.l1(x)) | |
h2 = F.sigmoid(self.l2(h)) | |
h3 = self.l3(h2) | |
return h3 | |
class Valid: | |
def __init__(self): | |
self.model = MyChain() | |
serializers.load_hdf5('sigMyChain0.model', self.model) | |
return | |
def validate(self, nowboard, now): | |
test_x = [] | |
xx = [] | |
for i in range(8): | |
for j in range(8): | |
x = 0.0 | |
if nowboard[i][j] == 1: | |
x = 1.0 | |
elif nowboard[i][j] == 0: | |
x = -1.0 | |
xx.append(x) | |
if now == 1: | |
xx.append(1.0) | |
else: | |
xx.append(-1.0) | |
test_x.append(xx) | |
test_x = np.array(test_x).astype(np.float32) | |
xt = Variable(test_x, volatile='on') | |
val = self.model.fwd(xt) | |
return val.data | |
dx = [ 0, 1, 1, 1, 0,-1,-1,-1] | |
dy = [ 1, 1, 0,-1,-1,-1, 0, 1] | |
board = [[0 for j in range(8)] for i in range(8)] | |
validator = [Valid(), Valid()] | |
def search(board, turn, index): | |
ret = [] | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
continue; | |
newboard = copy.deepcopy(board) | |
ok = False | |
for k in range(8): | |
x = i+dx[k] | |
y = j+dy[k] | |
if x < 0 or x >= 8 or y < 0 or y >= 8: | |
continue | |
if board[x][y] != (turn^1): | |
continue; | |
while True: | |
x += dx[k] | |
y += dy[k] | |
if x < 0 or x >= 8 or y < 0 or y >= 8: | |
break | |
if board[x][y] == (turn^1): | |
continue | |
if board[x][y] == turn: | |
ok = True | |
x -= dx[k] | |
y -= dy[k] | |
while x != i or y != j: | |
newboard[x][y] = turn | |
x -= dx[k] | |
y -= dy[k] | |
newboard[x][y] = turn | |
break | |
if ok == False: | |
continue | |
now_val = validator[index].validate(newboard, turn) | |
ret.append([now_val,i,j]) | |
stone_count = 0 | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
stone_count += 1 | |
ret.sort() | |
if turn == 1: | |
ret.reverse() | |
if len(ret) == 1 or stone_count >= 15: | |
return [ret[0][1],ret[0][2]] | |
if np.random.randint(10) <= 6: | |
return [ret[0][1],ret[0][2]] | |
else: | |
return [ret[1][1],ret[1][2]] | |
def make_net(): | |
for i in range(256): | |
for j in range(65): | |
validator[1].model.l1.W.data[i][j] += np.random.randn()*validator[1].model.l1.W.data[i][j]/10.0 | |
validator[1].model.l1.b.data[i] += np.random.randn()*validator[1].model.l1.b.data[i]/10.0 | |
for i in range(256): | |
for j in range(256): | |
validator[1].model.l2.W.data[i][j] += np.random.randn()*validator[1].model.l2.W.data[i][j]/10.0 | |
validator[1].model.l2.b.data[i] += np.random.randn()*validator[1].model.l2.b.data[i]/10.0 | |
for i in range(1): | |
for j in range(256): | |
validator[1].model.l3.W.data[i][j] += np.random.randn()*validator[1].model.l3.W.data[i][j]/10.0 | |
validator[1].model.l3.b.data[i] += np.random.randn()*validator[1].model.l3.b.data[i]/10.0 | |
def init_board(): | |
for i in range(8): | |
for j in range(8): | |
board[i][j] = 2 | |
board[3][3] = 0 | |
board[3][4] = 1 | |
board[4][3] = 1 | |
board[4][4] = 0 | |
def change_board(x, y, turn): | |
board[x][y] = turn | |
for i in range(8): | |
xx = x+dx[i] | |
yy = y+dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
continue | |
if board[xx][yy] != (turn^1): | |
continue | |
while True: | |
xx += dx[i] | |
yy += dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
break | |
if board[xx][yy] == (turn^1): | |
continue | |
if board[xx][yy] == turn: | |
xx -= dx[i] | |
yy -= dy[i] | |
while xx != x or yy != y: | |
board[xx][yy] = turn | |
xx -= dx[i] | |
yy -= dy[i] | |
break | |
return | |
def can_put(x, y, turn): | |
if board[x][y] != 2: | |
return False | |
for i in range(8): | |
xx = x+dx[i] | |
yy = y+dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
continue | |
if board[xx][yy] != (turn^1): | |
continue | |
while True: | |
xx += dx[i] | |
yy += dy[i] | |
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8: | |
break | |
if board[xx][yy] == (turn^1): | |
continue | |
if board[xx][yy] == turn: | |
return True | |
break; | |
return False | |
def change_turn(turn): | |
turn ^= 1 | |
for i in range(8): | |
for j in range(8): | |
if can_put(i, j, turn) == True: | |
return 1 | |
turn ^= 1 | |
for i in range(8): | |
for j in range(8): | |
if can_put(i, j, turn) == True: | |
return 2 | |
return 0 | |
if __name__ == '__main__': | |
make_net() | |
MATCH_TIMES = 100 | |
UPDATE_WIN = 80 | |
first_win = 0 | |
second_win = 0 | |
for train_time in range(MATCH_TIMES): | |
init_board() | |
turn = 1 | |
while True: | |
if turn == train_time%2: | |
hand = search(board, turn, 0) | |
change_board(hand[0], hand[1], turn) | |
tmp = change_turn(turn) | |
if tmp == 0: | |
break | |
elif tmp == 1: | |
turn ^= 1 | |
else: | |
hand = search(board, turn, 1) | |
change_board(hand[0], hand[1], turn) | |
tmp = change_turn(turn) | |
if tmp == 0: | |
break | |
elif tmp == 1: | |
turn ^= 1 | |
empty_area = 64 | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] != 2: | |
empty_area -= 1 | |
if empty_area <= 15: | |
break | |
''' | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] == 1: | |
sys.stdout.write('*') | |
elif board[i][j] == 2: | |
sys.stdout.write('_') | |
else: | |
sys.stdout.write('o') | |
print('') | |
print('') | |
''' | |
black = 0L | |
white = 0L | |
for i in range(8): | |
for j in range(8): | |
if board[i][j] == 1: | |
black += (1L<<(i*8+j)) | |
elif board[i][j] == 0: | |
white += (1L<<(i*8+j)) | |
f = open('input.txt', 'w') | |
f.write(str(black) + ' ' + str(white) + ' ' + str(turn) + '\n') | |
f.close() | |
result = commands.getoutput("./complete_read/a.out < input.txt") | |
if result == "black": | |
if train_time%2 == 1: | |
first_win += 1 | |
else: | |
second_win += 1 | |
elif result == "white": | |
if train_time%2 == 1: | |
second_win += 1 | |
else: | |
first_win += 1 | |
if first_win > MATCH_TIMES-UPDATE_WIN: | |
break | |
print "final " + str(first_win) + " " + str(second_win) | |
if second_win >= UPDATE_WIN: | |
serializers.save_hdf5('sigMyChain0.model', validator[1].model) | |
print "updated!" | |
time_end = time.time() | |
sys.stderr.write("train.py : " + str(time_end-time_start) + "\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment