Skip to content

Instantly share code, notes, and snippets.

@dylandjian
Created April 26, 2018 21:55
Show Gist options
  • Save dylandjian/87c824f5455356887a5245654e1c1b2b to your computer and use it in GitHub Desktop.
Save dylandjian/87c824f5455356887a5245654e1c1b2b to your computer and use it in GitHub Desktop.
help plz
import numpy as np
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
# length of board
L = 10
boat_shapes = [(2,4), (1,5), (1,3), (1,3), (1,3)]
def get_board():
total_mass = sum([x[0]*x[1] for x in boat_shapes])
def _gen_boats():
ret = np.zeros([L, L])
occupied = []
poses = []
joint_cstr = []
for b_sh in boat_shapes:
crd = np.random.randint(0, L-1, [2])
wh,d = rand_orient(*b_sh)
joint_cstr.append(rect_constr(crd, wh))
poses.append((crd[0],crd[1],d))
joint_constr = or_constr(joint_cstr)
for y in range(L):
for x in range(L):
if joint_constr((x,y)):
occupied.append((x,y))
ret[y][x] = 1
return ret, set(occupied), poses
ret, occupied, poses = _gen_boats()
if len(occupied) == total_mass:
return ret, occupied, poses
else:
return get_board()
def rand_orient(w,h):
if np.random.random() < 0.5:
return (w,h),True
else:
return (h,w),False
def rect_constr(left_top, wid_hei):
left, top = left_top
wid, hei = wid_hei
right, down = left + wid, top+hei
def constr(crd):
xx, yy = crd
in_e1 = xx >= left
in_e2 = xx < right
in_e3 = yy >= top
in_e4 = yy < down
return in_e1 and in_e2 and in_e3 and in_e4
return constr
def or_constr(crs):
def constr(crd):
for cr in crs:
if cr(crd):
return True
return False
return constr
def mask_board(board, made_moves):
board = np.copy(board)
for x in range(L):
for y in range(L):
if (x,y) not in made_moves:
board[y][x] = 2
return board
class GameEnv(object):
def __init__(self):
self.board, self.occupied, _ = get_board()
self.possible_actions = list(range(L*L))
def win(self):
return self.occupied.issubset(self.made_moves)
def reset(self):
self.made_moves = set()
return mask_board(self.board, self.made_moves)
def get_reward(self, x, y):
if (self.board[y][x] == 1 and (x,y) not in self.made_moves):
return 1.0
return -0.01
def step(self, action):
x, y = action // L, action % L
reward = self.get_reward(x,y)
self.made_moves.add((x,y))
if action in self.possible_actions:
self.possible_actions.remove(action)
done = self.win()
state = mask_board(self.board, self.made_moves)
return state, reward, done
class Agent(nn.Module):
def __init__(self):
super(Agent, self).__init__()
self.conv1 = nn.Conv2d(20, 10, stride=1, kernel_size=2, bias=False, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.fc = nn.Linear(1210, L ** 2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.fc(x.view(-1))
return F.softmax(x, dim=0)
def play_game(agent):
game = GameEnv()
state = game.reset()
done = False
states = []
labels = []
history = np.full((20, L, L), 2)
current_idx = 0
while not done:
history = np.roll(history, 1, axis=0)
history[0] = state
possible_actions = game.possible_actions
old_state = torch.tensor([history], dtype=torch.float, device=cuda0)
## Predicted
probas = agent(old_state).cpu().data.numpy()
probas[np.setdiff1d(np.arange(L * L), np.array(possible_actions))] = 0
total = np.sum(probas)
probas /= total
move = np.random.choice(np.where(probas == np.max(probas))[0])
state, r, done = game.step(move)
## Dataset
if r == 1:
states.append(old_state)
labels.append(move)
current_idx += 1
print("game done after %d steps" % current_idx)
return states, labels
cuda0 = torch.device('cuda:0')
if __name__ == "__main__":
agent = Agent().cuda()
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(agent.parameters(), lr=0.01, weight_decay=0.0001)
for i in range(100):
states, labels = play_game(agent)
for state, label in zip(states, labels):
opt.zero_grad()
label = torch.tensor(label, dtype=torch.long, device=cuda0)
probas = agent(state)
# values, probas = torch.max(agent(state), 0)
move = torch.tensor(probas, dtype=torch.float, device=cuda0).view(1, 100)
print(move, label)
loss = criterion(move, label)
print(loss.item())
assert 0
loss.backward()
opt.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment