Skip to content

Instantly share code, notes, and snippets.

@algmyr
Last active July 1, 2023 09:54
Show Gist options
  • Save algmyr/4ef4080a617f3b04e8d0f01fedccbee3 to your computer and use it in GitHub Desktop.
Save algmyr/4ef4080a617f3b04e8d0f01fedccbee3 to your computer and use it in GitHub Desktop.
descramble.py
import heapq
from collections import defaultdict
from dataclasses import dataclass
import sys
from typing import DefaultDict, Optional
import numpy as np
from PIL import Image, ImageDraw
CELL_WIDTH = 96
CELL_HEIGHT = 128
GRID_SIZE_X = 14
GRID_SIZE_Y = 16
@dataclass
class Piece:
img: Image.Image
ident: int
def __eq__(self, other):
return self.ident == other.ident
def __hash__(self):
return hash(self.ident)
def left_edge(self) -> np.ndarray:
"""Returns the left edge of the piece as a 1D numpy array"""
return np.array(self.img.crop((0, 0, 1, self.img.height)).getdata())
def right_edge(self) -> np.ndarray:
"""Returns the right edge of the piece as a 1D numpy array"""
return np.array(
self.img.crop(
(self.img.width - 1, 0, self.img.width, self.img.height)
).getdata(),
)
def top_edge(self) -> np.ndarray:
"""Returns the top edge of the piece as a 1D numpy array"""
return np.array(self.img.crop((0, 0, self.img.width, 1)).getdata())
def bottom_edge(self) -> np.ndarray:
"""Returns the bottom edge of the piece as a 1D numpy array"""
return np.array(
self.img.crop(
(0, self.img.height - 1, self.img.width, self.img.height)
).getdata(),
)
RIGHT = 0
TOP = 1
LEFT = 2
BOTTOM = 3
@dataclass
class Edge:
start: Piece
end: Piece
weight: float
direction: int
def as_tuple(self):
return (self.start, self.end, self.weight, self.direction)
def __eq__(self, other):
return self.as_tuple() == other.as_tuple()
def __hash__(self):
return hash(self.as_tuple())
def __lt__(self, other):
return self.weight < other.weight
class Heap:
def __init__(self):
self.data = []
def push(self, item):
heapq.heappush(self.data, item)
def push_all(self, items):
for item in items:
heapq.heappush(self.data, item)
def pop(self):
return heapq.heappop(self.data)
def distance(pixels1: np.ndarray, pixels2: np.ndarray) -> Optional[float]:
"""Returns the distance between two 1D numpy arrays of pixels.
Distance is defined as the sum of squared differences between each pixel.
"""
a = pixels1 / 255.0
b = pixels2 / 255.0
# Hacky way to ignore solid color edges.
eps = 1e-3
a_var = np.var(a.reshape((a.size//3, 3)).mean(axis=1))
b_var = np.var(b.reshape((b.size//3, 3)).mean(axis=1))
if a_var < eps or b_var < eps:
return None
return np.sqrt(np.sum(np.square(a - b))) / a.size
def chop_up_image(img: Image.Image) -> list[Piece]:
pieces: list[Piece] = []
for x in range(GRID_SIZE_X):
for y in range(GRID_SIZE_Y):
# draw_rect(img, x * GRID_WIDTH, y * GRID_HEIGHT, GRID_WIDTH, GRID_HEIGHT)
cropped = img.crop(
(
x * CELL_WIDTH,
y * CELL_HEIGHT,
x * CELL_WIDTH + CELL_WIDTH,
y * CELL_HEIGHT + CELL_HEIGHT,
)
)
pieces.append(Piece(cropped, len(pieces)))
return pieces
def create_graph(pieces: list[Piece]) -> DefaultDict[Piece, list[Edge]]:
graph: DefaultDict[Piece, list[Edge]] = defaultdict(list)
for i in range(len(pieces)):
print(i)
for j in range(i + 1, len(pieces)):
if i == j:
continue
a = pieces[i]
b = pieces[j]
def add_pair(a: Piece, b: Piece, dist: Optional[float], direction: int):
if dist is None:
return
opposite = (direction + 2) % 4
graph[a].append(Edge(a, b, dist, direction))
graph[b].append(Edge(b, a, dist, opposite))
add_pair(a, b, distance(a.right_edge(), b.left_edge()), RIGHT)
add_pair(a, b, distance(a.top_edge(), b.bottom_edge()), TOP)
add_pair(a, b, distance(a.left_edge(), b.right_edge()), LEFT)
add_pair(a, b, distance(a.bottom_edge(), b.top_edge()), BOTTOM)
return graph
def build_mst(
start_piece: Piece, num_pieces: int, graph: DefaultDict[Piece, list[Edge]]
) -> tuple[dict[Piece, tuple[int, int]], set[Piece], list[tuple[Piece, Piece]]]:
# Start with the first piece
placed = {start_piece}
used_locations = {(int(0), int(0))}
locations = {start_piece: (int(0), int(0))}
active_edges = Heap()
active_edges.push_all(graph[start_piece])
pairs: list[tuple[Piece, Piece]] = []
# Inefficiently find a minimum spanning tree
while len(placed) < num_pieces:
min_edge = active_edges.pop()
if min_edge.end in placed:
continue
start_x, start_y = locations[min_edge.start]
if min_edge.direction == RIGHT:
end_x, end_y = start_x + 1, start_y
elif min_edge.direction == TOP:
end_x, end_y = start_x, start_y - 1
elif min_edge.direction == LEFT:
end_x, end_y = start_x - 1, start_y
elif min_edge.direction == BOTTOM:
end_x, end_y = start_x, start_y + 1
else:
assert False
if (end_x, end_y) in used_locations:
continue
pairs.append((min_edge.start, min_edge.end))
used_locations.add((end_x, end_y))
locations[min_edge.end] = (end_x, end_y)
placed.add(min_edge.end)
active_edges.push_all(graph[min_edge.end])
return locations, placed, pairs
def stuff(pieces: list[Piece], graph: DefaultDict[Piece, list[Edge]]):
locations, placed, pairs = build_mst(pieces[0], len(pieces), graph)
# Normalize locations
min_x = min(locations.values(), key=lambda x: x[0])[0]
min_y = min(locations.values(), key=lambda x: x[1])[1]
for piece in locations:
x, y = locations[piece]
locations[piece] = (x - min_x, y - min_y)
# Fill image with black
width = round(1.5 * CELL_WIDTH * GRID_SIZE_X)
height = round(1.5 * CELL_HEIGHT * GRID_SIZE_Y)
result = Image.new('RGB', (width, height), (0, 0, 0))
# Paste the pieces into the image
for piece in placed:
x, y = locations[piece]
result.paste(piece.img, (x * CELL_WIDTH, y * CELL_HEIGHT))
# Draw the MST in a nice color gradient
draw = ImageDraw.Draw(result)
color1 = (255, 0, 255) # magenta
color2 = (0, 255, 255) # cyan
def interp(i: int) -> tuple[int, int, int]:
t = i / (len(pairs) - 1)
return tuple(int((1 - t) * c1 + t * c2) for c1, c2 in zip(color1, color2))
for i, (a, b) in enumerate(pairs):
x1, y1 = locations[a]
x2, y2 = locations[b]
draw.line(
[
(
x1 * CELL_WIDTH + CELL_WIDTH // 2,
y1 * CELL_HEIGHT + CELL_HEIGHT // 2,
),
(
x2 * CELL_WIDTH + CELL_WIDTH // 2,
y2 * CELL_HEIGHT + CELL_HEIGHT // 2,
),
],
fill=interp(i),
width=5,
)
result.save('out.png')
def main(in_file):
pieces = chop_up_image(Image.open(in_file))
graph = create_graph(pieces)
stuff(pieces, graph)
if __name__ == '__main__':
main(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment