Last active
October 7, 2020 07:09
-
-
Save djbyrne/5a4a51b75afb130ffe2ef19a6e06c972 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LitRLModel(pl.LightningModule): | |
def __init__(self, env, ...): | |
# Environemnt | |
self.env = gym.make(env) | |
self.env.seed(123) | |
self.obs_shape = self.env.observation_space.shape | |
self.n_actions = self.env.action_space.n | |
# Agent | |
self.agent = ValueAgent(self.net, self.n_actions) | |
# Dataset | |
def train_dataloader(self) -> DataLoader: | |
self.dataset = ExperienceSourceDataset(self.train_batch) | |
return DataLoader(dataset=self.dataset, batch_size=self.batch_size) | |
# Train Batch | |
def train_batch(self) -> Tuple: | |
# keep taking steps during training | |
while True: | |
# take a step in the environment | |
action = self.agent(self.state, self.device) | |
next_state, reward, done, _ = self.env.step(action[0]) | |
# add results to the batch | |
... | |
# when the batch is ready, yield to the dataset | |
yield batch | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment