Skip to content

Instantly share code, notes, and snippets.

@sirius5871
Created July 21, 2022 08:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sirius5871/a50ae10cd0596bca91f3288a233dc208 to your computer and use it in GitHub Desktop.
Save sirius5871/a50ae10cd0596bca91f3288a233dc208 to your computer and use it in GitHub Desktop.
import logging
import gym
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.tune import register_env
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s (%(asctime)s; %(filename)s:%(lineno)d)',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO)
LOGGER = logging.getLogger(__name__)
class SimpleCorridor(gym.Env):
def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = gym.spaces.Discrete(2) # left and right
self.observation_space = gym.spaces.Box(0.0, self.end_pos, shape=(1,))
LOGGER.info(f'The environment initialization is doneee!!!!!!!!!')
def reset(self):
"""Resets the episode and returns the initial observation of the new one."""
self.cur_pos = 0
# Return initial observation.
LOGGER.info(f'The reset function is executed')
return [self.cur_pos]
def step(self, action):
"""Takes a single step in the episode given `action`
Returns:
New observation, reward, done-flag, info-dict (empty).
"""
# Walk left.
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
# Walk right.
elif action == 1:
self.cur_pos += 1
# Set `done` flag when end of corridor (goal) reached.
done = self.cur_pos >= self.end_pos
# +1 when goal reached, otherwise -1.
reward = 1.0 if done else -0.1
LOGGER.info(f'The step function is executed')
return [self.cur_pos], reward, done, {}
environment = SimpleCorridor(config={'corridor_length': 20})
def env_creator(env_config=None):
return environment
register_env("my_env", env_creator)
# logs only printed in the single worker mode, doesn't work for multiple worker mode
ray.init()
# Create an RLlib Trainer instance.
trainer = A2CTrainer(env="my_env", config={"num_workers": 4, "horizon": 1})
results = trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment