Skip to content

Instantly share code, notes, and snippets.

@tcfuji
Last active July 8, 2021 19:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tcfuji/fb531b69f37ea6a7f8771c42e8b27f58 to your computer and use it in GitHub Desktop.
Save tcfuji/fb531b69f37ea6a7f8771c42e8b27f58 to your computer and use it in GitHub Desktop.
Deep Q-Learning in Julia with OpenSpiel
"""
Deep Q-Learning for turn-based board games in Julia using the OpenSpiel package. Just install packages and run 'julia dql.jl'.
Trains against a random agent (usually enough to train a decently strong agent for games like breakthrough). Based on work I did
in pytorch at PNNL. Learned Flux, OpenSpiel.jl for this over a weekend.
"""
using OpenSpiel
using StatsBase
using Flux
using BSON
using ProgressBars
abstract type Learner end
mutable struct QLearner <: Learner
model
target_model
end
mutable struct ReplayBuffer
buffer::Array
capacity::Int
next_entry_index::Int
end
Base.length(rb::ReplayBuffer) = length(rb.buffer)
Base.iterate(rb::ReplayBuffer) = iterate(rb.buffer)
function epsilon_greedy(learner::QLearner, state, id::Int32, current_epsiode::Int,
train_eps::Int, greedy::Bool = false)
epsilon = greedy ? 0 : max(1 - (current_epsiode / train_eps), 0.1)
if rand() < epsilon
action = rand(legal_actions(state, id))
else
state_vec = observation_tensor(state, id)
q_vals = learner.model(Flux.unsqueeze(state_vec, 2))
illegal_weights = -10000 * (1 .- legal_actions_mask(state, id))
legal_q_vals = illegal_weights .+ q_vals
legal_q_vals = dropdims(legal_q_vals; dims=2)
action = argmax(legal_q_vals) - 1
end
return action
end
function append!(rb::ReplayBuffer, t)
if length(rb) <= rb.capacity
Base.append!(rb.buffer, t)
else
rb.buffer[rb.next_entry_index] = t[1]
rb.next_entry_index += 1
rb.next_entry_index %= rb.capacity
if rb.next_entry_index == 0
rb.next_entry_index += 1
end
end
end
function rollout(game, learner::QLearner, rb::ReplayBuffer, current_ep::Int,
train_ep::Int, greedy::Bool = false)
state = new_initial_state(game)
# print(typeof(state))
prev_state = Nothing
while !is_terminal(state)
current_id = current_player(state)
if current_id == 0
actions = epsilon_greedy(learner, state, current_id, current_ep, train_ep, greedy)
else
actions = rand(legal_actions(state))
end
prev_state = deepcopy(state)
apply_action(state, actions)
rts = returns(state)
if !greedy
# println(rts)
state_vec = observation_tensor(prev_state, current_id)
next_state_vec = observation_tensor(state, current_id)
t = (state=state_vec, next_state=next_state_vec, action=actions, reward=rts[1], is_done=is_terminal(state))
append!(rb, [t])
end
end
rts = returns(state)
return rts
end
function update(game, learner::QLearner, batch::Array, γ::Float64, opt, loss_func = Flux.mse)
states = Flux.batch([t.state for t in batch])
next_states = Flux.batch([t.next_state for t in batch])
actions = Flux.batch([t.action for t in batch])
rewards = Flux.batch([t.reward for t in batch])
is_dones = Flux.batch([t.is_done for t in batch])
target_q_vals = learner.target_model(next_states)
max_next_q = dropdims(maximum(target_q_vals, dims=1); dims=1)
target = rewards .+ (1 .- is_dones) .* γ .* max_next_q
action_indices = Flux.onehotbatch(actions, 0:(num_distinct_actions(game) - 1))
ps = Flux.params(learner.model)
loss(s, t) = loss_func(learner.model(s)[action_indices], t)
Flux.train!(loss, ps, [(states, target)], opt)
end
function train_dqn(game, model, train_eps=5000, train_interval=10, eval_interval=500, γ=0.95, opt=ADAM(1e-4))
rb = ReplayBuffer([], Int(1e4), 1)
learner = QLearner(model, deepcopy(model))
for i in ProgressBar(1:train_eps)
rts = rollout(game, learner, rb, i, train_eps)
if i % train_interval == 0
batch = sample(rb.buffer, 64)
update(game, learner, batch, γ, opt)
end
if i % eval_interval == 0
eval_wins = Float32[]
for j in 1:100
rts = rollout(game, learner, rb, j, train_eps, true)
if rts[1] > 0.0
Base.append!(eval_wins, 1.0)
else
Base.append!(eval_wins, 0.0)
end
end
println("Win rate: $(mean(eval_wins))")
bson("model.bson", model=learner.model)
learner.target_model = deepcopy(learner.model)
end
end
println("Training done!")
end
name = "breakthrough"
game = load_game(name)
info_state_length = prod(observation_tensor_shape(game))
model = Chain(Dense(info_state_length, 128),
Dense(128, 128),
BatchNorm(128, relu),
Dense(128, 128),
BatchNorm(128, relu),
Dense(128, 128),
BatchNorm(128, relu),
Dense(128, 128),
BatchNorm(128, relu),
Dense(128, num_distinct_actions(game)))
train_dqn(game, model, Int(1e5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment