Created
March 28, 2022 07:00
-
-
Save jwarley/3f7501ccc4a86e6ee9ffa5ad71ecc995 to your computer and use it in GitHub Desktop.
Querying RLLib Action Probabilities
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##################################################################### | |
I run this script to train the agent... | |
##################################################################### | |
import ray | |
import gym | |
import numpy as np | |
from gym.spaces import Discrete, Box | |
from ray import tune | |
from ray.tune.registry import register_env | |
from ray.rllib.agents.dqn import DQNTrainer | |
from ray.rllib.env.env_context import EnvContext | |
class SimpleCorridor(gym.Env): | |
"""Example of a custom env in which you have to walk down a corridor. | |
You can configure the length of the corridor via the env config.""" | |
def __init__(self, config: EnvContext): | |
self.end_pos = config["corridor_length"] | |
self.cur_pos = 0 | |
self.action_space = Discrete(2) | |
self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32) | |
def reset(self): | |
self.cur_pos = 0 | |
return [self.cur_pos] | |
def step(self, action): | |
assert action in [0, 1], action | |
if action == 0 and self.cur_pos > 0: | |
self.cur_pos -= 1 | |
elif action == 1: | |
self.cur_pos += 1 | |
done = self.cur_pos >= self.end_pos | |
# Produce a reward of 1.0 when we reach the goal. | |
return [self.cur_pos], 1.0 if done else 0, done, {} | |
def env_creator(config): | |
return SimpleCorridor(config={'corridor_length': 6}) | |
register_env("simple_corridor", env_creator) | |
ray.init() | |
tune.run( | |
DQNTrainer, | |
config={ | |
"framework": "torch", | |
"env": "simple_corridor", | |
"gamma": .90, | |
"num_gpus": 0, | |
"num_workers": 3, | |
}, | |
checkpoint_freq=2, | |
checkpoint_at_end=True, | |
) | |
##################################################################### | |
Then I run the following in a jupyter notebook... | |
##################################################################### | |
from ray.rllib.agents.dqn import DQNTrainer | |
from ray.rllib.agents.ppo import PPOTrainer | |
import gym | |
import pickle as pkl | |
import numpy as np | |
from pathlib import Path | |
def compute_tabular_policy(agent, actions, observations, log=False): | |
def process_obs(obs): | |
obs = agent.workers.local_worker().preprocessors['default_policy'].transform(obs) | |
return agent.workers.local_worker().filters['default_policy'](obs, update=False) | |
LL_matrix = np.array([ | |
agent.get_policy().compute_log_likelihoods(actions, [process_obs(o)]).numpy() | |
for o in observations | |
]) | |
return LL_matrix if log else np.e ** LL_matrix | |
agent_dir = Path('path_to_the_agent') | |
checkpoint_dir = agent_dir / 'checkpoint_000026/checkpoint-26' | |
# Restore the trained agent | |
with open(agent_dir / 'params.pkl', 'rb') as param_file: | |
cfg = pkl.load(param_file) | |
agent = DQNTrainer(config=cfg, env='simple_corridor') | |
agent.restore(str(checkpoint_dir)) | |
compute_tabular_policy(agent, list(range(2)), list(range(6))) | |
# Outputs the following: | |
array([[0.48007554, 0.5199246 ], | |
[0.48117623, 0.51882374], | |
[0.4819729 , 0.5180271 ], | |
[0.48118111, 0.5188189 ], | |
[0.4809635 , 0.5190365 ], | |
[0.48091793, 0.51908207]], dtype=float32) | |
# Which, as I understand it, means the agent assigns roughly equal | |
# probability to each action (left or right) in every state. | |
# Is that actually what the agent is doing? Let's see... | |
env = SimpleCorridor(config={'corridor_length':6}) | |
rollouts = [] | |
rewards = [] | |
lengths = [] | |
for i in range(50): | |
rollout = [] | |
length = 0 | |
episode_reward = 0 | |
done = False | |
obs = env.reset() | |
while not done: | |
action = agent.compute_single_action(obs) | |
rollout.append(action) | |
obs, reward, done, info = env.step(action) | |
episode_reward += reward | |
length += 1 | |
rollouts.append(rollout) | |
rewards.append(episode_reward) | |
lengths.append(length) | |
# After running the above, nearly every rollout is [1, 1, 1, 1, 1, 1], | |
# i.e. the agent is going right basically all the time, as is expected | |
# for a close-to-optimal policy. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment