Last active
July 15, 2020 10:46
-
-
Save SaiVinay007/c3ef947a7c5eb798a73ae6207441a1d4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Custom model | |
#--------------------------------------------------------------------------------- | |
class CustomActorCritic(nn.Module): | |
def __init__(self, input_dim, control_dim, hidden_dim, safety_dim, Num_Hidden_Shared, Num_Hidden_Control, Num_Hidden_Safety, std=0.0): | |
super(CustomActorCritic, self).__init__() | |
layers_shared_actor = [] | |
layers_safety_actor = [] | |
layers_control_actor = [] | |
layers_shared_critic = [] | |
layers_safety_critic = [] | |
layers_control_critic = [] | |
in_dim = input_dim | |
out_dim = hidden_dim | |
# shared part | |
for i in range(Num_Hidden_Shared): | |
layers_shared_actor.append(nn.Linear(in_dim, out_dim)) | |
layers_shared_actor.append(nn.Tanh()) | |
in_dim = out_dim | |
# safety head | |
for i in range(Num_Hidden_Safety): | |
layers_safety_actor.append(nn.Linear(in_dim, out_dim)) | |
layers_safety_actor.append(nn.Tanh()) | |
# action head | |
for i in range(Num_Hidden_Control): | |
layers_control_actor.append(nn.Linear(in_dim, out_dim)) | |
layers_control_actor.append(nn.Tanh()) | |
self.base_actor = nn.Sequential(*layers_shared_actor) | |
self.safety_layer_actor = nn.Sequential(*layers_safety_actor, | |
nn.Linear(out_dim, safety_dim) | |
) | |
self.control_layer_actor = nn.Sequential(*layers_control_actor, | |
nn.Linear(out_dim, control_dim) | |
) | |
in_dim = input_dim | |
out_dim = hidden_dim | |
# shared part | |
for i in range(Num_Hidden_Shared): | |
layers_shared_critic.append(nn.Linear(in_dim, out_dim)) | |
layers_shared_critic.append(nn.Tanh()) | |
in_dim = out_dim | |
# safety head | |
for i in range(Num_Hidden_Safety): | |
layers_safety_critic.append(nn.Linear(in_dim, out_dim)) | |
layers_safety_critic.append(nn.Tanh()) | |
# action head | |
for i in range(Num_Hidden_Control): | |
layers_control_critic.append(nn.Linear(in_dim, out_dim)) | |
layers_control_critic.append(nn.Tanh()) | |
self.base_critic = nn.Sequential(*layers_shared_critic) | |
self.safety_layer_critic = nn.Sequential(*layers_safety_critic, | |
nn.Linear(out_dim, 1) | |
) | |
self.control_layer_critic = nn.Sequential(*layers_control_critic, | |
nn.Linear(out_dim, 1) | |
) | |
self.log_std1 = nn.Parameter(torch.ones(1, control_dim) * std) | |
self.log_std2 = nn.Parameter(torch.ones(1, safety_dim) * std) | |
# self.apply(init_weights) | |
def forward(self, state): | |
# state = torch.from_numpy(state).float().to(device) | |
mu1 = self.control_layer_actor(self.base_actor(state)) # mu1 | |
mu2 = self.safety_layer_actor(self.base_actor(state)) # mu2 | |
std1 = self.log_std1.exp() | |
dist1 = Normal(mu1, std1) | |
std2 = self.log_std2.exp() | |
dist2 = Normal(mu2, std2) | |
return dist1, dist2 | |
def evaluate(self, state): | |
control_state_value = self.control_layer_critic(self.base_critic(state)) | |
safety_state_value = self.safety_layer_critic(self.base_critic(state)) | |
return control_state_value, safety_state_value | |
# Modified loss function | |
#--------------------------------------------------------------------------------- | |
def custom_ppo_iter(mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s, returns_c, returns_s, \ | |
advantages_c, advantages_s): | |
batch_size = states.size(0) | |
for _ in range(batch_size // mini_batch_size): | |
rand_ids = np.random.randint(0, batch_size, mini_batch_size) | |
yield states[rand_ids, :], controls[rand_ids, :], safetys[rand_ids, :], log_probs_c[rand_ids, :], \ | |
log_probs_s[rand_ids, :], returns_c[rand_ids, :], returns_s[rand_ids, :], advantages_c[rand_ids, :], advantages_s[rand_ids, :] | |
def custom_ppo_update(ppo_epochs, mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s \ | |
, returns_c, returns_s, advantages_c, advantages_s, clip_param=0.2): | |
for _ in range(ppo_epochs): | |
for state, control, safety, old_log_probs_c, old_log_probs_s, return_c, return_s, advantage_c, advantage_s \ | |
in custom_ppo_iter(mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s, returns_c, returns_s, \ | |
advantages_c, advantages_s): | |
dist_c, dist_s = model.act(state) | |
value_c, value_s = model.evaluate(state) | |
entropy_c = dist_c.entropy().mean() | |
entropy_s = dist_s.entropy().mean() | |
new_log_probs_c = dist_c.log_prob(control) | |
new_log_probs_s = dist_s.log_prob(safety) | |
ratio_c = (new_log_probs_c - old_log_probs_c).exp() | |
ratio_s = (new_log_probs_s - old_log_probs_s).exp() | |
surr1_c = ratio_c * advantage_c | |
surr1_s = ratio_s * advantage_s | |
surr2_c = torch.clamp(ratio_c, 1.0 - clip_param, 1.0 + clip_param) * advantage_c | |
surr2_s = torch.clamp(ratio_s, 1.0 - clip_param, 1.0 + clip_param) * advantage_s | |
actor_loss_c = - torch.min(surr1_c, surr2_c).mean() | |
actor_loss_s = - torch.min(surr1_s, surr2_s).mean() | |
critic_loss_c = (return_c - value_c).pow(2).mean() | |
critic_loss_s = (return_s - value_s).pow(2).mean() | |
loss_c = 0.5 * critic_loss_c + actor_loss_c - 0.001 * entropy_c | |
loss_s = 0.5 * critic_loss_s + actor_loss_s - 0.001 * entropy_s | |
loss = CTRL_W*loss_c + SFTY_W*loss_s | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Training code | |
#--------------------------------------------------------------------------------- | |
while frame_idx < max_frames and not early_stop: | |
for _ in range(num_steps): | |
state = torch.FloatTensor(state).to(device) | |
dist_c, dist_s = model.act(state) | |
value_c, value_s = model.evaluate(state) | |
control = dist_c.sample() | |
safety = dist_s.sample() | |
next_state, reward_c, done, _ = envs.step(control.cpu().numpy()) | |
log_prob_c = dist_c.log_prob(control) | |
log_prob_s = dist_s.log_prob(safety) | |
entropy_c += dist_c.entropy().mean() | |
entropy_s += dist_s.entropy().mean() | |
log_probs_c.append(log_prob_c) | |
log_probs_s.append(log_prob_s) | |
values_c.append(value_c) | |
values_s.append(value_s) | |
reward_s = safety_reward(state, next_state, reward_c) | |
rewards_c.append(torch.FloatTensor(reward_c).unsqueeze(1).to(device)) | |
rewards_s.append(torch.FloatTensor(reward_s).unsqueeze(1).to(device)) | |
masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device)) | |
states.append(state) | |
controls.append(control) | |
safetys.append(safety) | |
state = next_state | |
frame_idx += 1 | |
if frame_idx % 100 == 0: | |
test_reward = np.mean([custom_test_env() for _ in range(10)]) | |
test_rewards.append(test_reward) | |
plot(frame_idx, test_rewards) | |
if test_reward > threshold_reward: early_stop = True | |
next_state = torch.FloatTensor(next_state).to(device) | |
next_value_c, next_value_s = model.evaluate(next_state) | |
returns_c = compute_gae(next_value_c, rewards_c, masks, values_c) | |
returns_s = compute_gae(next_value_s, rewards_s, masks, values_s) | |
returns_c = torch.cat(returns_c).detach() | |
returns_s = torch.cat(returns_s).detach() | |
log_probs_c = torch.cat(log_probs_c).detach() | |
log_probs_s = torch.cat(log_probs_s).detach() | |
values_c = torch.cat(values_c).detach() | |
values_s = torch.cat(values_s).detach() | |
states = torch.cat(states) | |
controls = torch.cat(controls) | |
safetys = torch.cat(safetys) | |
advantages_c = returns_c - values_c | |
advantages_s = returns_s - values_s | |
custom_ppo_update(ppo_epochs, mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s \ | |
, returns_c, returns_s, advantages_c, advantages_s, clip_param=0.2) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment