Skip to content

Instantly share code, notes, and snippets.

@thunderInfy
Created February 23, 2022 14:14
Show Gist options
  • Save thunderInfy/eb987f297e87d8f9ef45a22ec3510625 to your computer and use it in GitHub Desktop.
Save thunderInfy/eb987f297e87d8f9ef45a22ec3510625 to your computer and use it in GitHub Desktop.
class Node:
def __init__(self, state, model):
# saves state as a dictionary
self.state = state
# needs access to the neural network model
self.model = model
# W is the total reward and N is the number of playouts
self.W = 0
self.N = 0
self.value = None
self.policy = None
# sets which actions are valid and which are invalid
# in the variables self.valid_actions and
# self.invalid_actions respectively
self.set_action_validity()
# for all valid_actions, initialize new nodes (but don't
# fill them yet with states, i.e., lazy initialization)
self.initialize_edges()
# None if it's not a terminal state, otherwise 'red' or 'green'
# indicating the winner of that terminal state
self.win = None
def initialize_edges(self):
if self.state is not None:
# a dictionary with action tuples as keys
# and nodes as values
self.children = {}
for row in range(args.M):
for col in range(args.N):
if self.valid_actions[row][col]:
self.children[(row,col)] = Node(None, self.model)
def set_action_validity(self):
# what's an invalid action?
# a player can click anywhere on the board
# except those cells where orbs from the
# opposite player reside
if self.state is not None:
if self.state['player_turn'] == 'red':
self.invalid_actions = self.state['array_view']<0
else:
self.invalid_actions = self.state['array_view']>0
self.valid_actions = ~self.invalid_actions
def make_forward_pass(self):
# this function is useful to get
# policy and value for the current node
# both are used with MCTS
# policy is used for tree traversal
# and value is used as an alternative for monte carlo rollouts
with torch.no_grad():
out = self.model(
self.model.state_array_view_to_tensor(self.state)
)
self.policy = out['policy'][0].cpu().numpy()
self.value = out['value'].cpu().item()
self.policy[self.invalid_actions] = 0
# handling rare case where sum becomes zero
# can happen because of treating low magnitude values as zero
if self.policy.sum()==0:
self.policy[self.valid_actions] = 1
self.policy /= self.policy.sum()
def get_policy(self):
if self.policy is None:
self.make_forward_pass()
return self.policy
def get_value(self):
if self.value is None:
self.make_forward_pass()
return self.value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment