Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Last active May 30, 2020 06:49
Show Gist options
  • Save horoiwa/f1d3a752eecff1fdbfc4a6af41251d10 to your computer and use it in GitHub Desktop.
Save horoiwa/f1d3a752eecff1fdbfc4a6af41251d10 to your computer and use it in GitHub Desktop.
import gym
def envfunc():
env = gym.make("BreakoutDeterministic-v4")
return env
class A2CAgent:
TRAJECTORY_SIZE = 5
ACTION_SPACE = 4
def __init__(self, n_procs, gamma=0.99):
self.n_procs = n_procs
self.ACNet = ActorCriticNet(action_space=self.ACTION_SPACE)
self.gamma = gamma
self.vecenv = SubProcVecEnv([envfunc for i in range(self.n_procs)])
self.states = None
self.batch_size = self.n_procs * self.TRAJECTORY_SIZE
def run(self, total_steps, test_freq=10000):
self.states = self.vecenv.reset()
steps = 0
for _ in range(total_steps // (self.n_procs * self.TRAJECTORY_SIZE)):
mb_states, mb_actions, mb_discounted_rewards = self.run_Nsteps()
states = mb_states.reshape((self.batch_size, 84, 84, 4))
selected_actions = mb_actions.reshape(self.batch_size, -1)
discounted_rewards = mb_discounted_rewards.reshape(self.batch_size, -1)
self.ACNet.update(states, selected_actions, discounted_rewards)
steps += self.n_procs * self.TRAJECTORY_SIZE
print("Step:", steps)
def run_Nsteps(self):
"""各Agentに5step実行させる
"""
mb_states, mb_actions, mb_rewards, mb_dones = [], [], [], []
for _ in range(self.TRAJECTORY_SIZE):
states = np.array(self.states)
actions = self.ACNet.sample_action(states)
rewards, next_states, dones, _ = self.vecenv.step(actions)
mb_states.append(states)
mb_actions.append(actions)
mb_rewards.append(rewards)
mb_dones.append(dones)
self.states = next_states
mb_states = np.array(mb_states).swapaxes(0, 1)
mb_actions = np.array(mb_actions).T
mb_rewards = np.array(mb_rewards).T
mb_dones = np.array(mb_dones).T
"""割引報酬和の計算"""
last_values, _ = self.ACNet.predict(self.states)
mb_discounted_rewards = np.zeros(mb_rewards.shape)
for n, (rewards, dones, last_value) in enumerate(zip(mb_rewards, mb_dones, last_values.flatten())):
rewards = rewards.tolist()
dones = dones.tolist()
discounted_rewards = self.discount_with_dones(rewards, dones, last_value)
mb_discounted_rewards[n] = discounted_rewards
return (mb_states, mb_actions, mb_discounted_rewards)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment