Skip to content

Instantly share code, notes, and snippets.

@mrdmnd
Created January 23, 2024 21:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrdmnd/8fb5774a396d22d6ec6a469feb911870 to your computer and use it in GitHub Desktop.
Save mrdmnd/8fb5774a396d22d6ec6a469feb911870 to your computer and use it in GitHub Desktop.
outlaw rogue DQN toy problem
import numpy as np
from outlaw_environment import OutlawEnvironment
from stable_baselines3 import PPO
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
env = OutlawEnvironment()
check_env(env, warn=True)
model = DQN("MlpPolicy",
env,
learning_rate=0.02,
verbose=1,
tensorboard_log='./tensorboard_logdir/',
)
model.learn(
total_timesteps=int(1e5),
progress_bar=True,
)
vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(20):
vec_env.render("console")
action, _states = model.predict(obs, deterministic=True)
print("best action: ", action[0])
obs, rewards, dones, info = vec_env.step(action)
print("reward: ", rewards[0])
# poetry run python3 outlaw_agent.py
import gymnasium as gym
import numpy as np
BROADSIDES = 0
SKULL = 0
# Actions are just (impure) functions that return the same tuple type as the Environment step() fn.
# Allowed to modify agent_state object!
def PistolShot(agent_state):
BASE_DAM = 8708
fth = agent_state.fth_stacks
current_cp = agent_state.combo_points
cp_generated = 1
cp_generated += (3 if fth > 0 else 0)
cp_generated *= (2 if BROADSIDES else 1)
agent_state.fth_stacks = max(0, fth-1)
agent_state.combo_points = min(7, current_cp + cp_generated)
reward = 3 * BASE_DAM if (fth > 0) else BASE_DAM
terminated = False
truncated = False
observation = agent_state.get_observation()
info = agent_state.get_info()
return (observation, reward, terminated, truncated, info)
def SinisterStrike(agent_state):
BASE_DAM = 13896
fth = agent_state.fth_stacks
current_cp = agent_state.combo_points
cp_generated = 2 if BROADSIDES else 1
prob_double = (0.45 + 0.25 * (1 if SKULL else 0))
if np.random.random() < prob_double:
agent_state.fth_stacks = min(2, fth+1)
agent_state.combo_points = min(7, current_cp + 2 * cp_generated)
reward = 2 * BASE_DAM
else:
agent_state.combo_points = min(7, current_cp + 1 * cp_generated)
reward = BASE_DAM
terminated = False
truncated = False
observation = agent_state.get_observation()
info = agent_state.get_info()
return (observation, reward, terminated, truncated, info)
def Dispatch(agent_state):
BASE_DAM = 7842
current_cp = agent_state.combo_points
new_cp = 0 if current_cp <= 5 else 1
prob_extra_cp = (0.2 * current_cp) if current_cp <= 5 else (0.2 * (current_cp-5))
reward = current_cp * BASE_DAM
# Ruthlessness proc or not.
agent_state.combo_points = new_cp + (1 if np.random.random() <= prob_extra_cp else 0)
terminated = False
truncated = False
observation = agent_state.get_observation()
info = agent_state.get_info()
return (observation, reward, terminated, truncated, info)
# A helper type to keep track of our agent's state.
class AgentState():
def __init__(self):
self.fth_stacks = 0 # normalized to 0, 1, 2 instead of 0, 3, 6
self.combo_points = 0
def __repr__(self):
return f"{self.fth_stacks}, {self.combo_points}"
def describe_observation_space(self):
return gym.spaces.Box(low=np.array([0, 0]), high=np.array([2, 7]), shape=(2,), dtype=np.uint8)
def get_observation(self):
return np.array([self.fth_stacks, self.combo_points], dtype=np.uint8)
def get_info(self):
return {}
class OutlawEnvironment(gym.Env):
metadata = {"render_modes": ["console"], "render_fps": 30}
# A list of functions that take an agent_state, (potentially) mutate it, and return a tuple of the same form that step() is looking for.
ACTIONS = [
PistolShot,
SinisterStrike,
Dispatch,
]
def __init__(self, render_mode="console"):
super(OutlawEnvironment, self).__init__()
self.render_mode = render_mode
self.agent_state = AgentState()
self.action_space = gym.spaces.Discrete(len(OutlawEnvironment.ACTIONS))
self.observation_space = self.agent_state.describe_observation_space()
def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.sim_time = 0
self.agent_state = AgentState()
return (self.agent_state.get_observation(), self.agent_state.get_info())
def step(self, action):
self.sim_time += 1
return OutlawEnvironment.ACTIONS[action](self.agent_state)
def render(self):
if self.render_mode == "console":
print(self.agent_state)
def close(self):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment