Instantly share code, notes, and snippets.

# BasedLukas/double_q.py

Created March 16, 2024 18:34
Show Gist options
• Save BasedLukas/bda5cfed389e42108fc9f6a8daeb7cd7 to your computer and use it in GitHub Desktop.
Comparison of Q-learning with Double Q-learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np import random import matplotlib.pyplot as plt from typing import Dict, Optional, Tuple def argmax_a(state: str, q: Dict) -> int: """ Find the action that maximizes the Q-value in a given state. Args: - state: The current state. - q: The Q-value dictionary. Returns: - An integer representing the action that maximizes the Q-value. """ if state == 'terminal': return 0 actions = range(len(transitions[state])) values = [q.get((state, action), 0) for action in actions] max_value = max(values) best_actions = [action for action, value in enumerate(values) if value == max_value] return np.random.choice(best_actions) def max_a(state: str, q: Dict) -> float: """ Find the maximum Q-value for any action in a given state. Args: - state: The current state. - q: The Q-value dictionary. Returns: - The maximum Q-value for the given state. """ if state == 'terminal': return 0 return max([q.get((state, action), 0) for action in range(len(transitions[state]))]) def policy( state: str, epoch: int, q1: Dict, q2: Optional[Dict] = None ) -> int: """ Select an action based on the epsilon-greedy policy derived from Q-values. Args: - state: The current state in the environment. - epoch: The current epoch of training, used to decay epsilon. - q1: The primary Q-value dictionary. - q2: An optional secondary Q-value dictionary for double Q-learning. Returns: - An integer representing the chosen action. """ eps = epsilon * (eps_decay ** epoch) number_of_possible_actions = len(transitions[state]) #exploration if np.random.random() < eps: return np.random.choice(range(number_of_possible_actions)) action1 = argmax_a(state, q1) if not q2: return action1 action2 = argmax_a(state, q2) return np.random.choice([action1, action2]) def get_reward(state: str) -> float: """ Returns the reward for transitioning into a given state. Args: - state: The state transitioned into. Returns: - A float representing the reward for that transition. Raises: - ValueError: If an invalid state is provided. """ if state == "a": raise ValueError("a should not be passed as a param as it's the starting state") if state == 'b' or state == 'terminal': return 0 if 'c' in state: return np.random.normal(-0.1, 1) raise ValueError(f"state: {state} not recognized") def q_update( state: str, action: int, new_state: str, reward: float, alpha: float, q: Dict ) -> None: """ In-place update of Q-values for Q-learning. Args: state: The current state. action: The action taken in the current state. new_state: The state reached after taking the action. reward: The reward received after taking the action. alpha: The learning rate. q: The Q-values dictionary. """ current_q = q.get((state, action), 0) # Current Q-value estimation max_next = max_a(new_state, q) # Maximum Q-value for the next state target = reward + gamma * max_next # TD Target td_error = target - current_q # TD Error update = alpha * td_error # TD Update q[(state, action)] = current_q + update def double_q_update( state: str, action: int, new_state: str, reward: float, alpha: float, q1: Dict, q2: Dict ) -> None: """ In-place update of Q-values for Double Q-learning. Args: state: The current state. action: The action taken in the current state. new_state: The state reached after taking the action. reward: The reward received after taking the action. alpha: The learning rate. q1: The first Q-values dictionary. q2: The second Q-values dictionary. """ qs = [q1, q2] # List of Q dictionaries random.shuffle(qs) # Randomly shuffle to choose one for updating qa, qb = qs # qa is the Q to update, qb is used for target calculation current_q = qa.get((state, action), 0) # Current Q-value estimation best_action = argmax_a(new_state, qa) # Best action based on qa target = reward + gamma * qb.get((new_state, best_action), 0) # TD Target using qb error = target - current_q # TD Error update = alpha * error # TD Update qa[(state, action)] = current_q + update def simulate( epoch: int, q: Dict, q2: Optional[Dict] = None ) -> None: """ Simulate an epoch of the agent's interaction with the environment, updating Q-values based on observed transitions. Args: epoch: The current epoch of the simulation. q: The Q-values dictionary for Q-learning or the primary Q-values dictionary for Double Q-learning. q2: The secondary Q-values dictionary for Double Q-learning, if applicable. """ double = q2 is not None state = 'a' while state != 'terminal': if double: action = policy(state, epoch, q, q2) else: action = policy(state, epoch, q) new_state = transitions[state][action] reward = get_reward(new_state) if double: double_q_update( state=state, action=action, new_state=new_state, reward=reward, alpha=lr, q1=q, q2=q2 ) else: q_update(state, action, new_state, reward, lr, q) state = new_state if __name__ == "__main__": ### PARAMS ### lr = 0.001 epsilon = 0.1 eps_decay = 0.995 number_of_c_states = 10 gamma = 1 epochs = 1000 ### MDP ### transitions = { "a": ["terminal", "b"], "b": ["c"+str(i) for i in range(number_of_c_states)] } for i in range(number_of_c_states): transitions[f"c{i}"] = ["terminal"] # Track the evolution of Q-values for a specific action over epochs normal = [] # For standard Q-learning double = [] # For Double Q-learning # Q-values dictionaries: key=(state,action) q: Dict[Tuple[str, int], float] = {} q1: Dict[Tuple[str, int], float] = {} q2: Dict[Tuple[str, int], float] = {} for epoch in range(epochs): simulate(epoch, q) # normal simulate(epoch, q1, q2) # double normal_q_value: float = q.get(('a', 1), 0) double_q_value: float = (q1.get(('a', 1), 0) + q2.get(('a', 1), 0)) / 2 normal.append(normal_q_value) double.append(double_q_value) plt.plot(normal, label='Standard Q-Learning') plt.plot(double, label='Double Q-Learning') plt.legend() plt.xlabel('Epochs') plt.ylabel('Value of A -> B') plt.title('Comparison of Q-Learning and Double Q-Learning') plt.show()