Skip to content

Instantly share code, notes, and snippets.

@emrul
Last active August 25, 2023 13:45
Show Gist options
  • Save emrul/38fcade5eb4f5e6f639fe8273f5da3f7 to your computer and use it in GitHub Desktop.
Save emrul/38fcade5eb4f5e6f639fe8273f5da3f7 to your computer and use it in GitHub Desktop.
Tianshou DQN with Temporarlly-extended epsilon greedy exploration
from typing import Union
from tianshou.data import Batch
from tianshou.policy import RainbowPolicy
import numpy as np
import argparse
# See https://arxiv.org/abs/2006.01782 for paper - Temporally-Extended ε-Greedy Exploration
# See https://www.youtube.com/watch?v=Gi_B0IqscBE for video explaining where it's helpful
# See https://github.com/thu-ml/tianshou/blob/master/examples/atari/atari_rainbow.py for full example of how to setup args/net/collectors/etc.
def get_args():
parser = argparse.ArgumentParser()
# ...
return parser.parse_args()
def run_rainbow(args=get_args()):
# ...
RainbowPolicy.exploration_noise = ez_greedy_exploration_noise
# net = Rainbow(...
# define policy
policy = RainbowPolicy(
net,
optim,
args.gamma,
args.num_atoms,
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# ...
def generate_action_sequence(num_sequences, num_actions, min_repeat=2, max_repeat=20):
"""
Generate an array of shape (num_sequences, 2) where:
- First column: random integer between min_repeat and max_repeat (inclusive), representing num_repeats
- Second column: random integer between 0 and num_actions, representing action
Parameters:
- num_sequences: Number of sequences (rows) to generate
Returns:
- Numpy array of shape (num_sequences, 2)
"""
num_repeats = np.random.randint(min_repeat, max_repeat+1, size=num_sequences).reshape(-1, 1)
actions = np.random.randint(0, num_actions, size=num_sequences).reshape(-1, 1)
return np.hstack([num_repeats, actions])
def ez_greedy_exploration_noise(
self,
act: Union[np.ndarray, Batch],
batch: Batch,
) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
if hasattr(self, "ez_greedies"):
ez_greedies = self.ez_greedies
else:
ez_greedies = generate_action_sequence(bsz, self.max_action_num)
valid_repeats = ez_greedies[:, 0] > 0
if not any(valid_repeats) and np.random.random() < self.eps:
ez_greedies = generate_action_sequence(bsz, self.max_action_num)
valid_repeats = ez_greedies[:, 0] > 0
if np.any(valid_repeats): # if we have any n
indices = np.arange(bsz)[valid_repeats]
if hasattr(batch.obs, "mask"):
actions_to_validate = ez_greedies[indices, 1]
actions_valid = batch.obs.mask[indices, actions_to_validate]
invalid_action_indices = indices[np.where(~actions_valid)[0]]
for idx in invalid_action_indices:
available_actions = np.where(batch.obs.mask[idx])[0]
assert available_actions.size > 0, f"No available actions: {available_actions}"
ez_greedies[idx, 1] = np.random.choice(available_actions) # Replace with a random valid action
act[indices] = ez_greedies[indices, 1]
ez_greedies[indices, 0] -= 1
self.ez_greedies = ez_greedies
return act
if __name__ == "__main__":
run_rainbow(get_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment