Skip to content

Instantly share code, notes, and snippets.

@neale
Created March 21, 2017 19:46
Show Gist options
  • Save neale/26130d64591f6b57d7473cf99573e8eb to your computer and use it in GitHub Desktop.
Save neale/26130d64591f6b57d7473cf99573e8eb to your computer and use it in GitHub Desktop.
stacked A4C
class A4C(torch.nn.Module):
def __init__(self, input, n_actions):
""" Standard recurrent A3C """
self.conv1 = nn.Conv2d(input, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.lstm = StackedLSTM(2, 256, 256, dropout=.5)
""" Attention Network Linear Layers"""
self.Att_fc1 = nn.Linear(256, 256)
self.Att_fc2 = nn.Linear(256, 256)
""" actor critic networks """
self.critic = nn.Linear(256, 1)
self.actor = nn.Linear(256, n_actions)
""" init weights for linear, conv, and lstm layers """
self.train()
def forward(self, inputs):
inputs, (h_in, c_in) = inputs
x = F.elu(self.conv1(inputs))
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))
fmaps = [self.conv2.weight.view(-1, 1, 256),
self.conv3.weight.view(-1, 1, 256),
self.conv4.weight.view(-1, 1, 256)]
for fmap in fmaps:
for i, v_t in enumerate(fmap):
v_t = v_t.view(1, 256)
fc1_out = self.A_fc1(v_t)
add = fc1_out + h_in
tan = F.tanh(add)
sm = F.softmax(self.A_fc2(tan))
m = sm.data.max()
sm.data = torch.Tensor.sub_(sm.data, m)
gv_t = torch.mul(sm, v_t)
outputs.append(gv_t)
g_t = torch.stack(outputs).sum(0).view(-1, 256)
h_t, c_t = self.lstm(g_t, (h_in, c_in))
return self.critic_linear(h_t), self.actor_linear(h_t), (h_t, c_t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment