Deep Q-learning example to play Doom with PyTorch
#!/usr/bin/env python3 | |
""" | |
Brandon L Morris | |
Adapted from work by E. Culurciello | |
""" | |
from absl import app, flags | |
import itertools | |
from random import sample, randint, random | |
import numpy as np | |
from skimage.transform import resize | |
from time import time, sleep | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from tqdm import trange | |
from vizdoom import DoomGame, Mode, ScreenFormat, ScreenResolution | |
FLAGS = flags.FLAGS | |
frame_repeat = 12 | |
resolution = (30, 45) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Default configuration file path | |
default_config_file_path = "/home/bmorris/code/shared-code/ViZDoom/scenarios/simpler_basic.cfg" | |
def preprocess(img): | |
return torch.from_numpy(resize(img, resolution).astype(np.float32)) | |
def game_state(game): | |
return preprocess(game.get_state().screen_buffer) | |
class ReplayMemory: | |
def __init__(self, capacity): | |
channels = 1 | |
state_shape = (capacity, channels, *resolution) | |
self.s1 = torch.zeros(state_shape, dtype=torch.float32).to(device) | |
self.s2 = torch.zeros(state_shape, dtype=torch.float32).to(device) | |
self.a = torch.zeros(capacity, dtype=torch.long).to(device) | |
self.r = torch.zeros(capacity, dtype=torch.float32).to(device) | |
self.isterminal = torch.zeros(capacity, dtype=torch.float32).to(device) | |
self.capacity = capacity | |
self.size = 0 | |
self.pos = 0 | |
def add_transition(self, s1, action, s2, isterminal, reward): | |
idx = self.pos | |
self.s1[idx,0,:,:] = s1 | |
self.a[idx] = action | |
if not isterminal: | |
self.s2[idx,0,:,:] = s2 | |
self.isterminal[idx] = isterminal | |
self.r[idx] = reward | |
self.pos = (self.pos + 1) % self.capacity | |
self.size = min(self.size + 1, self.capacity) | |
def get_sample(self, size): | |
idx = sample(range(0, self.size), size) | |
return (self.s1[idx], self.a[idx], self.s2[idx], self.isterminal[idx], | |
self.r[idx]) | |
class QNet(nn.Module): | |
def __init__(self, available_actions_count): | |
super(QNet, self).__init__() | |
self.conv1 = nn.Conv2d(1, 8, kernel_size=6, stride=3) # 8x9x14 | |
self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) # 8x4x6 = 192 | |
self.fc1 = nn.Linear(192, 128) | |
self.fc2 = nn.Linear(128, available_actions_count) | |
self.criterion = nn.MSELoss() | |
self.optimizer = torch.optim.SGD(self.parameters(), FLAGS.learning_rate) | |
self.memory = ReplayMemory(capacity=FLAGS.replay_memory) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = x.view(-1, 192) | |
x = F.relu(self.fc1(x)) | |
return self.fc2(x) | |
def get_best_action(self, state): | |
q = self(state) | |
_, index = torch.max(q, 1) | |
return index | |
def train_step(self, s1, target_q): | |
output = self(s1) | |
loss = self.criterion(output, target_q) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
return loss | |
def learn_from_memory(self): | |
if self.memory.size < FLAGS.batch_size: return | |
s1, a, s2, isterminal, r = self.memory.get_sample(FLAGS.batch_size) | |
q = self(s2).detach() | |
q2, _ = torch.max(q, dim=1) | |
target_q = self(s1).detach() | |
idxs = (torch.arange(target_q.shape[0]), a) | |
target_q[idxs] = r + FLAGS.discount * (1-isterminal) * q2 | |
self.train_step(s1, target_q) | |
def find_eps(epoch): | |
"""Balance exploration and exploitation as we keep learning""" | |
start, end = 1.0, 0.1 | |
const_epochs, decay_epochs = .1*FLAGS.epochs, .6*FLAGS.epochs | |
if epoch < const_epochs: | |
return start | |
elif epoch > decay_epochs: | |
return end | |
# Linear decay | |
progress = (epoch-const_epochs)/(decay_epochs-const_epochs) | |
return start - progress * (start - end) | |
def perform_learning_step(epoch, game, model, actions): | |
s1 = game_state(game) | |
if random() <= find_eps(epoch): | |
a = torch.tensor(randint(0, len(actions) - 1)).long() | |
else: | |
s1 = s1.reshape([1, 1, *resolution]) | |
a = model.get_best_action(s1.to(device)) | |
reward = game.make_action(actions[a], frame_repeat) | |
if game.is_episode_finished(): | |
isterminal, s2 = 1., None | |
else: | |
isterminal = 0. | |
s2 = game_state(game) | |
model.memory.add_transition(s1, a, s2, isterminal, reward) | |
model.learn_from_memory() | |
def initialize_vizdoom(config): | |
game = DoomGame() | |
game.load_config(config) | |
game.set_window_visible(False) | |
game.set_mode(Mode.PLAYER) | |
game.set_screen_format(ScreenFormat.GRAY8) | |
game.set_screen_resolution(ScreenResolution.RES_640X480) | |
game.init() | |
return game | |
def train(game, model, actions): | |
time_start = time() | |
print("Saving the network weigths to:", FLAGS.save_path) | |
for epoch in range(FLAGS.epochs): | |
print(f'Epoch {epoch+1}') | |
episodes_finished = 0 | |
scores = np.array([]) | |
game.new_episode() | |
for learning_step in trange(FLAGS.iters, leave=False): | |
perform_learning_step(epoch, game, model, actions) | |
if game.is_episode_finished(): | |
score = game.get_total_reward() | |
scores = np.append(scores, score) | |
game.new_episode() | |
episodes_finished += 1 | |
print(f'Completed {episodes_finished} episodes') | |
print(f'Mean: {scores.mean():.1f} +/- {scores.std():.1f}') | |
print("Testing...") | |
test(FLAGS.test_episodes, game, model, actions) | |
torch.save(model, FLAGS.save_path) | |
print(f'Total elapsed time: {(time()-time_start):.2f} minutes') | |
def test(iters, game, model, actions): | |
scores = np.array([]) | |
for _ in trange(FLAGS.test_episodes, leave=False): | |
game.new_episode() | |
while not game.is_episode_finished(): | |
state = game_state(game) | |
state = state.reshape([1, 1, resolution[0], resolution[1]]) | |
a_idx = model.get_best_action(state.to(device)) | |
game.make_action(actions[a_idx], frame_repeat) | |
r = game.get_total_reward() | |
scores = np.append(scores, r) | |
print(f'Results: mean: {scores.mean():.1f} +/- {scores.std():.1f}') | |
def watch_episodes(game, model, actions): | |
game.set_window_visible(True) | |
game.set_mode(Mode.ASYNC_PLAYER) | |
game.init() | |
for episode in range(FLAGS.watch_episodes): | |
game.new_episode(f'episode-{episode}') | |
while not game.is_episode_finished(): | |
state = game_state(game) | |
state = state.reshape([1, 1, resolution[0], resolution[1]]) | |
a_idx = model.get_best_action(state.to(device)) | |
game.set_action(actions[a_idx]) | |
for _ in range(frame_repeat): | |
game.advance_action() | |
sleep(1.0) | |
score = game.get_total_reward() | |
print(f'Total score: {score}') | |
def main(_): | |
game = initialize_vizdoom(FLAGS.config) | |
n = game.get_available_buttons_size() | |
actions = [list(a) for a in itertools.product([0, 1], repeat=n)] | |
if FLAGS.load_model: | |
print(f'Loading model from: {FLAGS.save_path}') | |
model = torch.load(FLAGS.save_path).to(device) | |
else: | |
model = QNet(len(actions)).to(device) | |
print("Starting the training!") | |
if not FLAGS.skip_training: | |
train(game, model, actions) | |
game.close() | |
print("======================================") | |
watch_episodes(game, model, actions) | |
if __name__ == '__main__': | |
flags.DEFINE_integer('batch_size', 64, 'Batch size') | |
flags.DEFINE_float('learning_rate', 0.00025, 'Learning rate') | |
flags.DEFINE_float('discount', 0.99, 'Discount factor') | |
flags.DEFINE_integer('replay_memory', 10000, 'Replay memory capacity') | |
flags.DEFINE_integer('epochs', 20, 'Number of epochs') | |
flags.DEFINE_integer('iters', 2000, 'Iterations per epoch') | |
flags.DEFINE_integer('watch_episodes', 10, 'Trained episodes to watch') | |
flags.DEFINE_integer('test_episodes', 100, 'Episodes to test with') | |
flags.DEFINE_string('config', default_config_file_path, | |
'Path to the config file') | |
flags.DEFINE_boolean('skip_training', False, 'Set to skip training') | |
flags.DEFINE_boolean('load_model', False, 'Load the model from disk') | |
flags.DEFINE_string('save_path', 'model-doom.pth', | |
'Path to save/load the model') | |
app.run(main) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment