Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created December 23, 2019 11:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/e8bcf9b1a5f181925413a2712785c5b0 to your computer and use it in GitHub Desktop.
Save NMZivkovic/e8bcf9b1a5f181925413a2712785c5b0 to your computer and use it in GitHub Desktop.
def get_average_return(environment, policy, episodes=10):
total_return = 0.0
for _ in range(episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / episodes
return avg_return.numpy()[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment