Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active October 7, 2020 07:24
Show Gist options
  • Save djbyrne/9450a044275f7fc2c59e2567e554f24d to your computer and use it in GitHub Desktop.
Save djbyrne/9450a044275f7fc2c59e2567e554f24d to your computer and use it in GitHub Desktop.
train_batch
def train_batch(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
while True:
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
self.episode_rewards.append(reward)
self.batch_actions.append(action)
self.batch_states.append(self.state)
self.state = next_state
if done:
self.done_episodes += 1
self.state = self.env.reset()
self.total_rewards.append(sum(self.episode_rewards))
self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:]))
returns = self.compute_returns(self.episode_rewards)
for idx in range(len(self.batch_actions)):
yield self.batch_states[idx], self.batch_actions[idx], returns[idx]
self.batch_states = []
self.batch_actions = []
self.episode_rewards = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment