This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pl_bolts.models.rl.common import wrappers, cli | |
from pl_bolts.models.rl.dqn_model import DQN | |
parser = argparse.ArgumentParser(add_help=False) | |
# Trainer args | |
parser = pl.Trainer.add_argparse_args(parser) | |
# Model args | |
parser = DQN.add_model_specific_args(parser) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DQNLightning(pl.LightningModule): | |
""" Basic DQN Model """ | |
def __init__(self, hparams: argparse.Namespace) -> None: | |
super().__init__() | |
self.hparams = hparams | |
self.env = gym.make(self.hparams.env) | |
obs_size = self.env.observation_space.shape[0] | |
n_actions = self.env.action_space.n |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Agent: | |
""" | |
Base Agent class handeling the interaction with the environment | |
Args: | |
env: training environment | |
replay_buffer: replay buffer storing experiences | |
""" | |
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RLDataset(IterableDataset): | |
""" | |
Iterable Dataset containing the ReplayBuffer | |
which will be updated with new experiences during training | |
Args: | |
buffer: replay buffer | |
sample_size: number of experiences to sample at a time | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Named tuple for storing experience steps gathered in training | |
Experience = collections.namedtuple( | |
'Experience', field_names=['state', 'action', 'reward', | |
'done', 'new_state']) | |
class ReplayBuffer: | |
""" | |
Replay Buffer for storing past experiences allowing the agent to learn from them | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DQN(nn.Module): | |
""" | |
Simple MLP network | |
Args: | |
obs_size: observation/state size of the environment | |
n_actions: number of discrete actions available in the environment | |
hidden_size: size of hidden layers | |
""" | |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder