Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Last active January 21, 2020 17:21
Show Gist options
  • Save piEsposito/558e9ca1fb2c2abba275781b64623a3b to your computer and use it in GitHub Desktop.
Save piEsposito/558e9ca1fb2c2abba275781b64623a3b to your computer and use it in GitHub Desktop.
def update_policy(policy_network, rewards, log_probs):
discounted_rewards = []
for t in range(len(rewards)):
Gt = 0
pw = 0
for r in rewards[t:]:
Gt = Gt + GAMMA**pw * r
pw += 1
discounted_rewards.append(Gt)
discounted_rewards = torch.tensor(discounted_rewards, device=device)
discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std())
policy_gradient = []
for log_prob, Gt in zip(log_probs, discounted_rewards):
policy_gradient.append(-log_prob * Gt)
policy_network.optimizer.zero_grad()
policy_gradient_ = torch.stack(policy_gradient).sum()
policy_gradient_.backward(retain_graph=True)
policy_network.optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment