Skip to content

Instantly share code, notes, and snippets.

@charmasaur
Created December 19, 2021 12:32
Show Gist options
  • Save charmasaur/7734c895bad585ae3bcc8221fb8572c3 to your computer and use it in GitHub Desktop.
Save charmasaur/7734c895bad585ae3bcc8221fb8572c3 to your computer and use it in GitHub Desktop.
AOC 2021 19
import numpy as onp
import jax
import jax.numpy as np
from jax import lax
# Input
scanners = []
for line in open("input.txt", "r").readlines():
if line.startswith("--"):
scanners.append([])
continue
if len(line.strip()) == 0:
continue
scanners[-1].append(np.array(list((map(int, line.strip().split(","))))))
scanners = list(map(np.stack, scanners))
# Rotation matrices
X = np.array([1, 0, 0])
Y = np.array([0, 1, 0])
Z = np.array([0, 0, 1])
rot_mats = []
for s1 in [-1, 1]:
for s2 in [-1, 1]:
for v1 in [X, Y, Z]:
for v2 in [X, Y, Z]:
if np.array_equal(v2, v1):
continue
x = v1 * s1
y = v2 * s2
z = np.cross(x, y)
rot_mats.append(np.stack([x, y, z], axis=1))
rot_mats = np.stack(rot_mats)
# Convert 3D coordinates to 1D coordinates for faster comparison (lol)
to_1d_vec = np.array([100000*100000, 100000, 1])[:, None]
def to_1d(x):
return ((x+50000) @ to_1d_vec)[..., 0]
# Check two scanners for overlaps, and move the second into the first's coordinate system if so
@jax.jit
def check_overlap(s1, s2):
done = False
transformed = s2
center = np.array([0, 0, 0])
(_, _, done, transformed, center), _ = lax.scan(
check_overlap_with_rot, (s1, s2, done, transformed, center), rot_mats)
return done, transformed, center
def check_overlap_with_rot(carry, mat):
s1, s2_original, done, transformed, center = carry
s2 = (mat @ s2_original[:, :, None])[..., 0]
(_, _, done, transformed, center), _ = lax.scan(
check_overlap_with_s1_pivot, (s1, s2, done, transformed, center), s1)
return (s1, s2_original, done, transformed, center), None
def check_overlap_with_s1_pivot(carry, p1):
s1, s2, done, transformed, center = carry
(_, _, _, done, transformed, center), _ = lax.scan(
check_overlap_with_s2_pivot, (s1, s2, p1, done, transformed, center), s2)
return (s1, s2, done, transformed, center), None
def check_overlap_with_s2_pivot(carry, p2):
s1, s2, p1, done, transformed, center = carry
s1_shifted = to_1d(s1-p1)
s2_shifted = to_1d(s2-p2)
diff = np.abs(s1_shifted[:, None]-s2_shifted[None, :])
cond = np.count_nonzero(diff == 0) >= 12
done, transformed, center = lax.cond(
cond,
lambda: (True, s2-p2+p1, -p2+p1),
lambda: (done, transformed, center),
)
return (s1, s2, p1, done, transformed, center), None
# Process input
done = [0]
done_index = 0
transformed = [scanners[0]]
centers = [np.array([0, 0, 0])]
while len(done) != len(scanners):
for i in range(len(scanners)):
if i in done:
continue
r, scanners[i], center = check_overlap(scanners[done[done_index]], scanners[i])
if r:
print(f"Found {i}")
transformed.append(scanners[i])
done.append(i)
centers.append(center)
done_index += 1
combined = set()
for sensor in transformed:
for point in sensor:
combined.add(tuple(onp.asarray(point)))
print(len(combined))
best = 0
for a in centers:
for b in centers:
best = max(best, np.sum(np.abs(a-b)))
print(best)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment