Skip to content

Instantly share code, notes, and snippets.

@SaiVinay007
Last active July 15, 2020 10:46
Show Gist options
  • Save SaiVinay007/c3ef947a7c5eb798a73ae6207441a1d4 to your computer and use it in GitHub Desktop.
Save SaiVinay007/c3ef947a7c5eb798a73ae6207441a1d4 to your computer and use it in GitHub Desktop.
# Custom model
#---------------------------------------------------------------------------------
class CustomActorCritic(nn.Module):
def __init__(self, input_dim, control_dim, hidden_dim, safety_dim, Num_Hidden_Shared, Num_Hidden_Control, Num_Hidden_Safety, std=0.0):
super(CustomActorCritic, self).__init__()
layers_shared_actor = []
layers_safety_actor = []
layers_control_actor = []
layers_shared_critic = []
layers_safety_critic = []
layers_control_critic = []
in_dim = input_dim
out_dim = hidden_dim
# shared part
for i in range(Num_Hidden_Shared):
layers_shared_actor.append(nn.Linear(in_dim, out_dim))
layers_shared_actor.append(nn.Tanh())
in_dim = out_dim
# safety head
for i in range(Num_Hidden_Safety):
layers_safety_actor.append(nn.Linear(in_dim, out_dim))
layers_safety_actor.append(nn.Tanh())
# action head
for i in range(Num_Hidden_Control):
layers_control_actor.append(nn.Linear(in_dim, out_dim))
layers_control_actor.append(nn.Tanh())
self.base_actor = nn.Sequential(*layers_shared_actor)
self.safety_layer_actor = nn.Sequential(*layers_safety_actor,
nn.Linear(out_dim, safety_dim)
)
self.control_layer_actor = nn.Sequential(*layers_control_actor,
nn.Linear(out_dim, control_dim)
)
in_dim = input_dim
out_dim = hidden_dim
# shared part
for i in range(Num_Hidden_Shared):
layers_shared_critic.append(nn.Linear(in_dim, out_dim))
layers_shared_critic.append(nn.Tanh())
in_dim = out_dim
# safety head
for i in range(Num_Hidden_Safety):
layers_safety_critic.append(nn.Linear(in_dim, out_dim))
layers_safety_critic.append(nn.Tanh())
# action head
for i in range(Num_Hidden_Control):
layers_control_critic.append(nn.Linear(in_dim, out_dim))
layers_control_critic.append(nn.Tanh())
self.base_critic = nn.Sequential(*layers_shared_critic)
self.safety_layer_critic = nn.Sequential(*layers_safety_critic,
nn.Linear(out_dim, 1)
)
self.control_layer_critic = nn.Sequential(*layers_control_critic,
nn.Linear(out_dim, 1)
)
self.log_std1 = nn.Parameter(torch.ones(1, control_dim) * std)
self.log_std2 = nn.Parameter(torch.ones(1, safety_dim) * std)
# self.apply(init_weights)
def forward(self, state):
# state = torch.from_numpy(state).float().to(device)
mu1 = self.control_layer_actor(self.base_actor(state)) # mu1
mu2 = self.safety_layer_actor(self.base_actor(state)) # mu2
std1 = self.log_std1.exp()
dist1 = Normal(mu1, std1)
std2 = self.log_std2.exp()
dist2 = Normal(mu2, std2)
return dist1, dist2
def evaluate(self, state):
control_state_value = self.control_layer_critic(self.base_critic(state))
safety_state_value = self.safety_layer_critic(self.base_critic(state))
return control_state_value, safety_state_value
# Modified loss function
#---------------------------------------------------------------------------------
def custom_ppo_iter(mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s, returns_c, returns_s, \
advantages_c, advantages_s):
batch_size = states.size(0)
for _ in range(batch_size // mini_batch_size):
rand_ids = np.random.randint(0, batch_size, mini_batch_size)
yield states[rand_ids, :], controls[rand_ids, :], safetys[rand_ids, :], log_probs_c[rand_ids, :], \
log_probs_s[rand_ids, :], returns_c[rand_ids, :], returns_s[rand_ids, :], advantages_c[rand_ids, :], advantages_s[rand_ids, :]
def custom_ppo_update(ppo_epochs, mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s \
, returns_c, returns_s, advantages_c, advantages_s, clip_param=0.2):
for _ in range(ppo_epochs):
for state, control, safety, old_log_probs_c, old_log_probs_s, return_c, return_s, advantage_c, advantage_s \
in custom_ppo_iter(mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s, returns_c, returns_s, \
advantages_c, advantages_s):
dist_c, dist_s = model.act(state)
value_c, value_s = model.evaluate(state)
entropy_c = dist_c.entropy().mean()
entropy_s = dist_s.entropy().mean()
new_log_probs_c = dist_c.log_prob(control)
new_log_probs_s = dist_s.log_prob(safety)
ratio_c = (new_log_probs_c - old_log_probs_c).exp()
ratio_s = (new_log_probs_s - old_log_probs_s).exp()
surr1_c = ratio_c * advantage_c
surr1_s = ratio_s * advantage_s
surr2_c = torch.clamp(ratio_c, 1.0 - clip_param, 1.0 + clip_param) * advantage_c
surr2_s = torch.clamp(ratio_s, 1.0 - clip_param, 1.0 + clip_param) * advantage_s
actor_loss_c = - torch.min(surr1_c, surr2_c).mean()
actor_loss_s = - torch.min(surr1_s, surr2_s).mean()
critic_loss_c = (return_c - value_c).pow(2).mean()
critic_loss_s = (return_s - value_s).pow(2).mean()
loss_c = 0.5 * critic_loss_c + actor_loss_c - 0.001 * entropy_c
loss_s = 0.5 * critic_loss_s + actor_loss_s - 0.001 * entropy_s
loss = CTRL_W*loss_c + SFTY_W*loss_s
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Training code
#---------------------------------------------------------------------------------
while frame_idx < max_frames and not early_stop:
for _ in range(num_steps):
state = torch.FloatTensor(state).to(device)
dist_c, dist_s = model.act(state)
value_c, value_s = model.evaluate(state)
control = dist_c.sample()
safety = dist_s.sample()
next_state, reward_c, done, _ = envs.step(control.cpu().numpy())
log_prob_c = dist_c.log_prob(control)
log_prob_s = dist_s.log_prob(safety)
entropy_c += dist_c.entropy().mean()
entropy_s += dist_s.entropy().mean()
log_probs_c.append(log_prob_c)
log_probs_s.append(log_prob_s)
values_c.append(value_c)
values_s.append(value_s)
reward_s = safety_reward(state, next_state, reward_c)
rewards_c.append(torch.FloatTensor(reward_c).unsqueeze(1).to(device))
rewards_s.append(torch.FloatTensor(reward_s).unsqueeze(1).to(device))
masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))
states.append(state)
controls.append(control)
safetys.append(safety)
state = next_state
frame_idx += 1
if frame_idx % 100 == 0:
test_reward = np.mean([custom_test_env() for _ in range(10)])
test_rewards.append(test_reward)
plot(frame_idx, test_rewards)
if test_reward > threshold_reward: early_stop = True
next_state = torch.FloatTensor(next_state).to(device)
next_value_c, next_value_s = model.evaluate(next_state)
returns_c = compute_gae(next_value_c, rewards_c, masks, values_c)
returns_s = compute_gae(next_value_s, rewards_s, masks, values_s)
returns_c = torch.cat(returns_c).detach()
returns_s = torch.cat(returns_s).detach()
log_probs_c = torch.cat(log_probs_c).detach()
log_probs_s = torch.cat(log_probs_s).detach()
values_c = torch.cat(values_c).detach()
values_s = torch.cat(values_s).detach()
states = torch.cat(states)
controls = torch.cat(controls)
safetys = torch.cat(safetys)
advantages_c = returns_c - values_c
advantages_s = returns_s - values_s
custom_ppo_update(ppo_epochs, mini_batch_size, states, controls, safetys, log_probs_c, log_probs_s \
, returns_c, returns_s, advantages_c, advantages_s, clip_param=0.2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment