Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 24, 2020 09:03
Show Gist options
  • Save horoiwa/25f47367c34421573b57f10ccf5017e0 to your computer and use it in GitHub Desktop.
Save horoiwa/25f47367c34421573b57f10ccf5017e0 to your computer and use it in GitHub Desktop.
class A3CAgent:
""" 中略 """
def play(self, coord):
self.total_reward = 0
self.state = self.env.reset()
try:
while not coord.should_stop():
trajectory = self.play_n_steps(N=self.MAX_TRAJECTORY)
states = [step.state for step in trajectory]
actions = [step.action for step in trajectory]
if trajectory[-1].done:
R = 0
else:
values, _ = self.local_ACNet(
tf.convert_to_tensor(np.atleast_2d(trajectory[-1].next_state),
dtype=tf.float32))
R = values[0][0].numpy()
discounted_rewards = []
for step in reversed(trajectory):
R = step.reward + self.gamma * R
discounted_rewards.append(R)
discounted_rewards.reverse()
with tf.GradientTape() as tape:
total_loss = self.compute_loss(states, actions, discounted_rewards)
grads = tape.gradient(
total_loss, self.local_ACNet.trainable_variables)
self.optimizer.apply_gradients(
zip(grads, self.global_ACNet.trainable_variables))
self.local_ACNet.set_weights(self.global_ACNet.get_weights())
if self.global_counter.n >= self.global_steps_fin:
coord.request_stop()
except tf.errors.CancelledError:
return
def play_n_steps(self, N):
trajectory = []
for _ in range(N):
self.global_counter.n += 1
action = self.local_ACNet.sample_action(self.state)
next_state, reward, done, info = self.env.step(action)
step = Step(self.state, action, reward, next_state, done)
trajectory.append(step)
if done:
self.global_history.append(self.total_reward)
self.total_reward = 0
self.state = self.env.reset()
break
else:
self.total_reward += reward
self.state = next_state
return trajectory
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment