Skip to content

Instantly share code, notes, and snippets.

@andrealaforgia
Last active December 12, 2021 16:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andrealaforgia/9932688d70b01acdcff0bfc135a60ef6 to your computer and use it in GitHub Desktop.
Save andrealaforgia/9932688d70b01acdcff0bfc135a60ef6 to your computer and use it in GitHub Desktop.
class Node:
def __init__(self, name):
self.adjacent_nodes = []
self.name = name
self.traversed = 0
def add_adjacent_node(self, adjacent_node):
self.adjacent_nodes.append(adjacent_node)
def is_small_cave(self):
return self.name not in ["start", "end"] and self.name[0].islower()
def load_nodes(lines):
node_dict = {}
start_node = None
end_node = None
for line in lines:
cave_names = line.split("-")
node1 = node_dict.get(cave_names[0]) or Node(cave_names[0])
node2 = node_dict.get(cave_names[1]) or Node(cave_names[1])
node1.add_adjacent_node(node2)
node2.add_adjacent_node(node1)
node_dict[cave_names[0]] = node1
node_dict[cave_names[1]] = node2
if node1.name == "start":
start_node = node1
elif node1.name == "end":
end_node = node1
if node2.name == "start":
start_node = node2
elif node2.name == "end":
end_node = node2
return start_node, end_node
def node_can_be_traversed(node, path):
if node.is_small_cave():
if len(list(filter(lambda n: n.is_small_cave() and n.traversed > 1, path))) == 0:
return True
return node.traversed == 0
def find_paths(current_node, path, paths):
if current_node == end_node:
paths.append(path)
return
reachable_nodes = [node for node in current_node.adjacent_nodes if node_can_be_traversed(node, path)]
if len(reachable_nodes) == 0:
return
for node in reachable_nodes:
new_path = path.copy()
new_path.append(node)
if node.is_small_cave():
node.traversed += 1
find_paths(node, new_path, paths)
if node.is_small_cave():
node.traversed -= 1
else:
node.traversed = 0
def print_path(path):
s = ""
for node in path:
s += node.name
if node.name != "end":
s += ","
print(s)
if __name__ == "__main__":
with open("input3.txt") as file:
lines = [line.strip() for line in file.readlines()]
start_node, end_node = load_nodes(lines)
start_node.traversed = 1
paths = []
find_paths(start_node, [start_node], paths)
for path in paths:
print_path(path)
print(len(paths))
assert len(paths) == 99448
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment