Skip to content

Instantly share code, notes, and snippets.

@heiner
Created September 21, 2021 11:20
Show Gist options
  • Save heiner/67955538cce6375dbdd6fa97ffd11ce9 to your computer and use it in GitHub Desktop.
Save heiner/67955538cce6375dbdd6fa97ffd11ce9 to your computer and use it in GitHub Desktop.
import queue
import threading
import gym
def target(resetqueue, readyqueue):
while True:
env = resetqueue.get()
if env is None:
return
obs = env.reset()
readyqueue.put((obs, env))
class CachedEnvWrapper(gym.Env):
def __init__(self, envs, num_threads=2):
self._envs = envs
# This could alternatively also use concurrent.futures. I hesitate to do
# that as futures.wait would have me deal with sets all the time where they
# are really not necessary.
self._resetqueue = queue.SimpleQueue()
self._readyqueue = queue.SimpleQueue()
self._threads = [
threading.Thread(target=target, args=(self._resetqueue, self._readyqueue))
for _ in range(num_threads)
]
for t in self._threads:
t.start()
for env in envs[1:]:
self._resetqueue.put(env)
self._env = envs[0]
def reset(self):
self._resetqueue.put(self._env)
obs, self._env = self._readyqueue.get()
return obs
def step(self, action):
return self._env.step(action)
def close(self):
for _ in self._threads:
self._resetqueue.put(None)
for t in self._threads:
t.join()
for env in self._envs:
env.close()
def seed(self, seed=None):
self._env.seed(seed) # Unclear if this should happen in all envs?
def unwrapped(self):
return self._env
def __str__(self):
return "<CachedEnvWrapper envs=%s>" % [str(env) for env in self._envs]
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
return False # Propagate exception.
from concurrent import futures
import gym
def target(env):
obs = env.reset()
return obs, env
class CachedEnvWrapper2(gym.Env):
def __init__(self, envs, threadpool, num_workers=2):
self._envs = envs
self._threadpool = threadpool
self._num_workers = 2
self._futures = set()
self._env = envs[0]
for env in envs[1:]:
self._futures.add(threadpool.submit(target, env))
def step(self, action):
return self._env.step(action)
def reset(self):
self._futures.add(self._threadpool.submit(target, self._env))
done, not_done = futures.wait(
self._futures, return_when=futures.FIRST_COMPLETED
)
for future in done:
obs, self._env = future.result()
break
self._futures.remove(future)
return obs
def close(self):
for env in self._envs:
env.close()
def seed(self, seed=None):
self._env.seed(seed) # Unclear if this should happen in all envs?
def unwrapped(self):
return self._env
def __str__(self):
return "<CachedEnvWrapper envs=%s>" % [str(env) for env in self._envs]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment