Created
July 30, 2018 20:47
-
-
Save AurelianTactics/37420f03d587e8f1a9fc5235d8f6253a to your computer and use it in GitHub Desktop.
multi env implementation issues with ray
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import gym | |
from ray.rllib.agents import ppo, dqn | |
from ray.tune.registry import register_env | |
env_name = "multienv" | |
class MultiEnv(gym.Env): | |
def __init__(self, env_config): | |
# pick actual env based on worker and env indexes | |
self.env = gym.make( | |
choose_env_for(env_config.worker_index, env_config.vector_index)) | |
def reset(self): | |
return self.env.reset() | |
def step(self, action): | |
return self.env.step(action) | |
def choose_env_for(env_config): | |
print(env_config) | |
print("worker index is {}".format(env_config.worker_index)) | |
print("testing vector_index {}".format(env_config.vector_index)) | |
if env_config.worker_index % 2 == 0: | |
print('in even env') | |
env = gym.make('CartPole-v0') | |
else: | |
print('in odd env') | |
env = gym.make('CartPole-v1') | |
return env | |
#This example fails: | |
#File "/home/jim/projects/retroui/ray/python/ray/rllib/models/preprocessors.py", line 161, in legacy_patch_shapes return space.shape | |
#AttributeError: 'NoneType' object has no attribute 'shape' | |
register_env("multienv", lambda _: MultiEnv) | |
#This example fails: | |
#TypeError("Second argument must be a function.", env_creator) | |
#register_env("multienv", MultiEnv) | |
#Making an env the normal way works | |
#register_env("multienv", lambda config: gym.make('CartPole-v0')) | |
#register_env("multienv", lambda config: gym.make('FrozenLake-v0')) | |
#Making a multi env the naive way works | |
#register_env("multienv", lambda env_config: choose_env_for(env_config)) | |
ray.init() | |
config = dqn.DEFAULT_CONFIG.copy() | |
config.update({ | |
'env_config':{"vector_index":1,"vector":2}, #not sure how to assing vectors_index | |
'num_workers' : 2 | |
}) | |
alg = dqn.DQNAgent(config=config, env=env_name) | |
for i in range(1000): | |
result = alg.train() | |
print('result = {}'.format(result)) | |
if i % 10 == 0: | |
checkpoint = alg.save() | |
print('checkpoint saved at', checkpoint) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment