Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Last active January 21, 2020 17:21
Show Gist options
  • Save piEsposito/061b8ad69020e70a294e747415621e0e to your computer and use it in GitHub Desktop.
Save piEsposito/061b8ad69020e70a294e747415621e0e to your computer and use it in GitHub Desktop.
class PolicyNetwork(nn.Module):
def __init__(self, lr):
"""
We've put Tanh as activation in order to introduce variance on the learning
by making the model more sensible.
I encourage you to try other architectures, optimizers and hyperparameters
"""
super(PolicyNetwork, self).__init__()
self.num_actions = 3
self.conv_net = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.BatchNorm2d(32),
nn.ELU(True),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.BatchNorm2d(64),
nn.ELU(True),
nn.Conv2d(64, 128, kernel_size=4, stride=2 ),
nn.BatchNorm2d(128),
nn.ReLU(True))
self.linear = nn.Sequential(nn.Linear(1152, 512),
nn.Tanh(),
nn.Linear(512, 512),
nn.Tanh(),
nn.Linear(512, self.num_actions),
nn.Tanh(),)
self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
def forward(self, state_stack):
"""
simple feedforward method
"""
x = self.conv_net(state_stack)
x = x.view(x.size(0), -1)
x = F.softmax(self.linear(x), dim=1)
return x
def get_action(self, state):
state = state.float().unsqueeze(0)
probs = self.forward(Variable(state))
#we've decided to use stochastic action learning in order to introduce variance in the learning
distribution = torch.distributions.categorical.Categorical(probs = probs.detach())
highest_prob_action = distribution.sample()
log_prob = torch.log(probs.squeeze(0)[highest_prob_action])
#it returns the useful values for acting and optimizing
return highest_prob_action, log_prob
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment