Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Last active May 10, 2020 06:03
Show Gist options
  • Save horoiwa/f6965f28f637980c4e80d350432795cb to your computer and use it in GitHub Desktop.
Save horoiwa/f6965f28f637980c4e80d350432795cb to your computer and use it in GitHub Desktop.
DQNAgent
class DQNAgent:
""" ==== 中略 ==== """
def play(self, episodes):
total_rewards = []
for n in range(episodes):
self.epsilon = 1.0 - min(0.95, self.global_steps * 0.95 / 500)
total_reward = self.play_episode()
total_rewards.append(total_reward)
print(f"Episode {n}: {total_reward}")
print(f"Current experiences {len(self.experiences)}")
print(f"Current epsilon {self.epsilon}")
print()
return total_rewards
def play_episode(self):
total_reward = 0
episode_steps = 0
done = False
state = self.env.reset()
while not done:
action = self.sample_action(state)
next_state, reward, done, info = self.env.step(action)
total_reward += reward
exp = Experience(state, action, reward, next_state, done)
self.experiences.append(exp)
state = next_state
episode_steps += 1
self.global_steps += 1
self.update_qnetwork()
#: target_networkの更新
if self.global_steps % 250 == 0:
self.target_network.set_weights(self.q_network.get_weights())
return total_reward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment