TD3.ipynb
Last active February 5, 2023 02:02
def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
logits =
# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]]
policy_loss = -log_prob_actions.mean()
# entropy loss
class VanillaPolicyGradient(pl.LightningModule):
def __init__(
env: str,
gamma: float = 0.99,
lr: float = 0.01,
batch_size: int = 8,
n_steps: int = 10,
avg_reward_len: int = 100,
entropy_beta: float = 0.01,
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, scaled_rewards = batch
loss = self.loss(states, actions, scaled_rewards)
log = {
"episodes": self.done_episodes,
"reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,

Last active October 7, 2020 07:24
def train_batch(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
while True:
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-2, batch_size: int = 8,
n_steps: int = 10, avg_reward_len: int = 100, entropy_beta: float = 0.01,
epoch_len: int = 1000, *args, **kwargs) -> None:
# Model components
self.env = gym.make(env) = MLP(self.env.observation_space.shape, self.env.action_space.n)
self.agent = PolicyAgent(
def __init__(self, env, ...):
# Environemnt
self.env = gym.make(env)
self.obs_shape = self.env.observation_space.shape
self.n_actions = self.env.action_space.n
djbyrne / DQN Pong Bolts
Last active October 7, 2020 06:37
DQN Pong Bolts Example
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.dqn_model import DQN
parser = argparse.ArgumentParser(add_help=False)
# Trainer args
parser = pl.Trainer.add_argparse_args(parser)
# Model args
parser = DQN.add_model_specific_args(parser)