Skip to content

Instantly share code, notes, and snippets.

@zsunberg
Created October 18, 2020 04:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zsunberg/a6fdebed95c518ba4e06e1d5d782a30a to your computer and use it in GitHub Desktop.
Save zsunberg/a6fdebed95c518ba4e06e1d5d782a30a to your computer and use it in GitHub Desktop.
JuliaReinforcementLearning script that produces a bounds errror
using ReinforcementLearningZoo
using ReinforcementLearningBase
using ReinforcementLearningCore: NeuralNetworkApproximator, EpsilonGreedyExplorer, QBasedPolicy, CircularCompactSARTSATrajectory
using ReinforcementLearning
using Flux
using Flux: glorot_uniform, huber_loss
import Random
import BSON
RL = ReinforcementLearningBase
rng = Random.GLOBAL_RNG
mutable struct MyEnv <: AbstractEnv
s::Int
end
RL.get_actions(env::MyEnv) = [-1, 1]
RL.get_state(env::MyEnv) = [env.s]
RL.get_reward(env::MyEnv) = env.s
RL.get_terminal(env::MyEnv) = env.s >= 3
RL.reset!(env::MyEnv) = env.s = 1
function (env::MyEnv)(a)
env.s += a + rand([-1, 0, 1])
end
env = MyEnv(1)
ns, na = length(get_state(env)), length(get_actions(env))
agent = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
batch_size = 32,
min_replay_history = 100,
loss_func = huber_loss,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
),
),
trajectory = CircularCompactSARTSATrajectory(
capacity = 1000,
state_type = Float32,
state_size = (ns,),
),
)
stop_condition = StopAfterStep(10000)
total_reward_per_episode = TotalRewardPerEpisode()
time_per_step = TimePerStep()
hook = ComposedHook(
total_reward_per_episode,
time_per_step,
DoEveryNStep(10000) do t, agent, env
RLCore.save("/tmp/", agent)
BSON.@save joinpath("/tmp/", "stats.bson") total_reward_per_episode time_per_step
end,
)
exp = Experiment(agent, env, stop_condition, hook, "jrl_dqn")
run(exp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment