Skip to content

Instantly share code, notes, and snippets.

def loss(self, states, actions, scaled_rewards) -> torch.Tensor:
logits = self.net(states)
# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]]
policy_loss = -log_prob_actions.mean()
# entropy loss
LitRLModel(pl.LightningModule):
def __init__(self, env, ...):
# Environemnt
self.env = gym.make(env)
self.env.seed(123)
self.obs_shape = self.env.observation_space.shape
self.n_actions = self.env.action_space.n
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, scaled_rewards = batch
loss = self.loss(states, actions, scaled_rewards)
log = {
"episodes": self.done_episodes,
"reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
}
def train_dataloader(self) -> DataLoader:
dataset = ExperienceSourceDataset(self.train_batch)
return DataLoader(dataset=dataset, batch_size=self.batch_size)
@djbyrne
djbyrne / train_batch.py
Last active October 7, 2020 07:24
train_batch
def train_batch(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
while True:
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
self.episode_rewards.append(reward)
self.batch_actions.append(action)
def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-2, batch_size: int = 8,
n_steps: int = 10, avg_reward_len: int = 100, entropy_beta: float = 0.01,
epoch_len: int = 1000, *args, **kwargs) -> None:
super().__init__()
# Model components
self.env = gym.make(env)
self.net = MLP(self.env.observation_space.shape, self.env.action_space.n)
self.agent = PolicyAgent(self.net)
class VanillaPolicyGradient(pl.LightningModule):
def __init__(
self,
env: str,
gamma: float = 0.99,
lr: float = 0.01,
batch_size: int = 8,
n_steps: int = 10,
avg_reward_len: int = 100,
entropy_beta: float = 0.01,
class LunarLanderDQN(DQN):
def __init__(
self,
env: str,
eps_last_frame: int = 10000,
sync_rate: int = 10,
learning_rate: float = 1e-2,
batch_size: int = 16,
replay_size: int = 10000,
@djbyrne
djbyrne / test_dqn.py
Last active July 9, 2020 20:26
Test DQN Pong
trainer = pl.Trainer.from_argparse_args(args,resume_from_checkpoint=CHECKPOINT_PATH)
trainer.test(model)
trainer = pl.Trainer.from_argparse_args(args,resume_from_checkpoint=CHECKPOINT_PATH)
trainer.fit(model)
trainer.test(model)