This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#VizDoom general import | |
from vizdoom import * | |
#Import de libs auxiliares | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from collections import deque | |
#Torch imports | |
import torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#On the tutorial, we use rather use Intel MKL because of its performance and availability to be used on lower end PCs | |
#Anyway, if you want it, you can try it on CUDA by uncommenting it. | |
#use_cuda = torch.cuda.is_available() | |
use_cuda = False | |
device = torch.device('cuda' if use_cuda else 'cpu') | |
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor | |
DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Doom game assets: game env creation and frame resize functiom | |
#We are using the config and scenario paths from Thomas Simonini tutorial | |
#Those files are available on my GitHub repo from the post. | |
game = DoomGame() | |
game.load_config("health_gathering.cfg") | |
game.set_doom_scenario_path("health_gathering.wad") | |
game.set_seed(42) | |
game.init() | |
doom_actions = np.identity(3, dtype=int).tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FrameStacker: | |
def __init__(self): | |
""" | |
We can set the memory size here. | |
Our memory is a deque and, on each stack, it concatenates the frames in memory along the axis 0 | |
We also have a transformer from torch that handles the resizing. | |
""" | |
self.memory_size = 4 | |
self.memory = deque(maxlen=self.memory_size) | |
self.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PolicyNetwork(nn.Module): | |
def __init__(self, lr): | |
""" | |
We've put Tanh as activation in order to introduce variance on the learning | |
by making the model more sensible. | |
I encourage you to try other architectures, optimizers and hyperparameters | |
""" | |
super(PolicyNetwork, self).__init__() | |
self.num_actions = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update_policy(policy_network, rewards, log_probs): | |
discounted_rewards = [] | |
for t in range(len(rewards)): | |
Gt = 0 | |
pw = 0 | |
for r in rewards[t:]: | |
Gt = Gt + GAMMA**pw * r | |
pw += 1 | |
discounted_rewards.append(Gt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.tensorboard import SummaryWriter | |
writer = SummaryWriter(flush_secs = 40) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#here we set the global variables | |
GAMMA = .95 | |
EPISODES = 5000 | |
learning_rate = 0.01 | |
#our net and frame-stacker | |
stacker = FrameStacker() | |
policy_net = PolicyNetwork(lr=learning_rate).to(device) | |
#some lists to write the values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for episode in range(EPISODES): | |
game.new_episode() | |
curr_health = game.get_state().game_variables[0] | |
state = game.get_state().screen_buffer | |
state = stacker.stack(state) | |
log_probs = [] | |
rewards = [] | |
done = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
) | |
func hello() { | |
fmt.Println("Hello world goroutine") | |
} | |
func main() { |
OlderNewer