Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active October 7, 2020 07:09
Show Gist options
  • Save djbyrne/5a4a51b75afb130ffe2ef19a6e06c972 to your computer and use it in GitHub Desktop.
Save djbyrne/5a4a51b75afb130ffe2ef19a6e06c972 to your computer and use it in GitHub Desktop.
LitRLModel(pl.LightningModule):
def __init__(self, env, ...):
# Environemnt
self.env = gym.make(env)
self.env.seed(123)
self.obs_shape = self.env.observation_space.shape
self.n_actions = self.env.action_space.n
# Agent
self.agent = ValueAgent(self.net, self.n_actions)
# Dataset
def train_dataloader(self) -> DataLoader:
self.dataset = ExperienceSourceDataset(self.train_batch)
return DataLoader(dataset=self.dataset, batch_size=self.batch_size)
# Train Batch
def train_batch(self) -> Tuple:
# keep taking steps during training
while True:
# take a step in the environment
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
# add results to the batch
...
# when the batch is ready, yield to the dataset
yield batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment