Skip to content

Instantly share code, notes, and snippets.

@jeroenboeye
Created August 21, 2020 09:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeroenboeye/676c9cb02ed48fd853aceeec98a7b577 to your computer and use it in GitHub Desktop.
Save jeroenboeye/676c9cb02ed48fd853aceeec98a7b577 to your computer and use it in GitHub Desktop.
Blackjack simulator where rewards of a fixed policy are calculated using Monte Carlo method. As described in Chapter 5(.1) of Reinforcement Learning, an introduction by Sutton and Barto
"""
Blackjack simulator where rewards of a fixed policy are calculated using Monte Carlo method.
As described in Chapter 5(.1) of Reinforcement Learning, an introduction by Sutton and Barto
"""
from dataclasses import dataclass, field
from typing import List, Tuple
import numpy as np
DECK = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10])
@dataclass
class Person:
"""Superclass for both Dealer and Player."""
has_usable_ace: bool = False
total: int = 0
def hit(self):
"""Take a card from an (infinite) deck, check if there is a usable Ace."""
card = np.random.choice(DECK)
# card 1 = Ace.
if card == 1 and (self.total + 11) <= 21:
self.has_usable_ace = True
self.total += 11
else:
self.total += card
# Spend the usable ace to decrease total if over 21
if self.total > 21 and self.has_usable_ace:
self.has_usable_ace = False
self.total -= 10
@dataclass
class Player(Person):
threshold: int = 20
visited_states: List[Tuple[int]] = field(default_factory=lambda: [])
def __post_init__(self) -> None:
"""Strategy only starts as soon as you have 12 or more. Until then, always hit."""
while self.total < 12:
self.hit()
def play(self, dealer_card_shown: int) -> None:
"""Apply strategy to hit or stick given total score, usable ace, dealer card shown."""
while self.total < self.threshold:
self.visited_states.append(tuple([self.total, dealer_card_shown, int(self.has_usable_ace)]))
self.hit()
# Player sticks
if self.threshold <= self.total <= 21:
self.visited_states.append(tuple([self.total, dealer_card_shown, int(self.has_usable_ace)]))
@dataclass
class Dealer(Person):
threshold: int = 17
showing_card: int = field(init=False)
def __post_init__(self) -> None:
"""Take two start hits and make first card visible."""
self.hit()
self.showing_card = self.total
self.hit()
def play(self) -> None:
"""Dealer has fixed strategy, hit till past threshold."""
while self.total < self.threshold:
self.hit()
@dataclass
class Simulation:
n_iter: int
rewards: np.ndarray = field(default_factory=lambda: np.zeros((10, 10, 2)))
visits: np.ndarray = field(default_factory=lambda: np.zeros((10, 10, 2)))
def run(self) -> None:
"""Iterate through episodes (games) and see who won, then attribute rewards and count state visits."""
for i in range(self.n_iter):
p = Player()
d = Dealer()
p.play(d.showing_card)
if not p.total > 21:
d.play()
if p.total == d.total:
reward = 0
# print(f'Player and dealer drew with {p.total}')
elif p.total > d.total or d.total > 21:
reward = 1
# print(f'Player wins with {p.total} vs {d.total}')
else:
reward = -1
# print(f'Player loses with {p.total} vs {d.total}')
else:
# print(f'Player busted with {p.total}')
reward = -1
for state in p.visited_states:
self.visits[state[0] - 12, (state[1] - 1) % 10, state[2]] += 1
self.rewards[state[0] - 12, (state[1] - 1) % 10, state[2]] += reward
@staticmethod
def plot_heatmap(array: np.ndarray, title: str) -> None:
import seaborn as sns
import matplotlib.pylab as plt
ax = sns.heatmap(array, linewidth=0.5, vmin=-1, vmax=1, center=0, cmap='coolwarm_r', annot=True, fmt=".2f")
plt.ylabel('Player hand')
plt.xlabel('Dealer card shown')
plt.title(title)
plt.xticks(np.arange(10) + 0.5, ['Ace'] + list(range(2, 11)))
plt.yticks(np.arange(10) + 0.5, np.arange(10) + 12)
# plt.savefig(f'{title}.png')
plt.show()
plt.close()
def evaluate(self) -> None:
"""Per state, divide summed reward by visits to get average reward of state."""
mean_reward = self.rewards / self.visits
with_usable_ace = mean_reward[:, :, 1]
without_usable_ace = mean_reward[:, :, 0]
self.plot_heatmap(with_usable_ace, 'With usable ace')
self.plot_heatmap(without_usable_ace, 'Without usable ace')
if __name__ == "__main__":
np.random.seed(42)
s = Simulation(1000000)
s.run()
s.evaluate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment