Last active
October 7, 2020 07:26
-
-
Save djbyrne/9c9ea0ed2b564d8bcc91e8251bbb3021 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VanillaPolicyGradient(pl.LightningModule): | |
def __init__( | |
self, | |
env: str, | |
gamma: float = 0.99, | |
lr: float = 0.01, | |
batch_size: int = 8, | |
n_steps: int = 10, | |
avg_reward_len: int = 100, | |
entropy_beta: float = 0.01, | |
epoch_len: int = 1000, | |
**kwargs | |
) -> None: | |
super().__init__() | |
# Hyperparameters | |
self.lr = lr | |
self.batch_size = batch_size | |
self.batches_per_epoch = self.batch_size * epoch_len | |
self.entropy_beta = entropy_beta | |
self.gamma = gamma | |
self.n_steps = n_steps | |
self.save_hyperparameters() | |
# Model components | |
self.env = gym.make(env) | |
self.net = MLP(self.env.observation_space.shape, self.env.action_space.n) | |
self.agent = PolicyAgent(self.net) | |
# Tracking metrics | |
self.total_rewards = [] | |
self.episode_rewards = [] | |
self.done_episodes = 0 | |
self.avg_rewards = 0 | |
self.avg_reward_len = avg_reward_len | |
self.eps = np.finfo(np.float32).eps.item() | |
self.batch_states = [] | |
self.batch_actions = [] | |
self.state = self.env.reset() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Passes in a state x through the network and gets the q_values of each action as an output | |
Args: | |
x: environment state | |
Returns: | |
q values | |
""" | |
output = self.net(x) | |
return output | |
def train_batch( | |
self, | |
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: | |
""" | |
Contains the logic for generating a new batch of data to be passed to the DataLoader | |
Returns: | |
yields a tuple of Lists containing tensors for states, actions and rewards of the batch. | |
""" | |
while True: | |
action = self.agent(self.state, self.device) | |
next_state, reward, done, _ = self.env.step(action[0]) | |
self.episode_rewards.append(reward) | |
self.batch_actions.append(action) | |
self.batch_states.append(self.state) | |
self.state = next_state | |
if done: | |
self.done_episodes += 1 | |
self.state = self.env.reset() | |
self.total_rewards.append(sum(self.episode_rewards)) | |
self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) | |
returns = self.compute_returns(self.episode_rewards) | |
for idx in range(len(self.batch_actions)): | |
yield self.batch_states[idx], self.batch_actions[idx], returns[idx] | |
self.batch_states = [] | |
self.batch_actions = [] | |
self.episode_rewards = [] | |
def compute_returns(self, rewards): | |
""" | |
Calculate the discounted rewards of the batched rewards | |
Args: | |
rewards: list of batched rewards | |
Returns: | |
list of discounted rewards | |
""" | |
reward = 0 | |
returns = [] | |
for r in rewards[::-1]: | |
reward = r + self.gamma * reward | |
returns.insert(0, reward) | |
returns = torch.tensor(returns) | |
returns = (returns - returns.mean()) / (returns.std() + self.eps) | |
return returns | |
def loss(self, states, actions, scaled_rewards) -> torch.Tensor: | |
""" | |
Calculates the loss for VPG | |
Args: | |
states: batched states | |
actions: batch actions | |
scaled_rewards: batche Q values | |
Returns: | |
loss for the current batch | |
""" | |
logits = self.net(states) | |
# policy loss | |
log_prob = log_softmax(logits, dim=1) | |
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]] | |
policy_loss = -log_prob_actions.mean() | |
# entropy loss | |
prob = softmax(logits, dim=1) | |
entropy = -(prob * log_prob).sum(dim=1).mean() | |
entropy_loss = -self.entropy_beta * entropy | |
# total loss | |
loss = policy_loss + entropy_loss | |
return loss | |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict: | |
""" | |
Carries out a single step through the environment to update the replay buffer. | |
Then calculates loss based on the minibatch recieved | |
Args: | |
batch: current mini batch of replay data | |
_: batch number, not used | |
Returns: | |
Training loss and log metrics | |
""" | |
states, actions, scaled_rewards = batch | |
loss = self.loss(states, actions, scaled_rewards) | |
log = { | |
"episodes": self.done_episodes, | |
"reward": self.total_rewards[-1], | |
"avg_reward": self.avg_rewards, | |
} | |
return OrderedDict( | |
{ | |
"loss": loss, | |
"avg_reward": self.avg_rewards, | |
"log": log, | |
"progress_bar": log, | |
} | |
) | |
def configure_optimizers(self) -> List[Optimizer]: | |
""" Initialize Adam optimizer""" | |
optimizer = optim.Adam(self.net.parameters(), lr=self.lr) | |
return [optimizer] | |
def _dataloader(self) -> DataLoader: | |
"""Initialize the Replay Buffer dataset used for retrieving experiences""" | |
dataset = ExperienceSourceDataset(self.train_batch) | |
dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size) | |
return dataloader | |
def train_dataloader(self) -> DataLoader: | |
"""Get train loader""" | |
return self._dataloader() | |
@staticmethod | |
def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: | |
""" | |
Adds arguments for DQN model | |
Note: these params are fine tuned for Pong env | |
Args: | |
arg_parser: the current argument parser to add to | |
Returns: | |
arg_parser with model specific cargs added | |
""" | |
arg_parser.add_argument( | |
"--entropy_beta", type=float, default=0.01, help="entropy value", | |
) | |
arg_parser.add_argument( | |
"--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch" | |
) | |
arg_parser.add_argument( | |
"--batch_size", type=int, default=32, help="size of the batches" | |
) | |
arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") | |
arg_parser.add_argument( | |
"--env", type=str, required=True, help="gym environment tag" | |
) | |
arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") | |
arg_parser.add_argument( | |
"--seed", type=int, default=123, help="seed for training run" | |
) | |
arg_parser.add_argument( | |
"--avg_reward_len", | |
type=int, | |
default=100, | |
help="how many episodes to include in avg reward", | |
) | |
return arg_parser |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment