def get_average_return(environment, policy, episodes=10):
total_return = 0.0
for _ in range(episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / episodes
return avg_return.numpy()[0]
