Skip to content

Instantly share code, notes, and snippets.

@jwarley
Created March 28, 2022 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jwarley/3f7501ccc4a86e6ee9ffa5ad71ecc995 to your computer and use it in GitHub Desktop.
Save jwarley/3f7501ccc4a86e6ee9ffa5ad71ecc995 to your computer and use it in GitHub Desktop.
Querying RLLib Action Probabilities
#####################################################################
I run this script to train the agent...
#####################################################################
import ray
import gym
import numpy as np
from gym.spaces import Discrete, Box
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.env.env_context import EnvContext
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
You can configure the length of the corridor via the env config."""
def __init__(self, config: EnvContext):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = Discrete(2)
self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32)
def reset(self):
self.cur_pos = 0
return [self.cur_pos]
def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
elif action == 1:
self.cur_pos += 1
done = self.cur_pos >= self.end_pos
# Produce a reward of 1.0 when we reach the goal.
return [self.cur_pos], 1.0 if done else 0, done, {}
def env_creator(config):
return SimpleCorridor(config={'corridor_length': 6})
register_env("simple_corridor", env_creator)
ray.init()
tune.run(
DQNTrainer,
config={
"framework": "torch",
"env": "simple_corridor",
"gamma": .90,
"num_gpus": 0,
"num_workers": 3,
},
checkpoint_freq=2,
checkpoint_at_end=True,
)
#####################################################################
Then I run the following in a jupyter notebook...
#####################################################################
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
import gym
import pickle as pkl
import numpy as np
from pathlib import Path
def compute_tabular_policy(agent, actions, observations, log=False):
def process_obs(obs):
obs = agent.workers.local_worker().preprocessors['default_policy'].transform(obs)
return agent.workers.local_worker().filters['default_policy'](obs, update=False)
LL_matrix = np.array([
agent.get_policy().compute_log_likelihoods(actions, [process_obs(o)]).numpy()
for o in observations
])
return LL_matrix if log else np.e ** LL_matrix
agent_dir = Path('path_to_the_agent')
checkpoint_dir = agent_dir / 'checkpoint_000026/checkpoint-26'
# Restore the trained agent
with open(agent_dir / 'params.pkl', 'rb') as param_file:
cfg = pkl.load(param_file)
agent = DQNTrainer(config=cfg, env='simple_corridor')
agent.restore(str(checkpoint_dir))
compute_tabular_policy(agent, list(range(2)), list(range(6)))
# Outputs the following:
array([[0.48007554, 0.5199246 ],
[0.48117623, 0.51882374],
[0.4819729 , 0.5180271 ],
[0.48118111, 0.5188189 ],
[0.4809635 , 0.5190365 ],
[0.48091793, 0.51908207]], dtype=float32)
# Which, as I understand it, means the agent assigns roughly equal
# probability to each action (left or right) in every state.
# Is that actually what the agent is doing? Let's see...
env = SimpleCorridor(config={'corridor_length':6})
rollouts = []
rewards = []
lengths = []
for i in range(50):
rollout = []
length = 0
episode_reward = 0
done = False
obs = env.reset()
while not done:
action = agent.compute_single_action(obs)
rollout.append(action)
obs, reward, done, info = env.step(action)
episode_reward += reward
length += 1
rollouts.append(rollout)
rewards.append(episode_reward)
lengths.append(length)
# After running the above, nearly every rollout is [1, 1, 1, 1, 1, 1],
# i.e. the agent is going right basically all the time, as is expected
# for a close-to-optimal policy.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment