Created
April 19, 2020 14:49
-
-
Save gdemarcsek/0f917a14db5ae2e62859a71f67279023 to your computer and use it in GitHub Desktop.
Fun with graph algos
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from queue import PriorityQueue | |
from functools import total_ordering | |
# Some fun with graph algos I was playing with | |
class NoSolutionException(Exception): | |
pass | |
logger = logging.getLogger(__name__) | |
@total_ordering | |
class Node: | |
""" | |
For weighted graphs, children must contain tuples of (<weight>, (<node>)). Labels must be unique for a graph. | |
""" | |
def __init__(self, label, h=0, children=[]): | |
self._h = h | |
self.children = children | |
self.label = label | |
@property | |
def h(self): | |
if callable(self._h): | |
return (self._h)(self) | |
else: | |
return self._h | |
def alphabeta_minimax(self, depth, alpha, beta, is_max): | |
if depth == 0 or len(self.children) == 0: | |
return self.h | |
if is_max: | |
v = -1 * float('Inf') | |
for child in self.children: | |
v = max(v, child.alphabeta_minimax( | |
depth - 1, alpha, beta, False)) | |
alpha = max(alpha, v) | |
if alpha >= beta: | |
logger.debug("alpha-cutoff at %s -> %s", self, child) | |
break | |
return v | |
else: | |
v = 1 * float('Inf') | |
for child in self.children: | |
v = min(v, child.alphabeta_minimax( | |
depth - 1, alpha, beta, True)) | |
beta = min(beta, v) | |
if alpha >= beta: | |
logger.debug("beta-cutoff at node %s -> %s", self, child) | |
break | |
return v | |
def traverse_astar(self, target, iter_limit=None): | |
openQueue = PriorityQueue() | |
costs = {self: 0} | |
path = {self: None} | |
openQueue.put((0, self)) | |
iteration = 0 | |
logger.debug("A* started for: %s ~~> %s", self, target) | |
while True: | |
iteration += 1 | |
if openQueue.empty() or iter_limit and iteration > iter_limit: | |
raise NoSolutionException() | |
_, n = openQueue.get() | |
logger.debug("next node is chosen to be: %s", n) | |
if n == target: | |
return path, costs | |
for m in n.children: | |
childNode, childWeight = m | |
if childNode not in costs or costs[n] + childWeight < costs[childNode]: | |
path[childNode] = n | |
costs[childNode] = costs[n] + childWeight | |
openQueue.put((costs[childNode] + childNode.h, childNode)) | |
logger.debug("added to queue: %s", childNode) | |
def __repr__(self): | |
return self.label | |
def __hash__(self): | |
return hash(self.label) | |
def __lt__(self, other): | |
return (self.h < other.h) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment