Skip to content

Instantly share code, notes, and snippets.

@chetandhembre
Last active October 8, 2016 18:24
Show Gist options
  • Save chetandhembre/7fc7d6d24f22f98a9db1ab4d2e8128dc to your computer and use it in GitHub Desktop.
Save chetandhembre/7fc7d6d24f22f98a9db1ab4d2e8128dc to your computer and use it in GitHub Desktop.
Easy21
#!#/usr/bin/python2
#plot link: https://dl.dropboxusercontent.com/u/47591917/easy21_mc.png
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
RED = 0
BLACK = 1
STICK = 0
HIT = 1
WIN = 1
BUST = -1
DRAW = 0
GREDDY = 0
EXPLORE = 1
def plot_surface(X, Y, Z, title):
fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0)
ax.set_xlabel('Player Sum')
ax.set_ylabel('Dealer Showing')
ax.set_zlabel('Value')
ax.set_title(title)
ax.view_init(ax.elev, -120)
fig.colorbar(surf)
plt.savefig('easy21_mc.png')
plt.show()
class State(object):
# no ace in easy21 game to ace_usable will always be False
def __init__(self, players_sum, dealer_card):
self.player_sum = players_sum
self.dealer_sum = dealer_card
def __eq__(self, other):
return (self.player_sum, self.dealer_sum) == (other.player_sum, other.dealer_sum)
def __hash__(self):
return hash((self.player_sum, self.dealer_sum))
class StateAction(object):
def __init__(self, player_sum, dealer_sum, action):
self.state = State(player_sum, dealer_sum)
self.action = action
def __hash__(self):
return hash((self.state, self.action))
def __eq__(self, other):
return (other.state == self.state) and (other.action == self.action)
class Card(object):
def __init__(self, number, color):
self.number = number
self.color = color
class ValueMap(object):
def __init__(self, N0=100):
self.states = {}
self.visited_count = {}
self.state_visited_count = {}
self.N0 = N0
self.initialize()
def initialize(self):
for player_sum in range(1, 21 + 1):
for card in range(1, 10 + 1):
for action in [HIT, STICK]:
state = StateAction(player_sum, card, action)
self.states[state] = 0
self.visited_count[state] = 0
def state_visited(self, state):
self.state_visited_count[state] = self.state_visited_count.get(state, 0) + 1
def greedy_or_explore(self, state, action):
epsilon = self.N0 / float(self.N0 + self.state_visited_count.get(state, 0))
actions = np.ones(2) * epsilon / 2
actions[action] = actions[action] + (1 - epsilon)
return np.random.choice(len(actions), p=actions)
def select_action_state(self, state):
hit_action = self.states[StateAction(state.player_sum, state.dealer_sum, HIT)]
stick_action = self.states[StateAction(state.player_sum, state.dealer_sum, STICK)]
action = HIT if hit_action > stick_action else STICK
return self.greedy_or_explore(state, action)
def select_greedy_action(self, state):
hit_action = self.states[StateAction(state.player_sum, state.dealer_sum, HIT)]
stick_action = self.states[StateAction(state.player_sum, state.dealer_sum, STICK)]
return HIT if hit_action > stick_action else STICK
def select_card():
color = np.random.choice([RED, BLACK], 1, p=[1 / float(3), 2 / float(3)])
number = int(np.random.uniform(1, 11))
return Card(number, color)
class Easy21Env(object):
def __init__(self):
self.player_sum = select_card().number
self.dealer_sum = select_card().number
def get_reward(self):
if self.player_sum > 21 or self.player_sum < 1:
return BUST
if self.player_sum < self.dealer_sum and self.dealer_sum < 22:
return BUST
return WIN if self.dealer_sum > 21 or self.player_sum > self.dealer_sum else DRAW
def take_player_hit(self):
card = select_card()
self.player_sum = self.player_sum + card.number
def dealer_move(self):
while True:
if self.dealer_sum > self.player_sum or self.dealer_sum > 21:
break
card = select_card()
self.dealer_sum = self.dealer_sum + card.number
def get_state(self):
return State(self.player_sum, self.dealer_sum)
def select_move(player_sum):
return STICK if player_sum > 20 else HIT
class Game(object):
def __init__(self, no_episodes):
self.no_episodes = no_episodes
self.value_map = ValueMap()
def _handle_value_map(self, reward, state_visited_order, states_visited):
for action_state in state_visited_order:
n = states_visited[action_state]
self.value_map.visited_count[action_state] = self.value_map.visited_count[action_state] + 1
self.value_map.states[action_state] = self.value_map.states[action_state] + ((reward - self.value_map.states[action_state]) / float(self.value_map.visited_count[action_state]))
def plot(self):
player_sum = []
dealer_sum = []
result = []
policy_action = []
policy_a = ''
for player in range(12, 21 + 1):
line = []
for card in range(1, 10 + 1):
state = State(player, card)
player_sum.append(player)
dealer_sum.append(card)
action = self.value_map.select_greedy_action(State(player, card))
action_state = StateAction(player, card, action)
result.append(self.value_map.states[action_state])
policy_action.append(action)
line.append(str(action))
policy_a = ''.join(line) + '\n' + policy_a
print policy_a
min_x = min(k for k in dealer_sum)
max_x = max(k for k in dealer_sum)
min_y = min(k for k in player_sum)
max_y = max(k for k in player_sum)
x_range = np.arange(min_x, max_x + 1)
y_range = np.arange(min_y, max_y + 1)
X, Y = np.meshgrid(x_range, y_range)
# Find value for all (x, y) coordinates
Z_noace = np.apply_along_axis(lambda _: self.value_map.states[StateAction(_[0], _[1], self.value_map.select_greedy_action(State(_[0], _[1])))], 2, np.dstack([Y, X]))
# Z_ace = np.apply_along_axis(lambda _: V[(_[0], _[1], True)], 2, np.dstack([X, Y]))
title = "easy21"
plot_surface(Y, X, Z_noace, "{} (No Usable Ace)".format(title))
def play(self):
for i in range(self.no_episodes):
eps = Easy21Env()
is_busted = False
states_visited = {}
state_visited_order = []
state = eps.get_state()
while True:
action = self.value_map.select_action_state(state)
action_state = StateAction(state.player_sum, state.dealer_sum, action)
states_visited[action_state] = states_visited.get(action_state, 0) + 1
state_visited_order.append(action_state)
if action == STICK:
break
eps.take_player_hit()
if eps.player_sum > 21 or eps.player_sum < 1:
is_busted = True
break
state = eps.get_state()
self.value_map.state_visited(state)
if is_busted:
self._handle_value_map(BUST, reversed(state_visited_order), states_visited)
else:
eps.dealer_move()
reward = eps.get_reward()
self._handle_value_map(reward, reversed(state_visited_order), states_visited)
game = Game(500000)
game.play()
game.plot()
"""
0000000000
0000000000
0000000000
0000000000
0000000000
1111101111
1111111111
1111111111
1111111111
1111111111
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment