Skip to content

Instantly share code, notes, and snippets.

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@djbyrne
djbyrne / TD3.ipynb
Last active February 5, 2023 02:02
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
logits = self.net(states)
# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]]
policy_loss = -log_prob_actions.mean()
# entropy loss
class VanillaPolicyGradient(pl.LightningModule):
def __init__(
self,
env: str,
gamma: float = 0.99,
lr: float = 0.01,
batch_size: int = 8,
n_steps: int = 10,
avg_reward_len: int = 100,
entropy_beta: float = 0.01,
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, scaled_rewards = batch
loss = self.loss(states, actions, scaled_rewards)
log = {
"episodes": self.done_episodes,
"reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
}
@djbyrne
djbyrne / train_batch.py
Last active October 7, 2020 07:24
train_batch
def train_batch(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
while True:
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
self.episode_rewards.append(reward)
self.batch_actions.append(action)
def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-2, batch_size: int = 8,
n_steps: int = 10, avg_reward_len: int = 100, entropy_beta: float = 0.01,
epoch_len: int = 1000, *args, **kwargs) -> None:
super().__init__()
# Model components
self.env = gym.make(env)
self.net = MLP(self.env.observation_space.shape, self.env.action_space.n)
self.agent = PolicyAgent(self.net)
LitRLModel(pl.LightningModule):
def __init__(self, env, ...):
# Environemnt
self.env = gym.make(env)
self.env.seed(123)
self.obs_shape = self.env.observation_space.shape
self.n_actions = self.env.action_space.n
@djbyrne
djbyrne / DQN Pong Bolts Example.py
Last active October 7, 2020 06:37
DQN Pong Bolts Example
from pl_bolts.models.rl.common import wrappers, cli
from pl_bolts.models.rl.dqn_model import DQN
parser = argparse.ArgumentParser(add_help=False)
# Trainer args
parser = pl.Trainer.add_argparse_args(parser)
# Model args
parser = DQN.add_model_specific_args(parser)