This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Função que normaliza os frames para facilitar o trabalho da rede | |
def state_reshape(state): | |
state = np.swapaxes(state, -3, -1) | |
state = np.swapaxes(state, -1, -2) | |
return state / 255. | |
# Função que treina o nosso agente por uma quantidade de instantes (timesteps) | |
def train(agent, env, total_timesteps): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(self, batch_size=32): | |
# Para começarmos a treinar só depois de já ter uma grande quantidade de experiências | |
if 10000 > len(self.memory._storage): | |
return | |
# Atualizando o beta da memória | |
self.beta = self.beta + self.update_count/1000000 * (1.0 - self.beta) | |
# Pegando as nossas experiências da memória | |
(states, actions, rewards, next_states, dones, weights, batch_indexes) = self.memory.sample(batch_size, self.beta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def remember(self, state, action, reward, new_state, done): | |
# Guarda as informações na nossa memória | |
self.memory.add(state, action, reward, new_state, done) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def act(self, state): | |
# Decrescemos o nosso epsilon | |
self.epsilon *= self.epsilon_decay | |
self.epsilon = max(self.epsilon, self.min_epsilon) | |
# Se o número aleatório tirarmos for menor que epsilon, escolhemos uma ação aleatória | |
if np.random.random() < self.epsilon: | |
# Escolhendo uma ação aleatória | |
action = self.action_space.sample() | |
return action |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
from agents.dqn.prioritized_replay import PrioritizedReplayBuffer | |
class Agent: | |
def __init__(self, observation_space, action_space, lr=2.5e-4, gamma=0.99, tau=0.005): | |
# O Device nos diz que vamos treinar com uma GPU | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Gamma (para cálculo do Q) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
from agents.dqn.prioritized_replay import PrioritizedReplayBuffer | |
class Agent: | |
def __init__(self, observation_space, action_space, lr=2.5e-4, gamma=0.99): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.gamma = gamma | |
self.action_space = action_space |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
class Network(nn.Module): | |
def __init__(self, in_dim: int, out_dim: int): | |
super(Network, self).__init__() | |
self.out_dim = out_dim | |
self.convs = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), | |
nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder