Skip to content

Instantly share code, notes, and snippets.

@lukicdarkoo
Created March 25, 2021 21:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukicdarkoo/bd30bfc00a4f454d3cce6c3f99e4464e to your computer and use it in GitHub Desktop.
Save lukicdarkoo/bd30bfc00a4f454d3cce6c3f99e4464e to your computer and use it in GitHub Desktop.
Actor-Critic
import gym
import numpy as np
import torch
import torch.nn
from torch.nn.functional import smooth_l1_loss, relu, softmax
class ActorCriticNet(torch.nn.Module):
def __init__(self, n_states=4, n_hidden=128, n_actions=2):
super().__init__()
self.__input = torch.nn.Linear(n_states, n_hidden)
self.__action = torch.nn.Linear(n_hidden, n_actions)
self.__value = torch.nn.Linear(n_hidden, 1)
def forward(self, x):
x = relu(self.__input(x))
action = softmax(self.__action(x), dim=-1)
value = self.__value(x)
return action, value
def main():
# Config
timesteps_per_episode = 1_000
gamma = 0.99
# Models
env = gym.make("CartPole-v0").env
net = ActorCriticNet(
n_states=env.observation_space.shape[0],
n_actions=env.action_space.n
)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
# State
action_probabilities_history = []
critic_value_history = []
reward_history = []
total_loss = None
last_reward = 0
while True:
state = env.reset()
# Play
for _ in range(1, timesteps_per_episode):
if last_reward > 0.95 * timesteps_per_episode:
env.render()
state = torch.tensor(state).float()
action_probabilities, critic_value = net(state)
# Value
critic_value_history.append(critic_value)
# Policy
action = np.random.choice(action_probabilities.shape[0], p=action_probabilities.detach().numpy())
action_probabilities_history.append(action_probabilities[action])
state, reward, done, _ = env.step(action)
reward_history.append(reward)
if done:
break
# Actual value values
discounted_sum = 0
last_reward = sum(reward_history)
actual_values = []
for reward in reward_history[::-1]:
discounted_sum = reward + gamma * discounted_sum
actual_values.insert(0, discounted_sum)
# Normalize
actual_values = torch.tensor(actual_values)
actual_values = (actual_values - actual_values.mean()) / (actual_values.std() + 1e-12)
# Losses
actor_loss_sum = 0
critic_loss_sum = 0
for action_probability, critic_value, actual_value in zip(action_probabilities_history, critic_value_history, actual_values):
advantage_estimate = actual_value - critic_value.item()
actor_loss_sum = actor_loss_sum - torch.log(action_probability) * advantage_estimate
critic_loss_sum = critic_loss_sum + smooth_l1_loss(actual_value, critic_value)
total_loss = actor_loss_sum + critic_loss_sum
# Optimize
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Clear history
action_probabilities_history.clear()
critic_value_history.clear()
reward_history.clear()
print(last_reward)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment