Last active
December 19, 2020 23:15
-
-
Save panchishin/d71f098cbb2fd059d3c0c2cd63a634d0 to your computer and use it in GitHub Desktop.
MCTS in python from pypi.org/project/mcts/ with unit test example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import math | |
import random | |
def randomPolicy(state): | |
while not state.isTerminal(): | |
try: | |
action = random.choice(state.getPossibleActions()) | |
except IndexError: | |
raise Exception("Non-terminal state has no possible actions: " + str(state)) | |
state = state.takeAction(action) | |
return state.getReward() | |
class treeNode(): | |
__slots__= ['state','isTerminal','isFullyExpanded','parent','numVisits','totalReward','children'] | |
def __init__(self, state, parent): | |
self.state = state | |
self.isTerminal = state.isTerminal() | |
self.isFullyExpanded = self.isTerminal | |
self.parent = parent | |
self.numVisits = 0 | |
self.totalReward = 0 | |
self.children = {} | |
class mcts(): | |
__slots__= ['timeLimit','limitType','searchLimit','explorationConstant','rollout','root'] | |
def __init__(self, timeLimitSec=None, iterationLimit=None, explorationConstant=math.sqrt(2), | |
rolloutPolicy=randomPolicy): | |
if timeLimitSec != None: | |
if iterationLimit != None: | |
raise ValueError("Cannot have both a time limit and an iteration limit") | |
# time taken for each MCTS search in milliseconds | |
self.timeLimit = timeLimitSec | |
self.limitType = 'time' | |
else: | |
if iterationLimit == None: | |
raise ValueError("Must have either a time limit or an iteration limit") | |
# number of iterations of the search | |
if iterationLimit < 1: | |
raise ValueError("Iteration limit must be greater than one") | |
self.searchLimit = iterationLimit | |
self.limitType = 'iterations' | |
self.explorationConstant = explorationConstant | |
self.rollout = rolloutPolicy | |
def search(self, initialState): | |
self.root = treeNode(initialState, None) | |
if self.limitType == 'time': | |
timeLimit = time.time() + self.timeLimit | |
while time.time() < timeLimit: | |
self.executeRound() | |
else: | |
for i in range(self.searchLimit): | |
self.executeRound() | |
bestChild = self.getBestChild(self.root, 0) | |
return self.getAction(self.root, bestChild) | |
def executeRound(self): | |
node = self.selectNode(self.root) | |
reward = self.rollout(node.state) | |
self.backpropogate(node, reward) | |
def selectNode(self, node): | |
while not node.isTerminal: | |
if node.isFullyExpanded: | |
node = self.getBestChild(node, self.explorationConstant) | |
else: | |
return self.expand(node) | |
return node | |
def expand(self, node): | |
actions = node.state.getPossibleActions() | |
for action in actions: | |
if action not in node.children.keys(): | |
newNode = treeNode(node.state.takeAction(action), node) | |
node.children[action] = newNode | |
if len(actions) == len(node.children): | |
node.isFullyExpanded = True | |
return newNode | |
raise Exception("Should never reach here") | |
def backpropogate(self, node, reward): | |
while node is not None: | |
node.numVisits += 1 | |
node.totalReward += reward | |
node = node.parent | |
def getBestChild(self, node, explorationValue): | |
bestValue = float("-inf") | |
bestNodes = [] | |
for child in node.children.values(): | |
nodeValue = child.totalReward / child.numVisits + explorationValue * math.sqrt( | |
math.log(node.numVisits) / child.numVisits) | |
if nodeValue > bestValue: | |
bestValue = nodeValue | |
bestNodes = [child] | |
elif nodeValue == bestValue: | |
bestNodes.append(child) | |
return random.choice(bestNodes) | |
def getAction(self, root, bestChild): | |
for action, node in root.children.items(): | |
if node is bestChild: | |
return action | |
# --- UNIT TESTS --- | |
if __name__ == "__main__": | |
class DummyState: | |
def __init__(self, value='s0', actions=None): | |
self.actions = [] if actions is None else actions | |
self.state = value | |
def getPossibleActions(self): # Returns an iterable of all actions which can be taken from this state | |
return { | |
's0' : ['a1','a2'] , | |
's1' : ['a3','a4'] , | |
's2' : ['a5','a6'] | |
}[self.state] | |
def takeAction(self, action): # Returns the state which results from taking action action | |
value = { | |
'a1' : 's1', | |
'a2' : 's2', | |
'a3' : 's3', | |
'a4' : 's4', | |
'a5' : 's5', | |
'a6' : 's6' | |
}[action] | |
return DummyState(value=value, actions=self.actions + [action]) | |
def isTerminal(self): # Returns whether this state is a terminal state | |
return self.state in ['s3','s4','s5','s6'] | |
def getReward(self): # Returns the reward for this state. Only needed for terminal states | |
rewards = { | |
's1' : 20, | |
's2' : 10, | |
's3' : 0, | |
's5' : 14 | |
} | |
result = rewards.get(self.state,0) | |
return result | |
def hardCodedPolicy(state): | |
return state.getReward() | |
mcts_search = mcts(iterationLimit=4, explorationConstant=1, rolloutPolicy=hardCodedPolicy) | |
best_action = mcts_search.search(DummyState()) | |
assert('a2' == best_action) | |
action_values = [ [action, node.totalReward/node.numVisits] for action, node in mcts_search.root.children.items()] | |
assert('a1' == action_values[0][0] ) | |
assert(10 == action_values[0][1] ) | |
assert('a2' == action_values[1][0] ) | |
assert(12 == action_values[1][1] ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment