Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import time
time_start = time.time()
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import sys
import copy
import commands
class MyChain(Chain):
def __init__(self):
super(MyChain, self).__init__(
l1 = L.Linear(65, 256),
l2 = L.Linear(256, 256),
l3 = L.Linear(256, 1),
)
def __call__(self, x, y):
return F.mean_squared_error(self.fwd(x), y)
def fwd(self, x):
h = F.sigmoid(self.l1(x))
h2 = F.sigmoid(self.l2(h))
h3 = self.l3(h2)
return h3
class Valid:
def __init__(self):
self.model = MyChain()
serializers.load_hdf5('sigMyChain0.model', self.model)
return
def validate(self, nowboard, now):
test_x = []
xx = []
for i in range(8):
for j in range(8):
x = 0.0
if nowboard[i][j] == 1:
x = 1.0
elif nowboard[i][j] == 0:
x = -1.0
xx.append(x)
if now == 1:
xx.append(1.0)
else:
xx.append(-1.0)
test_x.append(xx)
test_x = np.array(test_x).astype(np.float32)
xt = Variable(test_x, volatile='on')
val = self.model.fwd(xt)
return val.data
dx = [ 0, 1, 1, 1, 0,-1,-1,-1]
dy = [ 1, 1, 0,-1,-1,-1, 0, 1]
board = [[0 for j in range(8)] for i in range(8)]
validator = [Valid(), Valid()]
def search(board, turn, index):
ret = []
for i in range(8):
for j in range(8):
if board[i][j] != 2:
continue;
newboard = copy.deepcopy(board)
ok = False
for k in range(8):
x = i+dx[k]
y = j+dy[k]
if x < 0 or x >= 8 or y < 0 or y >= 8:
continue
if board[x][y] != (turn^1):
continue;
while True:
x += dx[k]
y += dy[k]
if x < 0 or x >= 8 or y < 0 or y >= 8:
break
if board[x][y] == (turn^1):
continue
if board[x][y] == turn:
ok = True
x -= dx[k]
y -= dy[k]
while x != i or y != j:
newboard[x][y] = turn
x -= dx[k]
y -= dy[k]
newboard[x][y] = turn
break
if ok == False:
continue
now_val = validator[index].validate(newboard, turn)
ret.append([now_val,i,j])
stone_count = 0
for i in range(8):
for j in range(8):
if board[i][j] != 2:
stone_count += 1
ret.sort()
if turn == 1:
ret.reverse()
if len(ret) == 1 or stone_count >= 15:
return [ret[0][1],ret[0][2]]
if np.random.randint(10) <= 6:
return [ret[0][1],ret[0][2]]
else:
return [ret[1][1],ret[1][2]]
def make_net():
for i in range(256):
for j in range(65):
validator[1].model.l1.W.data[i][j] += np.random.randn()*validator[1].model.l1.W.data[i][j]/10.0
validator[1].model.l1.b.data[i] += np.random.randn()*validator[1].model.l1.b.data[i]/10.0
for i in range(256):
for j in range(256):
validator[1].model.l2.W.data[i][j] += np.random.randn()*validator[1].model.l2.W.data[i][j]/10.0
validator[1].model.l2.b.data[i] += np.random.randn()*validator[1].model.l2.b.data[i]/10.0
for i in range(1):
for j in range(256):
validator[1].model.l3.W.data[i][j] += np.random.randn()*validator[1].model.l3.W.data[i][j]/10.0
validator[1].model.l3.b.data[i] += np.random.randn()*validator[1].model.l3.b.data[i]/10.0
def init_board():
for i in range(8):
for j in range(8):
board[i][j] = 2
board[3][3] = 0
board[3][4] = 1
board[4][3] = 1
board[4][4] = 0
def change_board(x, y, turn):
board[x][y] = turn
for i in range(8):
xx = x+dx[i]
yy = y+dy[i]
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8:
continue
if board[xx][yy] != (turn^1):
continue
while True:
xx += dx[i]
yy += dy[i]
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8:
break
if board[xx][yy] == (turn^1):
continue
if board[xx][yy] == turn:
xx -= dx[i]
yy -= dy[i]
while xx != x or yy != y:
board[xx][yy] = turn
xx -= dx[i]
yy -= dy[i]
break
return
def can_put(x, y, turn):
if board[x][y] != 2:
return False
for i in range(8):
xx = x+dx[i]
yy = y+dy[i]
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8:
continue
if board[xx][yy] != (turn^1):
continue
while True:
xx += dx[i]
yy += dy[i]
if xx < 0 or xx >= 8 or yy < 0 or yy >= 8:
break
if board[xx][yy] == (turn^1):
continue
if board[xx][yy] == turn:
return True
break;
return False
def change_turn(turn):
turn ^= 1
for i in range(8):
for j in range(8):
if can_put(i, j, turn) == True:
return 1
turn ^= 1
for i in range(8):
for j in range(8):
if can_put(i, j, turn) == True:
return 2
return 0
if __name__ == '__main__':
make_net()
MATCH_TIMES = 100
UPDATE_WIN = 80
first_win = 0
second_win = 0
for train_time in range(MATCH_TIMES):
init_board()
turn = 1
while True:
if turn == train_time%2:
hand = search(board, turn, 0)
change_board(hand[0], hand[1], turn)
tmp = change_turn(turn)
if tmp == 0:
break
elif tmp == 1:
turn ^= 1
else:
hand = search(board, turn, 1)
change_board(hand[0], hand[1], turn)
tmp = change_turn(turn)
if tmp == 0:
break
elif tmp == 1:
turn ^= 1
empty_area = 64
for i in range(8):
for j in range(8):
if board[i][j] != 2:
empty_area -= 1
if empty_area <= 15:
break
'''
for i in range(8):
for j in range(8):
if board[i][j] == 1:
sys.stdout.write('*')
elif board[i][j] == 2:
sys.stdout.write('_')
else:
sys.stdout.write('o')
print('')
print('')
'''
black = 0L
white = 0L
for i in range(8):
for j in range(8):
if board[i][j] == 1:
black += (1L<<(i*8+j))
elif board[i][j] == 0:
white += (1L<<(i*8+j))
f = open('input.txt', 'w')
f.write(str(black) + ' ' + str(white) + ' ' + str(turn) + '\n')
f.close()
result = commands.getoutput("./complete_read/a.out < input.txt")
if result == "black":
if train_time%2 == 1:
first_win += 1
else:
second_win += 1
elif result == "white":
if train_time%2 == 1:
second_win += 1
else:
first_win += 1
if first_win > MATCH_TIMES-UPDATE_WIN:
break
print "final " + str(first_win) + " " + str(second_win)
if second_win >= UPDATE_WIN:
serializers.save_hdf5('sigMyChain0.model', validator[1].model)
print "updated!"
time_end = time.time()
sys.stderr.write("train.py : " + str(time_end-time_start) + "\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment