with tf.device("/cpu:0"): | |
master_network = AC_Network(s_size,a_size,'global',None) # Generate global network | |
num_workers = multiprocessing.cpu_count() # Set workers ot number of available CPU threads | |
workers = [] | |
# Create worker classes | |
for i in range(num_workers): | |
workers.append(Worker(DoomGame(),i,s_size,a_size,trainer,saver,model_path)) | |
with tf.Session() as sess: | |
coord = tf.train.Coordinator() | |
if load_model == True: | |
print 'Loading Model...' | |
ckpt = tf.train.get_checkpoint_state(model_path) | |
saver.restore(sess,ckpt.model_checkpoint_path) | |
else: | |
sess.run(tf.global_variables_initializer()) | |
# This is where the asynchronous magic happens. | |
# Start the "work" process for each worker in a separate threat. | |
worker_threads = [] | |
for worker in workers: | |
worker_work = lambda: worker.work(max_episode_length,gamma,master_network,sess,coord) | |
t = threading.Thread(target=(worker_work)) | |
t.start() | |
worker_threads.append(t) | |
coord.join(worker_threads) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment