Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active October 7, 2020 07:24
Show Gist options
  • Save djbyrne/f6cb4d430740918ed518918600fe1d4c to your computer and use it in GitHub Desktop.
Save djbyrne/f6cb4d430740918ed518918600fe1d4c to your computer and use it in GitHub Desktop.
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, scaled_rewards = batch
loss = self.loss(states, actions, scaled_rewards)
log = {
"episodes": self.done_episodes,
"reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
}
return OrderedDict(
{
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": log,
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment