Skip to content

Instantly share code, notes, and snippets.

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@djbyrne
djbyrne / TD3.ipynb
Last active February 5, 2023 02:02
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@djbyrne
djbyrne / dqn_mlp.py
Last active March 26, 2020 08:12
Simple network to be used with a DQN
class DQN(nn.Module):
"""
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
@djbyrne
djbyrne / replay_buffer.py
Created March 26, 2020 08:14
Basic Replay Buffer
# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])
class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ReplayBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""
class Agent:
"""
Base Agent class handeling the interaction with the environment
Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
class DQNLightning(pl.LightningModule):
""" Basic DQN Model """
def __init__(self, hparams: argparse.Namespace) -> None:
super().__init__()
self.hparams = hparams
self.env = gym.make(self.hparams.env)
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n
@djbyrne
djbyrne / DQN Pong Bolts Example.py
Last active October 7, 2020 06:37
DQN Pong Bolts Example
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.dqn_model import DQN
parser = argparse.ArgumentParser(add_help=False)
# Trainer args
parser = pl.Trainer.add_argparse_args(parser)
# Model args
parser = DQN.add_model_specific_args(parser)
trainer = pl.Trainer.from_argparse_args(args,resume_from_checkpoint=CHECKPOINT_PATH)
trainer.fit(model)
trainer.test(model)