Skip to content

Instantly share code, notes, and snippets.

@PaulDebus
Last active August 19, 2020 08:56
Show Gist options
  • Save PaulDebus/dcffd3d256fef2170cd328a8c9867d47 to your computer and use it in GitHub Desktop.
Save PaulDebus/dcffd3d256fef2170cd328a8c9867d47 to your computer and use it in GitHub Desktop.
using ReinforcementLearningBase
using Random
using ReinforcementLearning
using Flux
"""
ReinforcementLearning environment for the 15 Puzzle game.
As Julia has its arrays in column major form, the game is transposed from the human version
"""
const MAX_MOVES = 2000
mutable struct GameEnv <: AbstractEnv
size::Int
grid::Array{Int, 2}
steps::Int
action::CartesianIndex
changed::Bool
end
function RLBase.get_actions(env::GameEnv)
collect(1:env.size*env.size)
end
function count_inversions(env::GameEnv)::Int
inv_count = 0
for i in 1:env.size*env.size
for j in i:env.size*env.size
if env.grid[i] >0 && env.grid[j] > 0 && env.grid[i] > env.grid[j]
inv_count += 1
end
end
end
inv_count
end
x_pos(env::GameEnv) = env.size - findfirst(iszero.(env.grid))[2] + 1
function issolvable(env::GameEnv)::Bool
inv_count = count_inversions(env)
if isodd(env.size)
return iseven(inv_count)
else
pos = x_pos(env)
if isodd(pos)
return iseven(inv_count)
else
return isodd(inv_count)
end
end
end
function RLBase.get_terminal(env::GameEnv)
if env.steps > MAX_MOVES
return true
end
if env.grid[3] != 0
return false
end
issorted(env.grid[1:end-1])
end
function RLBase.get_reward(env::GameEnv)
if env.steps > MAX_MOVES
return -1
end
if get_terminal(env)
return 1
end
if env.changed
return 0
else
return - 1/MAX_MOVES
end
end
RLBase.get_state(env::GameEnv) = vec(env.grid')
function RLBase.reset!(env::GameEnv)
s = env.size
env.grid = reshape(shuffle(0:s*s-1), s, s)
while !issolvable(env)
env.grid = reshape(shuffle(0:s*s-1), s, s)
end
env.steps = 0
env.action = CartesianIndex(1,1)
env.changed = false
nothing
end
function GameEnv(s::Int=3)
grid = reshape(shuffle(0:s*s-1), s, s)
GameEnv(s, grid, 0, CartesianIndex(1,1), false)
end
function (env::GameEnv)(a)
env.changed = false
if env.steps > MAX_MOVES
return
end
env.steps += 1
action = CartesianIndices(env.grid)[a]
env.action = action
cell = env.grid[action]
if cell == 0
return
end
for offset in ((-1, 0), (1, 0), (0, -1), (0, 1) )
field = action + CartesianIndex(offset)
if checkbounds(Bool, env.grid, field) && env.grid[field] == 0
env.grid[field] = cell
env.grid[action] = 0
env.changed = true
return
end
end
end
rng = MersenneTwister(123);
env = GameEnv(3)
ns, na = length(get_state(env)), length(get_actions(env))
struct MyHook <: AbstractHook
end
function (hook::MyHook)(::PreActStage, agent, env, action)
@show(action)
end
agent = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128*2, relu; initW = glorot_uniform(rng)),
Dense(128*2, 128*2, relu; initW = glorot_uniform(rng)),
Dense(128*2, na; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
batch_size = 32,
min_replay_history = 100,
loss_func = huber_loss,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 1000,
rng = rng,
is_break_tie = true
),
),
trajectory = CircularCompactSARTSATrajectory(
capacity = 1000,
state_type = Int,
state_size = ns,
),
)
hook = TotalRewardPerEpisode()
run(DynamicStyle(env), NumAgentStyle(env), agent, env, StopAfterEpisode(1000), hook)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment