Last active
August 19, 2020 08:56
-
-
Save PaulDebus/dcffd3d256fef2170cd328a8c9867d47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ReinforcementLearningBase | |
using Random | |
using ReinforcementLearning | |
using Flux | |
""" | |
ReinforcementLearning environment for the 15 Puzzle game. | |
As Julia has its arrays in column major form, the game is transposed from the human version | |
""" | |
const MAX_MOVES = 2000 | |
mutable struct GameEnv <: AbstractEnv | |
size::Int | |
grid::Array{Int, 2} | |
steps::Int | |
action::CartesianIndex | |
changed::Bool | |
end | |
function RLBase.get_actions(env::GameEnv) | |
collect(1:env.size*env.size) | |
end | |
function count_inversions(env::GameEnv)::Int | |
inv_count = 0 | |
for i in 1:env.size*env.size | |
for j in i:env.size*env.size | |
if env.grid[i] >0 && env.grid[j] > 0 && env.grid[i] > env.grid[j] | |
inv_count += 1 | |
end | |
end | |
end | |
inv_count | |
end | |
x_pos(env::GameEnv) = env.size - findfirst(iszero.(env.grid))[2] + 1 | |
function issolvable(env::GameEnv)::Bool | |
inv_count = count_inversions(env) | |
if isodd(env.size) | |
return iseven(inv_count) | |
else | |
pos = x_pos(env) | |
if isodd(pos) | |
return iseven(inv_count) | |
else | |
return isodd(inv_count) | |
end | |
end | |
end | |
function RLBase.get_terminal(env::GameEnv) | |
if env.steps > MAX_MOVES | |
return true | |
end | |
if env.grid[3] != 0 | |
return false | |
end | |
issorted(env.grid[1:end-1]) | |
end | |
function RLBase.get_reward(env::GameEnv) | |
if env.steps > MAX_MOVES | |
return -1 | |
end | |
if get_terminal(env) | |
return 1 | |
end | |
if env.changed | |
return 0 | |
else | |
return - 1/MAX_MOVES | |
end | |
end | |
RLBase.get_state(env::GameEnv) = vec(env.grid') | |
function RLBase.reset!(env::GameEnv) | |
s = env.size | |
env.grid = reshape(shuffle(0:s*s-1), s, s) | |
while !issolvable(env) | |
env.grid = reshape(shuffle(0:s*s-1), s, s) | |
end | |
env.steps = 0 | |
env.action = CartesianIndex(1,1) | |
env.changed = false | |
nothing | |
end | |
function GameEnv(s::Int=3) | |
grid = reshape(shuffle(0:s*s-1), s, s) | |
GameEnv(s, grid, 0, CartesianIndex(1,1), false) | |
end | |
function (env::GameEnv)(a) | |
env.changed = false | |
if env.steps > MAX_MOVES | |
return | |
end | |
env.steps += 1 | |
action = CartesianIndices(env.grid)[a] | |
env.action = action | |
cell = env.grid[action] | |
if cell == 0 | |
return | |
end | |
for offset in ((-1, 0), (1, 0), (0, -1), (0, 1) ) | |
field = action + CartesianIndex(offset) | |
if checkbounds(Bool, env.grid, field) && env.grid[field] == 0 | |
env.grid[field] = cell | |
env.grid[action] = 0 | |
env.changed = true | |
return | |
end | |
end | |
end | |
rng = MersenneTwister(123); | |
env = GameEnv(3) | |
ns, na = length(get_state(env)), length(get_actions(env)) | |
struct MyHook <: AbstractHook | |
end | |
function (hook::MyHook)(::PreActStage, agent, env, action) | |
@show(action) | |
end | |
agent = Agent( | |
policy = QBasedPolicy( | |
learner = BasicDQNLearner( | |
approximator = NeuralNetworkApproximator( | |
model = Chain( | |
Dense(ns, 128*2, relu; initW = glorot_uniform(rng)), | |
Dense(128*2, 128*2, relu; initW = glorot_uniform(rng)), | |
Dense(128*2, na; initW = glorot_uniform(rng)), | |
) |> cpu, | |
optimizer = ADAM(), | |
), | |
batch_size = 32, | |
min_replay_history = 100, | |
loss_func = huber_loss, | |
rng = rng, | |
), | |
explorer = EpsilonGreedyExplorer( | |
kind = :exp, | |
ϵ_stable = 0.01, | |
decay_steps = 1000, | |
rng = rng, | |
is_break_tie = true | |
), | |
), | |
trajectory = CircularCompactSARTSATrajectory( | |
capacity = 1000, | |
state_type = Int, | |
state_size = ns, | |
), | |
) | |
hook = TotalRewardPerEpisode() | |
run(DynamicStyle(env), NumAgentStyle(env), agent, env, StopAfterEpisode(1000), hook) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment