Last active
August 31, 2020 07:55
-
-
Save jeroenboeye/03cf9dca9a18097fdc65eb0fb75bd004 to your computer and use it in GitHub Desktop.
Higher lower (simple card game) optimizer using epsilon greedy Monte Carlo learning. For educational purposes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Higher lower (simple card game) optimizer using epsilon greedy Monte Carlo learning. For educational purposes. | |
""" | |
from dataclasses import dataclass, field | |
from typing import List, Tuple | |
import numpy as np | |
@dataclass | |
class Player: | |
action: int = field(init=False) | |
epsilon: float = 0.1 | |
state_actions: List[Tuple[int]] = field(default_factory=list) | |
def play(self, card: int, policy: np.ndarray) -> None: | |
"""If random < epsilon -> random policy (exploration) else, exploitation of current best policy.""" | |
# 0 is saying guessing next card will be lower, 1 is guessing it will be higher. | |
if np.random.random() < self.epsilon: | |
self.action = np.random.randint(0, 2) | |
else: | |
self.action = int(policy[card - 2]) | |
self.state_actions.append(tuple([card, self.action])) | |
@dataclass | |
class Simulation: | |
n_iter: int | |
visits: np.ndarray = field(default_factory=lambda: np.zeros((13, 2))) | |
q_table: np.ndarray = field(default_factory=lambda: np.random.uniform(-0.001, 0.001, (13, 2))) | |
policy: np.ndarray = field(default_factory=lambda: np.zeros(13)) | |
def __post_init__(self): | |
"""Set initial policy""" | |
self.policy[:7] = 1 | |
def run(self) -> None: | |
"""Iterate through episodes (games) and check win / lose, then attribute rewards and count state visits.""" | |
card = np.random.randint(2, 15) | |
for i in range(self.n_iter): | |
p = Player() | |
p.play(card, self.policy) | |
next_card = np.random.randint(2, 15) | |
if next_card == card: | |
reward = 0 | |
# 0 is saying guessing next card will be lower, 1 is guessing it will be higher. | |
elif (next_card > card and p.action == 1) or (next_card < card and p.action == 0): | |
reward = 1 | |
else: | |
reward = -1 | |
for state in p.state_actions: | |
self.visits[state[0] - 2, state[1]] += 1 | |
# Get the expected reward given the current state and action | |
expected_reward = self.q_table[state[0] - 2, state[1]] | |
n_visits = self.visits[state[0] - 2, state[1]] | |
self.q_table[state[0] - 2, state[1]] += (reward - expected_reward) / n_visits | |
# For each state, greedily select the action with the highest expected reward as the policy. | |
self.policy = np.argmax(self.q_table, axis=-1) | |
card = next_card | |
if (i + 1) % (self.n_iter // 20) == 0: | |
print(f'Simulation {i / self.n_iter:4.0%} complete ') | |
@staticmethod | |
def plot_heatmap(array: np.ndarray, title: str) -> None: | |
import seaborn as sns | |
import matplotlib.pylab as plt | |
plt.figure(figsize=(2, 5)) | |
ax = sns.heatmap(np.flip(np.expand_dims(array, axis=1), 0), | |
linewidth=0.5, vmin=-1, vmax=1, center=0, | |
cmap='coolwarm_r', annot=True, fmt=".2f") | |
plt.ylabel('Card shown') | |
plt.xlabel('') | |
plt.title(title) | |
plt.yticks(np.arange(13) + 0.5, np.flip(np.arange(13) + 2)) | |
plt.savefig(f'{title}.png') | |
plt.show() | |
plt.close() | |
def evaluate(self) -> None: | |
value_optimal_policy = np.max(self.q_table, axis=-1) | |
self.plot_heatmap(self.policy, 'Policy') | |
self.plot_heatmap(self.q_table[:, 0], 'Q table lower') | |
self.plot_heatmap(self.q_table[:, 1], 'Q table higher') | |
self.plot_heatmap(value_optimal_policy, 'V*') | |
if __name__ == "__main__": | |
np.random.seed(1) | |
s = Simulation(1000000) | |
s.run() | |
s.evaluate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment