Created
March 3, 2019 21:40
-
-
Save ericl/e989a2267ffe4fcbecd9682271331977 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ray.remote | |
class MyEnvActor: | |
def reset(self): | |
return obs, 0, False, {} # dummy vals for all but obs | |
def step(self, action): | |
... | |
return obs, rew, done, info | |
class MyBaseEnv(BaseEnv): | |
def __init__(self): | |
self.actors = [MyEnvActor.remote() for _ in range(10)] | |
self.pending = { | |
self.actors.reset.remote(): a for a in self.actors | |
} | |
def poll(self): | |
obs, rewards, dones, infos = {}, {}, {}, {} | |
env_returns = {} | |
ready = [] | |
# Wait for at least 1 env, but not more than 10ms total | |
while not ready: | |
ready, _ = ray.wait(list(self.pending), timeout=0.01) | |
# Get and return observations for each of the ready envs | |
for obj_id in ready: | |
actor = self.pending.pop(obj_id) | |
env_id = self.actors.index(actor) | |
ob, rew, done, info = ray.get(obj_id) | |
obs[env_id] = ob | |
rewards[env_id] = rew | |
dones[env_id] = done | |
info[env_id] = info | |
return obs, rewards, dones, infos, {} | |
def send_actions(self, action_dict): | |
for env_id, actions in action_dict.items(): | |
actor = self.actors[env_id] | |
obj_id = actor.step.remote(actions) | |
self.pending[obj_id] = actor | |
def try_reset(self, env_id): | |
obs, _, _, _ = ray.get(self.actors[env_id].reset.remote()) | |
return obs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment