Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
with tf.device("/cpu:0"):
master_network = AC_Network(s_size,a_size,'global',None) # Generate global network
num_workers = multiprocessing.cpu_count() # Set workers ot number of available CPU threads
workers = []
# Create worker classes
for i in range(num_workers):
workers.append(Worker(DoomGame(),i,s_size,a_size,trainer,saver,model_path))
with tf.Session() as sess:
coord = tf.train.Coordinator()
if load_model == True:
print 'Loading Model...'
ckpt = tf.train.get_checkpoint_state(model_path)
saver.restore(sess,ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
# This is where the asynchronous magic happens.
# Start the "work" process for each worker in a separate threat.
worker_threads = []
for worker in workers:
worker_work = lambda: worker.work(max_episode_length,gamma,master_network,sess,coord)
t = threading.Thread(target=(worker_work))
t.start()
worker_threads.append(t)
coord.join(worker_threads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment