Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 24, 2020 06:37
Show Gist options
  • Save horoiwa/a8ccecd1e754dc5ed212a94245995e82 to your computer and use it in GitHub Desktop.
Save horoiwa/a8ccecd1e754dc5ed212a94245995e82 to your computer and use it in GitHub Desktop.
import threading
import tensorflow as tf
import gym
class GlobalCounter:
def __init__(self):
n = 0
def main():
ACTION_SPACE = 2
NUM_AGENTS = 8
N_STEPS = 50000
with tf.device("/cpu:0"):
global_counter = GlobalCounter()
global_history = []
global_ACNet = ActorCriticNet(ACTION_SPACE)
global_ACNet.build(input_shape=(None, 4))
agents = []
for agent_id in range(NUM_AGENTS):
agent = A3CAgent(agent_id=f"agent_{agent_id}",
env=gym.envs.make("CartPole-v1"),
global_counter=global_counter,
action_space=ACTION_SPACE,
global_ACNet=global_ACNet,
gamma=0.99,
global_history=global_history,
global_steps_fin=N_STEPS)
agents.append(agent)
coord = tf.train.Coordinator()
agent_threads = []
for agent in agents:
target_func = (lambda: agent.play(coord))
thread = threading.Thread(target=target_func)
thread.start()
agent_threads.append(thread)
coord.join(agent_threads, stop_grace_period_secs=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment