Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Deep Q-learning example to play Doom with PyTorch
#!/usr/bin/env python3
"""
Brandon L Morris
Adapted from work by E. Culurciello
"""
from absl import app, flags
import itertools
from random import sample, randint, random
import numpy as np
from skimage.transform import resize
from time import time, sleep
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import trange
from vizdoom import DoomGame, Mode, ScreenFormat, ScreenResolution
FLAGS = flags.FLAGS
frame_repeat = 12
resolution = (30, 45)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Default configuration file path
default_config_file_path = "/home/bmorris/code/shared-code/ViZDoom/scenarios/simpler_basic.cfg"
def preprocess(img):
return torch.from_numpy(resize(img, resolution).astype(np.float32))
def game_state(game):
return preprocess(game.get_state().screen_buffer)
class ReplayMemory:
def __init__(self, capacity):
channels = 1
state_shape = (capacity, channels, *resolution)
self.s1 = torch.zeros(state_shape, dtype=torch.float32).to(device)
self.s2 = torch.zeros(state_shape, dtype=torch.float32).to(device)
self.a = torch.zeros(capacity, dtype=torch.long).to(device)
self.r = torch.zeros(capacity, dtype=torch.float32).to(device)
self.isterminal = torch.zeros(capacity, dtype=torch.float32).to(device)
self.capacity = capacity
self.size = 0
self.pos = 0
def add_transition(self, s1, action, s2, isterminal, reward):
idx = self.pos
self.s1[idx,0,:,:] = s1
self.a[idx] = action
if not isterminal:
self.s2[idx,0,:,:] = s2
self.isterminal[idx] = isterminal
self.r[idx] = reward
self.pos = (self.pos + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def get_sample(self, size):
idx = sample(range(0, self.size), size)
return (self.s1[idx], self.a[idx], self.s2[idx], self.isterminal[idx],
self.r[idx])
class QNet(nn.Module):
def __init__(self, available_actions_count):
super(QNet, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=6, stride=3) # 8x9x14
self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) # 8x4x6 = 192
self.fc1 = nn.Linear(192, 128)
self.fc2 = nn.Linear(128, available_actions_count)
self.criterion = nn.MSELoss()
self.optimizer = torch.optim.SGD(self.parameters(), FLAGS.learning_rate)
self.memory = ReplayMemory(capacity=FLAGS.replay_memory)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 192)
x = F.relu(self.fc1(x))
return self.fc2(x)
def get_best_action(self, state):
q = self(state)
_, index = torch.max(q, 1)
return index
def train_step(self, s1, target_q):
output = self(s1)
loss = self.criterion(output, target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
def learn_from_memory(self):
if self.memory.size < FLAGS.batch_size: return
s1, a, s2, isterminal, r = self.memory.get_sample(FLAGS.batch_size)
q = self(s2).detach()
q2, _ = torch.max(q, dim=1)
target_q = self(s1).detach()
idxs = (torch.arange(target_q.shape[0]), a)
target_q[idxs] = r + FLAGS.discount * (1-isterminal) * q2
self.train_step(s1, target_q)
def find_eps(epoch):
"""Balance exploration and exploitation as we keep learning"""
start, end = 1.0, 0.1
const_epochs, decay_epochs = .1*FLAGS.epochs, .6*FLAGS.epochs
if epoch < const_epochs:
return start
elif epoch > decay_epochs:
return end
# Linear decay
progress = (epoch-const_epochs)/(decay_epochs-const_epochs)
return start - progress * (start - end)
def perform_learning_step(epoch, game, model, actions):
s1 = game_state(game)
if random() <= find_eps(epoch):
a = torch.tensor(randint(0, len(actions) - 1)).long()
else:
s1 = s1.reshape([1, 1, *resolution])
a = model.get_best_action(s1.to(device))
reward = game.make_action(actions[a], frame_repeat)
if game.is_episode_finished():
isterminal, s2 = 1., None
else:
isterminal = 0.
s2 = game_state(game)
model.memory.add_transition(s1, a, s2, isterminal, reward)
model.learn_from_memory()
def initialize_vizdoom(config):
game = DoomGame()
game.load_config(config)
game.set_window_visible(False)
game.set_mode(Mode.PLAYER)
game.set_screen_format(ScreenFormat.GRAY8)
game.set_screen_resolution(ScreenResolution.RES_640X480)
game.init()
return game
def train(game, model, actions):
time_start = time()
print("Saving the network weigths to:", FLAGS.save_path)
for epoch in range(FLAGS.epochs):
print(f'Epoch {epoch+1}')
episodes_finished = 0
scores = np.array([])
game.new_episode()
for learning_step in trange(FLAGS.iters, leave=False):
perform_learning_step(epoch, game, model, actions)
if game.is_episode_finished():
score = game.get_total_reward()
scores = np.append(scores, score)
game.new_episode()
episodes_finished += 1
print(f'Completed {episodes_finished} episodes')
print(f'Mean: {scores.mean():.1f} +/- {scores.std():.1f}')
print("Testing...")
test(FLAGS.test_episodes, game, model, actions)
torch.save(model, FLAGS.save_path)
print(f'Total elapsed time: {(time()-time_start):.2f} minutes')
def test(iters, game, model, actions):
scores = np.array([])
for _ in trange(FLAGS.test_episodes, leave=False):
game.new_episode()
while not game.is_episode_finished():
state = game_state(game)
state = state.reshape([1, 1, resolution[0], resolution[1]])
a_idx = model.get_best_action(state.to(device))
game.make_action(actions[a_idx], frame_repeat)
r = game.get_total_reward()
scores = np.append(scores, r)
print(f'Results: mean: {scores.mean():.1f} +/- {scores.std():.1f}')
def watch_episodes(game, model, actions):
game.set_window_visible(True)
game.set_mode(Mode.ASYNC_PLAYER)
game.init()
for episode in range(FLAGS.watch_episodes):
game.new_episode(f'episode-{episode}')
while not game.is_episode_finished():
state = game_state(game)
state = state.reshape([1, 1, resolution[0], resolution[1]])
a_idx = model.get_best_action(state.to(device))
game.set_action(actions[a_idx])
for _ in range(frame_repeat):
game.advance_action()
sleep(1.0)
score = game.get_total_reward()
print(f'Total score: {score}')
def main(_):
game = initialize_vizdoom(FLAGS.config)
n = game.get_available_buttons_size()
actions = [list(a) for a in itertools.product([0, 1], repeat=n)]
if FLAGS.load_model:
print(f'Loading model from: {FLAGS.save_path}')
model = torch.load(FLAGS.save_path).to(device)
else:
model = QNet(len(actions)).to(device)
print("Starting the training!")
if not FLAGS.skip_training:
train(game, model, actions)
game.close()
print("======================================")
watch_episodes(game, model, actions)
if __name__ == '__main__':
flags.DEFINE_integer('batch_size', 64, 'Batch size')
flags.DEFINE_float('learning_rate', 0.00025, 'Learning rate')
flags.DEFINE_float('discount', 0.99, 'Discount factor')
flags.DEFINE_integer('replay_memory', 10000, 'Replay memory capacity')
flags.DEFINE_integer('epochs', 20, 'Number of epochs')
flags.DEFINE_integer('iters', 2000, 'Iterations per epoch')
flags.DEFINE_integer('watch_episodes', 10, 'Trained episodes to watch')
flags.DEFINE_integer('test_episodes', 100, 'Episodes to test with')
flags.DEFINE_string('config', default_config_file_path,
'Path to the config file')
flags.DEFINE_boolean('skip_training', False, 'Set to skip training')
flags.DEFINE_boolean('load_model', False, 'Load the model from disk')
flags.DEFINE_string('save_path', 'model-doom.pth',
'Path to save/load the model')
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.