Skip to content

Instantly share code, notes, and snippets.

@TsuMakoto
Created December 24, 2019 00:52
Show Gist options
  • Save TsuMakoto/f161da04bba8cf09023c509d5d697d47 to your computer and use it in GitHub Desktop.
Save TsuMakoto/f161da04bba8cf09023c509d5d697d47 to your computer and use it in GitHub Desktop.
ReinforcementLearning.jlのサンプル
using Pkg; Pkg.add.(["ReinforcementLearning", "ReinforcementLearningEnvironments", "Flux", "StatsBase", "Plots"])
# import Random # hide
# Random.seed!(1) # hide
using ReinforcementLearning, ReinforcementLearningEnvironments, Flux
# using StatsBase:mean
env = CartPoleEnv(;T=Float32)
ns, na = length(observation_space(env)), length(action_space(env)) # (4, 2)
device = :cpu
layer1 = Dense(ns, 128, relu)
layer2 = Dense(128, 128, relu)
layer3 = Dense(128, na)
neural_network_q = NeuralNetworkQ(model = Chain(layer1, layer2, layer3),
optimizer = ADAM(),
device = device)
ϵ_selector = EpsilonGreedySelector{:exp}(ϵ_stable = 0.01, decay_steps = 500)
q_base_policy = QBasedPolicy(learner = BasicDQNLearner(approximator = neural_network_q,
batch_size = 32,
min_replay_history = 100,
loss_fun = huber_loss),
selector = ϵ_selector)
circular_risa_buffer = circular_RTSA_buffer(capacity = 1000,
state_eltype = Float32,
state_size = (ns,))
agent = Agent(
π = q_base_policy,
buffer = circular_risa_buffer
)
hook = ComposedHook(
TotalRewardPerEpisode(),
TimePerStep()
)
run(agent, env, StopAfterStep(10000; is_show_progress=true); hook = hook)
using Plots
gr()
plot(hook[1].rewards, xlabel="Episode", ylabel="Reward", label="")
savefig("a_quick_example_cartpole_cpu_basic_dqn.png");
for b ∈ agent.buffer
env.state = b.state
render(env)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment