Last active
July 8, 2021 19:11
-
-
Save tcfuji/fb531b69f37ea6a7f8771c42e8b27f58 to your computer and use it in GitHub Desktop.
Deep Q-Learning in Julia with OpenSpiel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Deep Q-Learning for turn-based board games in Julia using the OpenSpiel package. Just install packages and run 'julia dql.jl'. | |
Trains against a random agent (usually enough to train a decently strong agent for games like breakthrough). Based on work I did | |
in pytorch at PNNL. Learned Flux, OpenSpiel.jl for this over a weekend. | |
""" | |
using OpenSpiel | |
using StatsBase | |
using Flux | |
using BSON | |
using ProgressBars | |
abstract type Learner end | |
mutable struct QLearner <: Learner | |
model | |
target_model | |
end | |
mutable struct ReplayBuffer | |
buffer::Array | |
capacity::Int | |
next_entry_index::Int | |
end | |
Base.length(rb::ReplayBuffer) = length(rb.buffer) | |
Base.iterate(rb::ReplayBuffer) = iterate(rb.buffer) | |
function epsilon_greedy(learner::QLearner, state, id::Int32, current_epsiode::Int, | |
train_eps::Int, greedy::Bool = false) | |
epsilon = greedy ? 0 : max(1 - (current_epsiode / train_eps), 0.1) | |
if rand() < epsilon | |
action = rand(legal_actions(state, id)) | |
else | |
state_vec = observation_tensor(state, id) | |
q_vals = learner.model(Flux.unsqueeze(state_vec, 2)) | |
illegal_weights = -10000 * (1 .- legal_actions_mask(state, id)) | |
legal_q_vals = illegal_weights .+ q_vals | |
legal_q_vals = dropdims(legal_q_vals; dims=2) | |
action = argmax(legal_q_vals) - 1 | |
end | |
return action | |
end | |
function append!(rb::ReplayBuffer, t) | |
if length(rb) <= rb.capacity | |
Base.append!(rb.buffer, t) | |
else | |
rb.buffer[rb.next_entry_index] = t[1] | |
rb.next_entry_index += 1 | |
rb.next_entry_index %= rb.capacity | |
if rb.next_entry_index == 0 | |
rb.next_entry_index += 1 | |
end | |
end | |
end | |
function rollout(game, learner::QLearner, rb::ReplayBuffer, current_ep::Int, | |
train_ep::Int, greedy::Bool = false) | |
state = new_initial_state(game) | |
# print(typeof(state)) | |
prev_state = Nothing | |
while !is_terminal(state) | |
current_id = current_player(state) | |
if current_id == 0 | |
actions = epsilon_greedy(learner, state, current_id, current_ep, train_ep, greedy) | |
else | |
actions = rand(legal_actions(state)) | |
end | |
prev_state = deepcopy(state) | |
apply_action(state, actions) | |
rts = returns(state) | |
if !greedy | |
# println(rts) | |
state_vec = observation_tensor(prev_state, current_id) | |
next_state_vec = observation_tensor(state, current_id) | |
t = (state=state_vec, next_state=next_state_vec, action=actions, reward=rts[1], is_done=is_terminal(state)) | |
append!(rb, [t]) | |
end | |
end | |
rts = returns(state) | |
return rts | |
end | |
function update(game, learner::QLearner, batch::Array, γ::Float64, opt, loss_func = Flux.mse) | |
states = Flux.batch([t.state for t in batch]) | |
next_states = Flux.batch([t.next_state for t in batch]) | |
actions = Flux.batch([t.action for t in batch]) | |
rewards = Flux.batch([t.reward for t in batch]) | |
is_dones = Flux.batch([t.is_done for t in batch]) | |
target_q_vals = learner.target_model(next_states) | |
max_next_q = dropdims(maximum(target_q_vals, dims=1); dims=1) | |
target = rewards .+ (1 .- is_dones) .* γ .* max_next_q | |
action_indices = Flux.onehotbatch(actions, 0:(num_distinct_actions(game) - 1)) | |
ps = Flux.params(learner.model) | |
loss(s, t) = loss_func(learner.model(s)[action_indices], t) | |
Flux.train!(loss, ps, [(states, target)], opt) | |
end | |
function train_dqn(game, model, train_eps=5000, train_interval=10, eval_interval=500, γ=0.95, opt=ADAM(1e-4)) | |
rb = ReplayBuffer([], Int(1e4), 1) | |
learner = QLearner(model, deepcopy(model)) | |
for i in ProgressBar(1:train_eps) | |
rts = rollout(game, learner, rb, i, train_eps) | |
if i % train_interval == 0 | |
batch = sample(rb.buffer, 64) | |
update(game, learner, batch, γ, opt) | |
end | |
if i % eval_interval == 0 | |
eval_wins = Float32[] | |
for j in 1:100 | |
rts = rollout(game, learner, rb, j, train_eps, true) | |
if rts[1] > 0.0 | |
Base.append!(eval_wins, 1.0) | |
else | |
Base.append!(eval_wins, 0.0) | |
end | |
end | |
println("Win rate: $(mean(eval_wins))") | |
bson("model.bson", model=learner.model) | |
learner.target_model = deepcopy(learner.model) | |
end | |
end | |
println("Training done!") | |
end | |
name = "breakthrough" | |
game = load_game(name) | |
info_state_length = prod(observation_tensor_shape(game)) | |
model = Chain(Dense(info_state_length, 128), | |
Dense(128, 128), | |
BatchNorm(128, relu), | |
Dense(128, 128), | |
BatchNorm(128, relu), | |
Dense(128, 128), | |
BatchNorm(128, relu), | |
Dense(128, 128), | |
BatchNorm(128, relu), | |
Dense(128, num_distinct_actions(game))) | |
train_dqn(game, model, Int(1e5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment