Skip to content

Instantly share code, notes, and snippets.

@adam-r-kowalski
Created May 3, 2019 18:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adam-r-kowalski/8d81b3fa62fdb0506aefc570b085a588 to your computer and use it in GitHub Desktop.
Save adam-r-kowalski/8d81b3fa62fdb0506aefc570b085a588 to your computer and use it in GitHub Desktop.
using Flux, PyCall, Distributions, Statistics, Plots
using Flux: params
using Flux.Tracker: gradient, update!
gym = pyimport("gym")
struct PolicyGradient{P, O}
policy::P
optimizer::O
log_probabilities::Vector{Tracker.TrackedReal{Float32}}
rewards::Vector{Float32}
discount_factor::Float32
end
function PolicyGradient(inputs, outputs)
h1, h2, h3 = [30, 40, 30]
model = Chain(
Dense(inputs, h1, relu),
Dense(h1, h2, relu),
Dense(h2, h3, relu),
Dense(h3, outputs),
softmax)
optimizer = ADAM()
log_probabilities = Tracker.TrackedReal{Float32}[]
rewards = Float32[]
discount_factor = Float32(0.9)
PolicyGradient(model, optimizer,
log_probabilities, rewards,
discount_factor)
end
function select_action!(agent, state)
probabilities = agent.policy(state)
distribution = Categorical(probabilities)
action = rand(distribution)
push!(agent.log_probabilities, loglikelihood(distribution, [action]))
action
end
remember!(agent, (_, _, reward, _)) = push!(agent.rewards, reward)
normalize(xs) = (xs .- mean(xs)) / (std(xs) + eps(eltype(xs)))
function discounted_rewards(agent)
rewards = agent.rewards
discounted = similar(rewards)
running_sum = Float32(0.0)
for i in length(rewards):-1:1
running_sum = agent.discount_factor * running_sum + rewards[i]
discounted[i] = running_sum
end
discounted
end
function improve_policy!(agent)
returns = normalize(discounted_rewards(agent))
θ = params(agent.policy)
Δ = gradient(() -> sum(-agent.log_probabilities .* returns), θ)
update!(agent.optimizer, θ, Δ)
empty!(agent.log_probabilities)
empty!(agent.rewards)
end
function simulate!(agent, env; render=false, episodes=1, graph=false)
rewards = Float64[]
@progress for _ in 1:episodes
episode_reward = 0.0
done = false
state = env.reset()
while !done
action = select_action!(agent, state)
next_state, reward, done, _ = env.step(action - 1)
remember!(agent, (state, action, reward, next_state))
state = next_state
episode_reward += reward
render && env.render()
end
improve_policy!(agent)
push!(rewards, episode_reward)
end
graph ? plot(rewards) : sum(rewards) / episodes
end
env = gym.make("CartPole-v0")
agent = PolicyGradient(4, 2)
simulate!(agent, env, episodes=300, graph=true)
simulate!(agent, env, episodes=5, render=true)
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment