Skip to content

Instantly share code, notes, and snippets.

@gdemarcsek
Created April 19, 2020 14:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gdemarcsek/0f917a14db5ae2e62859a71f67279023 to your computer and use it in GitHub Desktop.
Save gdemarcsek/0f917a14db5ae2e62859a71f67279023 to your computer and use it in GitHub Desktop.
Fun with graph algos
import logging
from queue import PriorityQueue
from functools import total_ordering
# Some fun with graph algos I was playing with
class NoSolutionException(Exception):
pass
logger = logging.getLogger(__name__)
@total_ordering
class Node:
"""
For weighted graphs, children must contain tuples of (<weight>, (<node>)). Labels must be unique for a graph.
"""
def __init__(self, label, h=0, children=[]):
self._h = h
self.children = children
self.label = label
@property
def h(self):
if callable(self._h):
return (self._h)(self)
else:
return self._h
def alphabeta_minimax(self, depth, alpha, beta, is_max):
if depth == 0 or len(self.children) == 0:
return self.h
if is_max:
v = -1 * float('Inf')
for child in self.children:
v = max(v, child.alphabeta_minimax(
depth - 1, alpha, beta, False))
alpha = max(alpha, v)
if alpha >= beta:
logger.debug("alpha-cutoff at %s -> %s", self, child)
break
return v
else:
v = 1 * float('Inf')
for child in self.children:
v = min(v, child.alphabeta_minimax(
depth - 1, alpha, beta, True))
beta = min(beta, v)
if alpha >= beta:
logger.debug("beta-cutoff at node %s -> %s", self, child)
break
return v
def traverse_astar(self, target, iter_limit=None):
openQueue = PriorityQueue()
costs = {self: 0}
path = {self: None}
openQueue.put((0, self))
iteration = 0
logger.debug("A* started for: %s ~~> %s", self, target)
while True:
iteration += 1
if openQueue.empty() or iter_limit and iteration > iter_limit:
raise NoSolutionException()
_, n = openQueue.get()
logger.debug("next node is chosen to be: %s", n)
if n == target:
return path, costs
for m in n.children:
childNode, childWeight = m
if childNode not in costs or costs[n] + childWeight < costs[childNode]:
path[childNode] = n
costs[childNode] = costs[n] + childWeight
openQueue.put((costs[childNode] + childNode.h, childNode))
logger.debug("added to queue: %s", childNode)
def __repr__(self):
return self.label
def __hash__(self):
return hash(self.label)
def __lt__(self, other):
return (self.h < other.h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment