Created
March 26, 2020 12:50
-
-
Save djbyrne/caff38f59400128b4b1377f59bb6ddde to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Agent: | |
""" | |
Base Agent class handeling the interaction with the environment | |
Args: | |
env: training environment | |
replay_buffer: replay buffer storing experiences | |
""" | |
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: | |
self.env = env | |
self.replay_buffer = replay_buffer | |
self.reset() | |
self.state = self.env.reset() | |
def reset(self) -> None: | |
""" Resents the environment and updates the state""" | |
self.state = self.env.reset() | |
def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: | |
""" | |
Using the given network, decide what action to carry out | |
using an epsilon-greedy policy | |
Args: | |
net: DQN network | |
epsilon: value to determine likelihood of taking a random action | |
device: current device | |
Returns: | |
action | |
""" | |
if np.random.random() < epsilon: | |
action = self.env.action_space.sample() | |
else: | |
state = torch.tensor([self.state]) | |
if device not in ['cpu']: | |
state = state.cuda(device) | |
q_values = net(state) | |
_, action = torch.max(q_values, dim=1) | |
action = int(action.item()) | |
return action | |
@torch.no_grad() | |
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]: | |
""" | |
Carries out a single interaction step between the agent and the environment | |
Args: | |
net: DQN network | |
epsilon: value to determine likelihood of taking a random action | |
device: current device | |
Returns: | |
reward, done | |
""" | |
action = self.get_action(net, epsilon, device) | |
# do step in the environment | |
new_state, reward, done, _ = self.env.step(action) | |
exp = Experience(self.state, action, reward, done, new_state) | |
self.replay_buffer.append(exp) | |
self.state = new_state | |
if done: | |
self.reset() | |
return reward, done |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment