Skip to content

Instantly share code, notes, and snippets.

@ebobby
Last active March 30, 2020 21:08
Show Gist options
  • Save ebobby/2ad575279d59f8d5529c70268e4182d5 to your computer and use it in GitHub Desktop.
Save ebobby/2ad575279d59f8d5529c70268e4182d5 to your computer and use it in GitHub Desktop.
import heapq
class PriorityQueue:
def __init__(self):
self.elements = []
def empty(self):
return len(self.elements) == 0
def put(self, item, priority):
heapq.heappush(self.elements, (priority, item))
def get(self):
return heapq.heappop(self.elements)[1]
class ChainLinkGraph:
def neighbors(self, n, gn):
"""Returns all possible moves from n.
The returned value is a list of tuples with the move and the cost.
[ (posible_n, total_cost) ...]
"""
if len(n) == 1:
if n[0] > 0:
return [(n, gn)]
else:
return [((1), gn + 3)]
seen = set([])
result = []
for i, section in enumerate(n):
new_n = list(n)
new_gn = gn
valid = True
if section > 1:
new_n[i:i+1] = section - 1, 0
new_gn += 2
elif section == 1:
new_n[i] = 0
new_gn += 2
else:
if new_n[0] == 0:
new_n[i] = 1
new_gn += 3
else:
if len(new_n) == 2:
new_n = [new_n[0]+1]
new_gn += 3
elif new_n[1] == 0:
new_n[0] += 1
new_n.pop(i)
new_gn += 3
elif len(new_n) == 3:
valid = False
else:
new_n.pop(i)
new_n[0:2] = new_n[0] + new_n[1] + 1,
new_gn += 3
new_n.sort(reverse=True)
if valid and tuple(new_n) not in seen:
result.append((tuple(new_n), new_gn))
seen.add(tuple(new_n))
return result
def heuristic(self, n):
"""Returns the estimated cost of moving from n to a closed chain.
This heuristic considers opening a link to cost 2, and to close an open
link to cost 3.
"""
if len(n) == 1:
return 0
open_links = [i for i in n if i == 0]
sections = [i for i in n if i > 0]
we_have = len(open_links)
we_need = len(sections)
# We have enough links to join the sections
if we_have >= we_need:
return len(open_links) * 3
# We need to break some sections
else:
total = 0
diff = 0
for i in sections[::-1]:
diff = we_need - we_have
if diff <= 0:
break
elif diff - i > 0:
total += (i * 2)
we_have += i
we_need -= 1
else:
total += diff * 2
we_have += diff
if diff - i == 0:
we_need -= 1
return total + we_need * 3
def a_star_search(graph, start, goal):
frontier = PriorityQueue()
frontier.put(start, 0)
cost_so_far = {start: 0}
came_from = {start: None}
while not frontier.empty():
current = frontier.get()
if current == goal:
break
for next, new_cost in graph.neighbors(current, cost_so_far[current]):
if next not in cost_so_far or new_cost < cost_so_far[next]:
cost_so_far[next] = new_cost
priority = new_cost + graph.heuristic(next)
frontier.put(next, priority)
came_from[next] = current
return came_from, cost_so_far
graph = ChainLinkGraph()
a_star_search(graph, (3, 3, 3, 3), (12,))
# In [14]: a_star_search(graph, (3,3,3,3), (12,))
# Out[14]:
# ({(3, 3, 3, 3): None,
# (3, 3, 3, 2, 0): (3, 3, 3, 3),
# (3, 3, 2, 2, 0, 0): (3, 3, 3, 2, 0),
# (3, 3, 3, 1, 0, 0): (3, 3, 3, 2, 0),
# (7, 3, 2): (3, 3, 3, 2, 0),
# (3, 3, 2, 1, 0, 0, 0): (3, 3, 3, 1, 0, 0),
# (3, 3, 3, 0, 0, 0): (3, 3, 3, 1, 0, 0),
# (7, 3, 1, 0): (3, 3, 3, 1, 0, 0),
# (3, 3, 2, 0, 0, 0, 0): (3, 3, 3, 0, 0, 0),
# (7, 3, 0, 0): (3, 3, 3, 0, 0, 0),
# (6, 3, 0, 0, 0): (7, 3, 0, 0),
# (7, 2, 0, 0, 0): (7, 3, 0, 0),
# (11, 0): (7, 3, 0, 0),
# (6, 3, 1, 0, 0): (7, 3, 1, 0),
# (7, 2, 1, 0, 0): (7, 3, 1, 0),
# (11, 1): (7, 3, 1, 0),
# (6, 3, 2, 0): (7, 3, 2),
# (7, 2, 2, 0): (7, 3, 2),
# (10, 0, 0): (11, 0),
# (12,): (11, 0),
# (10, 1, 0): (11, 1)},
# {(3, 3, 3, 3): 0,
# (3, 3, 3, 2, 0): 2,
# (3, 3, 2, 2, 0, 0): 4,
# (3, 3, 3, 1, 0, 0): 4,
# (7, 3, 2): 5,
# (3, 3, 2, 1, 0, 0, 0): 6,
# (3, 3, 3, 0, 0, 0): 6,
# (7, 3, 1, 0): 7,
# (3, 3, 2, 0, 0, 0, 0): 8,
# (7, 3, 0, 0): 9,
# (6, 3, 0, 0, 0): 11,
# (7, 2, 0, 0, 0): 11,
# (11, 0): 12,
# (6, 3, 1, 0, 0): 9,
# (7, 2, 1, 0, 0): 9,
# (11, 1): 10,
# (6, 3, 2, 0): 7,
# (7, 2, 2, 0): 7,
# (10, 0, 0): 14,
# (12,): 15, <-- ruta mas corta, cadena unida con costo de 15
# (10, 1, 0): 12})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment