Skip to content

Instantly share code, notes, and snippets.

@panchishin
Last active December 19, 2020 23:15
Show Gist options
  • Save panchishin/d71f098cbb2fd059d3c0c2cd63a634d0 to your computer and use it in GitHub Desktop.
Save panchishin/d71f098cbb2fd059d3c0c2cd63a634d0 to your computer and use it in GitHub Desktop.
MCTS in python from pypi.org/project/mcts/ with unit test example
import time
import math
import random
def randomPolicy(state):
while not state.isTerminal():
try:
action = random.choice(state.getPossibleActions())
except IndexError:
raise Exception("Non-terminal state has no possible actions: " + str(state))
state = state.takeAction(action)
return state.getReward()
class treeNode():
__slots__= ['state','isTerminal','isFullyExpanded','parent','numVisits','totalReward','children']
def __init__(self, state, parent):
self.state = state
self.isTerminal = state.isTerminal()
self.isFullyExpanded = self.isTerminal
self.parent = parent
self.numVisits = 0
self.totalReward = 0
self.children = {}
class mcts():
__slots__= ['timeLimit','limitType','searchLimit','explorationConstant','rollout','root']
def __init__(self, timeLimitSec=None, iterationLimit=None, explorationConstant=math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimitSec != None:
if iterationLimit != None:
raise ValueError("Cannot have both a time limit and an iteration limit")
# time taken for each MCTS search in milliseconds
self.timeLimit = timeLimitSec
self.limitType = 'time'
else:
if iterationLimit == None:
raise ValueError("Must have either a time limit or an iteration limit")
# number of iterations of the search
if iterationLimit < 1:
raise ValueError("Iteration limit must be greater than one")
self.searchLimit = iterationLimit
self.limitType = 'iterations'
self.explorationConstant = explorationConstant
self.rollout = rolloutPolicy
def search(self, initialState):
self.root = treeNode(initialState, None)
if self.limitType == 'time':
timeLimit = time.time() + self.timeLimit
while time.time() < timeLimit:
self.executeRound()
else:
for i in range(self.searchLimit):
self.executeRound()
bestChild = self.getBestChild(self.root, 0)
return self.getAction(self.root, bestChild)
def executeRound(self):
node = self.selectNode(self.root)
reward = self.rollout(node.state)
self.backpropogate(node, reward)
def selectNode(self, node):
while not node.isTerminal:
if node.isFullyExpanded:
node = self.getBestChild(node, self.explorationConstant)
else:
return self.expand(node)
return node
def expand(self, node):
actions = node.state.getPossibleActions()
for action in actions:
if action not in node.children.keys():
newNode = treeNode(node.state.takeAction(action), node)
node.children[action] = newNode
if len(actions) == len(node.children):
node.isFullyExpanded = True
return newNode
raise Exception("Should never reach here")
def backpropogate(self, node, reward):
while node is not None:
node.numVisits += 1
node.totalReward += reward
node = node.parent
def getBestChild(self, node, explorationValue):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = child.totalReward / child.numVisits + explorationValue * math.sqrt(
math.log(node.numVisits) / child.numVisits)
if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
elif nodeValue == bestValue:
bestNodes.append(child)
return random.choice(bestNodes)
def getAction(self, root, bestChild):
for action, node in root.children.items():
if node is bestChild:
return action
# --- UNIT TESTS ---
if __name__ == "__main__":
class DummyState:
def __init__(self, value='s0', actions=None):
self.actions = [] if actions is None else actions
self.state = value
def getPossibleActions(self): # Returns an iterable of all actions which can be taken from this state
return {
's0' : ['a1','a2'] ,
's1' : ['a3','a4'] ,
's2' : ['a5','a6']
}[self.state]
def takeAction(self, action): # Returns the state which results from taking action action
value = {
'a1' : 's1',
'a2' : 's2',
'a3' : 's3',
'a4' : 's4',
'a5' : 's5',
'a6' : 's6'
}[action]
return DummyState(value=value, actions=self.actions + [action])
def isTerminal(self): # Returns whether this state is a terminal state
return self.state in ['s3','s4','s5','s6']
def getReward(self): # Returns the reward for this state. Only needed for terminal states
rewards = {
's1' : 20,
's2' : 10,
's3' : 0,
's5' : 14
}
result = rewards.get(self.state,0)
return result
def hardCodedPolicy(state):
return state.getReward()
mcts_search = mcts(iterationLimit=4, explorationConstant=1, rolloutPolicy=hardCodedPolicy)
best_action = mcts_search.search(DummyState())
assert('a2' == best_action)
action_values = [ [action, node.totalReward/node.numVisits] for action, node in mcts_search.root.children.items()]
assert('a1' == action_values[0][0] )
assert(10 == action_values[0][1] )
assert('a2' == action_values[1][0] )
assert(12 == action_values[1][1] )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment