Skip to content

Instantly share code, notes, and snippets.

@djbyrne
djbyrne / dqn_mlp.py
Last active March 26, 2020 08:12
Simple network to be used with a DQN
class DQN(nn.Module):
"""
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
@djbyrne
djbyrne / replay_buffer.py
Created March 26, 2020 08:14
Basic Replay Buffer
# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])
class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ReplayBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""
class Agent:
"""
Base Agent class handeling the interaction with the environment
Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
class DQNLightning(pl.LightningModule):
""" Basic DQN Model """
def __init__(self, hparams: argparse.Namespace) -> None:
super().__init__()
self.hparams = hparams
self.env = gym.make(self.hparams.env)
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n
trainer = pl.Trainer.from_argparse_args(args,resume_from_checkpoint=CHECKPOINT_PATH)
trainer.fit(model)
trainer.test(model)
@djbyrne
djbyrne / test_dqn.py
Last active July 9, 2020 20:26
Test DQN Pong
trainer = pl.Trainer.from_argparse_args(args,resume_from_checkpoint=CHECKPOINT_PATH)
trainer.test(model)
class LunarLanderDQN(DQN):
def __init__(
self,
env: str,
eps_last_frame: int = 10000,
sync_rate: int = 10,
learning_rate: float = 1e-2,
batch_size: int = 16,
replay_size: int = 10000,
def train_dataloader(self) -> DataLoader:
dataset = ExperienceSourceDataset(self.train_batch)
return DataLoader(dataset=dataset, batch_size=self.batch_size)
@djbyrne
djbyrne / DQN Pong Bolts Example.py
Last active October 7, 2020 06:37
DQN Pong Bolts Example
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.dqn_model import DQN
parser = argparse.ArgumentParser(add_help=False)
# Trainer args
parser = pl.Trainer.add_argparse_args(parser)
# Model args
parser = DQN.add_model_specific_args(parser)