Created
November 6, 2019 05:13
-
-
Save jadechip/10c6561293ec6e1a096261ec2c04a3e0 to your computer and use it in GitHub Desktop.
Reinforcement learning - continuous control companion code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def hidden_init(layer): | |
fan_in = layer.weight.data.size()[0] | |
lim = 1. / np.sqrt(fan_in) | |
return (-lim, lim) | |
class Actor(nn.Module): | |
"""Actor (Policy) Model.""" | |
def __init__(self, state_size, action_size, seed, fc1_units=400, fc2_units=300): | |
"""Initialize parameters and build model. | |
Params | |
====== | |
state_size (int): Dimension of each state | |
action_size (int): Dimension of each action | |
seed (int): Random seed | |
fc1_units (int): Number of nodes in first hidden layer | |
fc2_units (int): Number of nodes in second hidden layer | |
""" | |
super(Actor, self).__init__() | |
self.seed = torch.manual_seed(seed) | |
self.fc1 = nn.Linear(state_size, fc1_units) | |
self.fc2 = nn.Linear(fc1_units, fc2_units) | |
self.fc3 = nn.Linear(fc2_units, action_size) | |
self.reset_parameters() | |
def reset_parameters(self): | |
self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) | |
self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) | |
self.fc3.weight.data.uniform_(-3e-3, 3e-3) | |
def forward(self, state): | |
"""Build an actor (policy) network that maps states -> actions.""" | |
x = F.relu(self.fc1(state)) | |
x = F.relu(self.fc2(x)) | |
return F.tanh(self.fc3(x)) | |
class Critic(nn.Module): | |
"""Critic (Value) Model.""" | |
def __init__(self, state_size, action_size, seed, fcs1_units=400, fc2_units=300): | |
"""Initialize parameters and build model. | |
Params | |
====== | |
state_size (int): Dimension of each state | |
action_size (int): Dimension of each action | |
seed (int): Random seed | |
fcs1_units (int): Number of nodes in the first hidden layer | |
fc2_units (int): Number of nodes in the second hidden layer | |
""" | |
super(Critic, self).__init__() | |
self.seed = torch.manual_seed(seed) | |
self.fcs1 = nn.Linear(state_size, fcs1_units) | |
self.fc2 = nn.Linear(fcs1_units+action_size, fc2_units) | |
self.fc3 = nn.Linear(fc2_units, 1) | |
self.reset_parameters() | |
def reset_parameters(self): | |
self.fcs1.weight.data.uniform_(*hidden_init(self.fcs1)) | |
self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) | |
self.fc3.weight.data.uniform_(-3e-3, 3e-3) | |
def forward(self, state, action): | |
"""Build a critic (value) network that maps (state, action) pairs -> Q-values.""" | |
xs = F.relu(self.fcs1(state)) | |
x = torch.cat((xs, action), dim=1) | |
x = F.relu(self.fc2(x)) | |
return self.fc3(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment