Created
August 21, 2020 09:32
Blackjack simulator where rewards of a fixed policy are calculated using Monte Carlo method. As described in Chapter 5(.1) of Reinforcement Learning, an introduction by Sutton and Barto
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Blackjack simulator where rewards of a fixed policy are calculated using Monte Carlo method. | |
As described in Chapter 5(.1) of Reinforcement Learning, an introduction by Sutton and Barto | |
""" | |
from dataclasses import dataclass, field | |
from typing import List, Tuple | |
import numpy as np | |
DECK = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10]) | |
@dataclass | |
class Person: | |
"""Superclass for both Dealer and Player.""" | |
has_usable_ace: bool = False | |
total: int = 0 | |
def hit(self): | |
"""Take a card from an (infinite) deck, check if there is a usable Ace.""" | |
card = np.random.choice(DECK) | |
# card 1 = Ace. | |
if card == 1 and (self.total + 11) <= 21: | |
self.has_usable_ace = True | |
self.total += 11 | |
else: | |
self.total += card | |
# Spend the usable ace to decrease total if over 21 | |
if self.total > 21 and self.has_usable_ace: | |
self.has_usable_ace = False | |
self.total -= 10 | |
@dataclass | |
class Player(Person): | |
threshold: int = 20 | |
visited_states: List[Tuple[int]] = field(default_factory=lambda: []) | |
def __post_init__(self) -> None: | |
"""Strategy only starts as soon as you have 12 or more. Until then, always hit.""" | |
while self.total < 12: | |
self.hit() | |
def play(self, dealer_card_shown: int) -> None: | |
"""Apply strategy to hit or stick given total score, usable ace, dealer card shown.""" | |
while self.total < self.threshold: | |
self.visited_states.append(tuple([self.total, dealer_card_shown, int(self.has_usable_ace)])) | |
self.hit() | |
# Player sticks | |
if self.threshold <= self.total <= 21: | |
self.visited_states.append(tuple([self.total, dealer_card_shown, int(self.has_usable_ace)])) | |
@dataclass | |
class Dealer(Person): | |
threshold: int = 17 | |
showing_card: int = field(init=False) | |
def __post_init__(self) -> None: | |
"""Take two start hits and make first card visible.""" | |
self.hit() | |
self.showing_card = self.total | |
self.hit() | |
def play(self) -> None: | |
"""Dealer has fixed strategy, hit till past threshold.""" | |
while self.total < self.threshold: | |
self.hit() | |
@dataclass | |
class Simulation: | |
n_iter: int | |
rewards: np.ndarray = field(default_factory=lambda: np.zeros((10, 10, 2))) | |
visits: np.ndarray = field(default_factory=lambda: np.zeros((10, 10, 2))) | |
def run(self) -> None: | |
"""Iterate through episodes (games) and see who won, then attribute rewards and count state visits.""" | |
for i in range(self.n_iter): | |
p = Player() | |
d = Dealer() | |
p.play(d.showing_card) | |
if not p.total > 21: | |
d.play() | |
if p.total == d.total: | |
reward = 0 | |
# print(f'Player and dealer drew with {p.total}') | |
elif p.total > d.total or d.total > 21: | |
reward = 1 | |
# print(f'Player wins with {p.total} vs {d.total}') | |
else: | |
reward = -1 | |
# print(f'Player loses with {p.total} vs {d.total}') | |
else: | |
# print(f'Player busted with {p.total}') | |
reward = -1 | |
for state in p.visited_states: | |
self.visits[state[0] - 12, (state[1] - 1) % 10, state[2]] += 1 | |
self.rewards[state[0] - 12, (state[1] - 1) % 10, state[2]] += reward | |
@staticmethod | |
def plot_heatmap(array: np.ndarray, title: str) -> None: | |
import seaborn as sns | |
import matplotlib.pylab as plt | |
ax = sns.heatmap(array, linewidth=0.5, vmin=-1, vmax=1, center=0, cmap='coolwarm_r', annot=True, fmt=".2f") | |
plt.ylabel('Player hand') | |
plt.xlabel('Dealer card shown') | |
plt.title(title) | |
plt.xticks(np.arange(10) + 0.5, ['Ace'] + list(range(2, 11))) | |
plt.yticks(np.arange(10) + 0.5, np.arange(10) + 12) | |
# plt.savefig(f'{title}.png') | |
plt.show() | |
plt.close() | |
def evaluate(self) -> None: | |
"""Per state, divide summed reward by visits to get average reward of state.""" | |
mean_reward = self.rewards / self.visits | |
with_usable_ace = mean_reward[:, :, 1] | |
without_usable_ace = mean_reward[:, :, 0] | |
self.plot_heatmap(with_usable_ace, 'With usable ace') | |
self.plot_heatmap(without_usable_ace, 'Without usable ace') | |
if __name__ == "__main__": | |
np.random.seed(42) | |
s = Simulation(1000000) | |
s.run() | |
s.evaluate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment