Skip to content

Instantly share code, notes, and snippets.

@afiodorov
Last active December 17, 2023 14:26
Show Gist options
  • Save afiodorov/9002ae0c012e77c4a60b53bed7181673 to your computer and use it in GitHub Desktop.
Save afiodorov/9002ae0c012e77c4a60b53bed7181673 to your computer and use it in GitHub Desktop.
advent of code 2023 day 17 a* algo
import pandas as pd
from queue import PriorityQueue
from dataclasses import dataclass
from functools import cache
grid = pd.read_csv("./data/17.txt", names=[0], dtype='str').apply(lambda x: pd.Series(list(x[0])), axis=1).astype(int)
directions = [(0, 1), (1, 0), (-1, 0), (0, -1)]
@dataclass(frozen=True, order=True)
class NodeState:
position: tuple
direction: int
step_count: int
@cache
def enhanced_heuristic(position, goal):
x, y = position
subgrid = grid.iloc[y:, x:]
return min(subgrid.min(axis=0).sum(), subgrid.min(axis=1).sum())
def line(pos0, pos1):
x0, y0 = pos0
x1, y1 = pos1
if x0 == x1: # Vertical line
step = 1 if y1 > y0 else -1
for y in range(y0 + step, y1 + step, step):
yield (x0, y)
elif y0 == y1: # Horizontal line
step = 1 if x1 > x0 else -1
for x in range(x0 + step, x1 + step, step):
yield (x, y0)
def get_neighbors_with_restriction(node_state):
for idx, (dx, dy) in enumerate(directions):
if idx == node_state.direction and node_state.step_count >= 3:
continue
if directions[node_state.direction] == (-dx, -dy):
continue
x, y = node_state.position[0] + dx, node_state.position[1] + dy
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]:
new_step_count = node_state.step_count + 1 if idx == node_state.direction else 1
yield NodeState(position=(x, y), direction=idx, step_count=new_step_count)
# In[2]:
def astar(get_neighbors_with_restriction=get_neighbors_with_restriction):
start = (0, 0)
goal = (grid.shape[1] - 1, grid.shape[0] - 1)
frontier = PriorityQueue()
frontier.put((0, NodeState(position=start, direction=1, step_count=1)))
frontier.put((0, NodeState(position=start, direction=0, step_count=1)))
cost_so_far = {start_state: 0 for _, start_state in frontier.queue}
came_from = {start_state: None for _, start_state in frontier.queue}
while not frontier.empty():
_, current = frontier.get()
if current.position == goal:
break
for next_state in get_neighbors_with_restriction(current):
new_cost = cost_so_far[current] + sum(grid.iloc[y, x] for x, y in line(current.position, next_state.position))
if next_state not in cost_so_far or new_cost < cost_so_far[next_state]:
cost_so_far[next_state] = new_cost
priority = new_cost + enhanced_heuristic(next_state.position, goal)
frontier.put((priority, next_state))
came_from[next_state] = current
return cost_so_far[current], came_from, current
c, came_from, current = astar()
c
# In[3]:
def part2(node_state):
if node_state.step_count == 1:
dx, dy = directions[node_state.direction]
x, y = node_state.position[0] + 3 * dx, node_state.position[1] + 3 * dy
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]:
yield NodeState(
position=(x, y),
direction=node_state.direction,
step_count=4,
)
return
for idx, (dx, dy) in enumerate(directions):
if idx == node_state.direction and node_state.step_count >= 10:
continue
if directions[node_state.direction] == (-dx, -dy):
continue
x, y = node_state.position[0] + dx, node_state.position[1] + dy
if 0 <= x < grid.shape[1] and 0 <= y < grid.shape[0]:
new_step_count = node_state.step_count + 1 if idx == node_state.direction else 1
yield NodeState(position=(x, y), direction=idx, step_count=new_step_count)
c, came_from, current = astar(part2)
c
# In[4]:
# path = [current]
# while (n := came_from[path[-1]]) is not None:
# path.append(n)
# path = list(reversed(path))
# g = grid.copy().astype(str)
# for f, t in zip(path, path[1:]):
# for x, y in line(f.position, t.position):
# g.iloc[y, x] = '#'
# g
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment