Skip to content

Instantly share code, notes, and snippets.

@t-abe
Created November 20, 2019 13:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-abe/e488e4c25c583045c1fb16a4bb49f37b to your computer and use it in GitHub Desktop.
Save t-abe/e488e4c25c583045c1fb16a4bb49f37b to your computer and use it in GitHub Desktop.
import numpy as np
import typing
from collections import OrderedDict, Callable
import random
from copy import copy
import time
class Node(object):
def __init__(self, prev=None, action=(), player=None):
self.prev = prev
self.action: tuple = action
self.player = player
self.depth = 0
if prev is not None:
self.depth = self.prev.depth + 1
self.next = self.list_next()
self.utility = 0
if self.is_terminal():
self.utility = self.compute_utility()
def is_terminal(self):
return len(self.next) == 0
def list_next(self):
betsize = 0.5
if self.depth == 0: # depth=0 is root, next is dealing the first card
return [Node(self, ('dealt', c), 0) for c in (0, 1, 2)]
if self.depth == 1: # next is dealing the second card
return [Node(self, ('dealt', c), 1) for c in (0, 1, 2) if c != self.action[1]]
if self.depth == 2: # player 0's first action
return [Node(self, ('bet', 0), 0), Node(self, ('bet', betsize), 0)]
if self.depth == 3: # player 1's action
return [Node(self, ('bet', 0), 1), Node(self, ('bet', betsize), 1)]
if self.depth >= 4:
is_terminal = self.prev.action[1] >= self.action[1]
if is_terminal:
return []
# check -> bet -> ?
next_p = int(not self.player)
return [Node(self, ('bet', 0), next_p), Node(self, ('bet', betsize), next_p)]
def is_root(self):
return self.prev is None
def is_fold(self):
assert self.player is not None
return self.prev.action[1] > self.action[1]
def is_call(self):
assert self.player is not None
return self.prev.action[1] == self.action[1]
def compute_utility(self): # of player 0
u = -0.5 # ante
pot = 1 # ante
ptr = self
cards = [0, 0]
while ptr.prev is not None:
if ptr.action[0] == 'dealt':
cards[ptr.player] = ptr.action[1]
elif ptr.action[0] == 'bet':
pot += ptr.action[1]
if ptr.player == 0:
u -= ptr.action[1]
ptr = ptr.prev
win = False # flag of player 0 win
if self.is_fold() and self.player == 1:
win = True
if self.is_call() and cards[0] > cards[1]:
win = True
if win:
u += pot
return u
def trace_up_actions(self):
nodes = [self]
while nodes[-1].prev is not None:
nodes.append(nodes[-1].prev)
return [n.action for n in reversed(nodes) if n.action is not ()]
def get_infoset(self, player):
actions = self.trace_up_actions()
if player == 0:
actions[1] = ('dealt', -1)
elif player == 1:
actions[0] = ('dealt', -1)
return (player,) + tuple(actions)
class CFR(object):
def __init__(self):
self.regret_table = OrderedDict() # r[I][a]
self.cum_strategy_table = OrderedDict() # s[I][a]
self.profiles = [
OrderedDict()
]
def __call__(self, node: Node, learning_player: int, t: int, reach_probs):
if node.is_terminal():
u = node.utility
if learning_player == 1:
u = -u
return u
if node.next[0].action[0] == 'dealt': # chance node
# return self(random.choice(node.next), learning_player, t, reach_probs)
return sum([self(n, learning_player, t, reach_probs) for n in node.next]) / len(node.next)
player = node.next[0].player # player of NEXT action
infoset = node.get_infoset(player)
n_actions = len(node.next)
value = 0
value_given_action = np.zeros((n_actions,), dtype=np.float32)
strategy_profile = self.profiles[t]
if infoset not in strategy_profile:
strategy_profile[infoset] = np.ones((n_actions,), dtype=np.float32) * 1. / n_actions
for j, n in enumerate(node.next):
action_prob = strategy_profile[infoset][j]
if player == 0:
value_given_action[j] = self(n, learning_player, t,
(action_prob * reach_probs[0], reach_probs[1]))
elif player == 1:
value_given_action[j] = self(n, learning_player, t,
(reach_probs[0], action_prob * reach_probs[1]))
value += action_prob * value_given_action[j]
if player == learning_player:
# TODO: Move to somewhere not to evaluate many times
if infoset not in self.regret_table:
self.regret_table[infoset] = np.zeros((n_actions,), dtype=np.float32)
if infoset not in self.cum_strategy_table:
self.cum_strategy_table[infoset] = np.zeros((n_actions,), dtype=np.float32)
self.regret_table[infoset] += reach_probs[1 - learning_player] * (value_given_action - value)
self.cum_strategy_table[infoset] += reach_probs[learning_player] * strategy_profile[infoset]
return value
def update_profile(self):
new_profile = copy(self.profiles[-1])
for infoset in self.regret_table.keys():
n_actions = self.regret_table[infoset].shape[0]
pos_reg = np.maximum(self.regret_table[infoset], 0)
pos_sum = np.sum(pos_reg)
if pos_sum <= 0:
new_profile[infoset] = np.ones(n_actions, dtype=np.float32) * 1. / n_actions
else:
new_profile[infoset] = pos_reg / pos_sum
self.profiles.append(new_profile)
def get_mes(self, player):
# Computes strategy minimizing current regret.
mes = copy(self.profiles[-1])
for infoset in self.regret_table:
if infoset[0] == player:
mes_j = self.regret_table[infoset].argmax()
mes[infoset] = np.zeros_like(self.regret_table[infoset])
mes[infoset][mes_j] = 1.0
return mes
def main():
tree = Node()
# traverse
next_nodes = [n for n in tree.next]
while len(next_nodes) > 0:
next_node = next_nodes.pop(0)
if next_node.is_terminal():
print(next_node.trace_up_actions(), next_node.utility)
next_nodes += next_node.next
cfr = CFR()
timer = time.time()
for t in range(15000):
for learning_player in (0, 1):
cfr(tree, learning_player, t, (1., 1.))
cfr.update_profile()
print(time.time() - timer)
cumst_list = sorted([(infoset, probs)
for infoset, probs in cfr.cum_strategy_table.items()])
for infoset, probs in cumst_list:
print(infoset, probs / np.sum(probs))
ev_calc = CFR()
for infoset, probs in cfr.cum_strategy_table.items():
ev_calc.profiles[0][infoset] = probs / np.sum(probs)
evs = [ev_calc(tree, p, 0, [1, 1]) for p in (0, 1)]
print("EV=", evs)
print("MES:")
for p in (0, 1):
mes_ev_calc = CFR()
mes_ev_calc.profiles[0] = ev_calc.get_mes(p)
print(mes_ev_calc(tree, p, 0, (1, 1)))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment