Skip to content

Instantly share code, notes, and snippets.

@arthurmensch
Created September 17, 2018 21:15
Show Gist options
  • Save arthurmensch/f6a80691662e59f10283205eb15762ce to your computer and use it in GitHub Desktop.
Save arthurmensch/f6a80691662e59f10283205eb15762ce to your computer and use it in GitHub Desktop.
MCTS + dask
import threading as thr
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed as thr_as_completed
from queue import Queue as ThrQueue
from time import sleep
import numpy as np
import tornado
from distributed import Client, Queue, as_completed, Pub, Sub, get_client, \
get_worker
class _MCTS:
"""Dummy class to be replaced by cython implementation"""
def __init__(self, max_eval):
self.max_eval = max_eval
self.state_counter = 0
self.eval_counter = 0
self.turn_counter = 0
self.exploring = thr.Event()
self.backuping = thr.Event()
self.thread_pool = ThreadPoolExecutor(max_workers=2)
self.state_q = ThrQueue()
self.eval_q = ThrQueue()
def _explore(self):
self.exploring.set()
while self.state_counter < self.max_eval:
sleep(.001)
self.state_q.put(np.zeros((2048, 11, 11), dtype=np.uint8))
self.state_counter += 1
self.exploring.clear()
def _backup(self):
self.backuping.set()
while self.eval_counter < self.max_eval:
self.eval_q.get()
sleep(.001)
self.eval_counter += 1
self.backuping.clear()
def get(self):
res = self.state_q.get()
if res is None:
self.exploring.wait()
return self.state_q.get()
def put(self, eval):
self.backuping.wait()
self.eval_q.put(eval)
def act(self):
self.turn_counter += 1
return Record(np.zeros((11, 11), dtype=np.uint8), np.zeros((4, 7), dtype=np.float32))
def grow(self):
explore_future = self.thread_pool.submit(self._explore)
backup_future = self.thread_pool.submit(self._backup)
for future in thr_as_completed((explore_future, backup_future)):
try:
future.result()
except Exception as e:
raise ChildProcessError from e
self.eval_counter = 0
self.state_counter = 0
def __repr__(self):
return (f'turn {self.turn_counter}, '
f'state/eval {self.state_counter}/{self.eval_counter} '
f'buffers state/eval {self.state_q.qsize()}/{self.eval_q.qsize()}')
class Record:
def __init__(self, state, target):
self.state = state
self.target = target
class Player:
@classmethod
def start(cls, *args, **kwargs):
player = cls(*args, **kwargs)
player.loop()
return player.watch()
def __init__(self, state_q, eval_q, train_q):
self.thread_pool = ThreadPoolExecutor(max_workers=5)
self.alive = thr.Event()
self.mcts = _MCTS(max_eval=500, )
self.state_counter = 0
self.eval_counter = 0
self.turn_counter = 0
# Dask communication
self.client = get_client()
self.worker = get_worker()
self.state_q = state_q
self.eval_q = eval_q
self.train_q = train_q
def _send(self):
while self.alive.is_set():
state = self.mcts.get()
future = self.client.scatter(state)
self.state_q.put(future)
self.state_counter += 1
def _recv(self):
while self.alive.is_set():
eval = self.eval_q.get()
self.client.gather(eval)
self.mcts.put(eval)
self.eval_counter += 1
def _monitor(self):
while self.alive.is_set():
sleep(1)
print(f'[Player {self.worker.id[7:14]}] {self.mcts}')
def _loop(self):
while self.alive.is_set():
self.mcts.grow()
self.state_counter = 0
self.eval_counter = 0
record = self.mcts.act()
record_future = self.client.scatter(record)
self.train_q.put(record_future)
self.turn_counter += 1
def loop(self):
self.alive.set()
self._futures = {'monitor': self.thread_pool.submit(self._monitor),
'send': self.thread_pool.submit(self._send),
'recv': self.thread_pool.submit(self._recv),
'loop': self.thread_pool.submit(self._loop)}
def watch(self):
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as ex:
self.kill()
raise ChildProcessError from ex
return
def kill(self):
self.alive.clear()
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as e:
continue
class Model:
def __init__(self):
pass
def __call__(self, state):
sleep(0.0001)
eval = np.array(10)
return eval
def load(self, model):
pass
def serialize(self):
return np.array(100)
class Evaluator:
@classmethod
def start(cls, *args, **kwargs):
evaluator = cls(*args, **kwargs)
evaluator.loop()
return evaluator.watch()
def _monitor(self):
while self.alive.is_set():
sleep(1)
print(f'[Evaluator {self.worker.id[7:14]}] state/eval/update: '
f'{self.state_counter}/{self.eval_counter}/{self.update_counter}')
def __init__(self, model, state_q, eval_q):
self.model = model
self.thread_pool = ThreadPoolExecutor(max_workers=3)
self.alive = thr.Event()
self.state_counter = 0
self.eval_counter = 0
self.update_counter = 0
# Dask communicaiton
self.client = get_client()
self.worker = get_worker()
self.sub = Sub('model_q')
self.state_q = state_q
self.eval_q = eval_q
def _loop(self):
while self.alive.is_set():
state = self.state_q.get()
self.client.gather(state)
self.state_counter += 1
eval = self.model(state)
future = self.client.scatter(eval)
self.eval_q.put(future)
self.eval_counter += 1
try:
model_state = self.sub.get(timeout=0.)
except tornado.util.TimeoutError:
continue
if model_state is not None:
model_state = self.client.gather(model_state)
self.model.load(model_state)
self.update_counter += 1
def loop(self):
self.alive.set()
self._futures = {'loop': self.thread_pool.submit(self._loop),
'monitor': self.thread_pool.submit(self._monitor)
}
def kill(self):
self.alive.clear()
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as e:
continue
def watch(self):
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as ex:
self.kill()
raise ChildProcessError from ex
return
class ReplayBuffer:
@classmethod
def start(cls, *args, **kwargs):
buffer = cls(*args, **kwargs)
buffer.loop()
return buffer.watch()
def _monitor(self):
while self.alive.is_set():
sleep(1)
print(f'[Buffer {self.worker.id[7:14]}]'
f' train/sample: {self.train_counter}/{self.sample_counter}')
def __init__(self, train_q, sample_q, max_size=10000, batch_size=32,
min_size=3):
self.max_size = max_size
self.min_size = min_size
self.batch_size = batch_size
self.train_counter = 0
self.sample_counter = 0
self.states = np.zeros((max_size, 11, 11))
self.targets = np.zeros((max_size, 4, 7))
self.thread_pool = ThreadPoolExecutor(max_workers=4)
self.alive = thr.Event()
self.sampling = thr.Event()
# Dask communication
self.client = get_client()
self.worker = get_worker()
self.train_q = train_q
self.sample_q = sample_q
def _recv(self):
while self.alive.is_set():
i = self.train_counter % self.max_size
record_future = self.train_q.get()
record = self.client.gather(record_future)
self.train_counter += 1
if self.train_counter > self.min_size:
self.sampling.set()
self.states[i] = record.state
self.targets[i] = record.target
def _send(self):
while self.alive.is_set():
self.sampling.wait()
lim = min(self.train_counter, self.max_size)
batch = np.random.permutation(lim)[:self.batch_size]
sample = Record(self.states[batch], self.targets[batch])
sample_future = self.client.scatter(sample)
self.sample_q.put(sample_future)
self.sample_counter += 1
def loop(self):
self.alive.set()
self.sampling.clear()
self._futures = {'recv': self.thread_pool.submit(self._recv),
'send': self.thread_pool.submit(self._send),
'monitor': self.thread_pool.submit(self._monitor)
}
def watch(self):
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as ex:
self.kill()
raise ChildProcessError from ex
return
def kill(self):
self.alive.clear()
self.sampling.set()
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception:
continue
class Trainer:
@classmethod
def start(cls, *args, **kwargs):
trainer = cls(*args, **kwargs)
trainer.loop()
return trainer.watch()
def __init__(self, model, sample_q):
self.model = model
self.sample_q = sample_q
self.thread_pool = ThreadPoolExecutor(max_workers=3)
self.alive = thr.Event()
self.iter_counter = 0
# Dask communication
self.client = get_client()
self.worker = get_worker()
self.pub = Pub('model_q')
def watch(self):
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as ex:
raise ChildProcessError from ex
return
def _monitor(self):
while self.alive.is_set():
sleep(1)
print(f'[Trainer {self.worker.id[7:14]}]'
f' iter: {self.iter_counter}')
def _loop(self):
self.iter_counter = 0
while self.alive.is_set():
sample_future = self.sample_q.get()
sample = self.client.gather(sample_future)
# Train
self.model = self.model
sleep(0.001)
self.iter_counter += 1
if self.iter_counter % 100 == 0:
model_state = self.model.serialize()
model_future = self.client.scatter(model_state)
self.pub.put(model_future)
def loop(self):
self.alive.set()
self._futures = {'loop': self.thread_pool.submit(self._loop),
'monitor': self.thread_pool.submit(self._monitor)
}
def kill(self):
self.alive.clear()
for future in thr_as_completed(self._futures.values()):
try:
future.result()
except Exception as e:
raise ChildProcessError from e
def monitor(queues):
while True:
sleep(1)
print(f"[Queues] {'/'.join(name for name in queues.keys())}: "
f"{'/'.join(str(queue.qsize()) for queue in queues.values())}")
if __name__ == '__main__':
n_evaluators = 1
n_players = 1
client = Client(processes=False, n_workers=5)
state_q = Queue(maxsize=1000)
eval_q = Queue(maxsize=1000)
train_q = Queue(maxsize=1000)
sample_q = Queue(maxsize=1000)
queues = {'state_q': state_q,
'eval_q': eval_q,
'train_q': train_q,
'sample_q': sample_q,}
model = Model()
futures = []
futures.append(client.submit(Player.start, state_q, eval_q, train_q))
futures += [client.submit(Evaluator.start, model, state_q, eval_q)
for _ in range(2)]
futures.append(client.submit(ReplayBuffer.start, train_q, sample_q))
futures.append(client.submit(Trainer.start, model, sample_q))
futures.append(client.submit(monitor, queues))
for future in as_completed(futures):
try:
future.result()
except Exception as ex:
raise ChildProcessError from ex
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment