Skip to content

Instantly share code, notes, and snippets.

@AurelianTactics
Created July 30, 2018 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AurelianTactics/37420f03d587e8f1a9fc5235d8f6253a to your computer and use it in GitHub Desktop.
Save AurelianTactics/37420f03d587e8f1a9fc5235d8f6253a to your computer and use it in GitHub Desktop.
multi env implementation issues with ray
import ray
import gym
from ray.rllib.agents import ppo, dqn
from ray.tune.registry import register_env
env_name = "multienv"
class MultiEnv(gym.Env):
def __init__(self, env_config):
# pick actual env based on worker and env indexes
self.env = gym.make(
choose_env_for(env_config.worker_index, env_config.vector_index))
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
def choose_env_for(env_config):
print(env_config)
print("worker index is {}".format(env_config.worker_index))
print("testing vector_index {}".format(env_config.vector_index))
if env_config.worker_index % 2 == 0:
print('in even env')
env = gym.make('CartPole-v0')
else:
print('in odd env')
env = gym.make('CartPole-v1')
return env
#This example fails:
#File "/home/jim/projects/retroui/ray/python/ray/rllib/models/preprocessors.py", line 161, in legacy_patch_shapes return space.shape
#AttributeError: 'NoneType' object has no attribute 'shape'
register_env("multienv", lambda _: MultiEnv)
#This example fails:
#TypeError("Second argument must be a function.", env_creator)
#register_env("multienv", MultiEnv)
#Making an env the normal way works
#register_env("multienv", lambda config: gym.make('CartPole-v0'))
#register_env("multienv", lambda config: gym.make('FrozenLake-v0'))
#Making a multi env the naive way works
#register_env("multienv", lambda env_config: choose_env_for(env_config))
ray.init()
config = dqn.DEFAULT_CONFIG.copy()
config.update({
'env_config':{"vector_index":1,"vector":2}, #not sure how to assing vectors_index
'num_workers' : 2
})
alg = dqn.DQNAgent(config=config, env=env_name)
for i in range(1000):
result = alg.train()
print('result = {}'.format(result))
if i % 10 == 0:
checkpoint = alg.save()
print('checkpoint saved at', checkpoint)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment