Skip to content

Instantly share code, notes, and snippets.

@seanmor5
Created December 5, 2021 01:30
Show Gist options
  • Save seanmor5/7cfd9f283528454d62e79841d1a2a525 to your computer and use it in GitHub Desktop.
Save seanmor5/7cfd9f283528454d62e79841d1a2a525 to your computer and use it in GitHub Desktop.
defmodule Day4 do
import Nx.Defn
def part1() do
File.read!("aoc/4.txt")
|> parse_input()
|> play_bingo()
|> find_winning_board()
end
def part2() do
File.read!("aoc/4.txt")
|> parse_input()
|> play_bingo_until_last()
|> compute_last_score()
end
defp parse_input(input) do
[draws | boards] =
input
|> String.replace("\r", "")
|> String.split("\n\n")
draws =
draws
|> String.split(",")
|> Enum.map(&String.to_integer/1)
|> Nx.tensor()
{draws, to_matrices(boards)}
end
defp to_matrices(boards) do
boards
|> Enum.map(&to_matrix/1)
|> Nx.stack()
end
defp to_matrix(board) do
# otherwise the end of the board gets cut off, there's
# probably a better way
board = <<board::binary, " "::binary>>
digits =
for <<c::3-binary <- board>> do
c
|> String.trim()
|> String.to_integer()
end
digits
|> Nx.tensor()
|> Nx.reshape({5, 5}, names: [:rows, :columns])
end
defnp play_bingo({draws, boards}) do
# the current draw
current = Nx.tensor(0)
# mask of filled spaces on all boards, nobody has
# anything filled in
mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards)
# iterate through draws, this will be much easier
# when we merge the while loop on leading axis syntax
{current, bingo_mask, _, boards} =
while {current, mask, draws, boards}, Nx.logical_not(bingo?(mask)) do
next_draw = Nx.squeeze(draws[current])
values_to_fill = Nx.equal(boards, next_draw)
update_mask = Nx.logical_or(values_to_fill, mask)
{current + 1, update_mask, draws, boards}
end
{Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)), bingo_mask, boards}
end
defnp bingo?(mask) do
# bingo occurs when the sum along rows or columns is
# 5, thank goodness there are no diagonal bingos :)
any_rows? =
mask
|> Nx.sum(axes: [:rows])
|> Nx.equal(5)
|> Nx.any?()
any_cols? =
mask
|> Nx.sum(axes: [:columns])
|> Nx.equal(5)
|> Nx.any?()
Nx.logical_or(any_rows?, any_cols?)
end
defnp find_winning_board({last_drawn, mask, boards}) do
# the winning board index is the one where the sum of
# the rows or columns is 5, so we can select it with iota
# then slice out the winning board
rows? =
mask
|> Nx.sum(axes: [:rows])
|> Nx.reduce_max(axes: [:columns])
|> Nx.equal(5)
|> Nx.any?()
cols? =
mask
|> Nx.sum(axes: [:columns])
|> Nx.equal(5)
|> Nx.any?()
winning_board_index =
cond do
rows? ->
mask
|> Nx.sum(axes: [:rows])
|> Nx.equal(5)
# we need to reduce away the columns now
|> Nx.sum(axes: [:columns])
|> Nx.select(Nx.iota({100}), 0)
|> Nx.sum()
cols? ->
mask
|> Nx.sum(axes: [:columns])
|> Nx.equal(5)
# we need to reduce away the rows now
|> Nx.sum(axes: [:rows])
|> Nx.select(Nx.iota({100}), 0)
|> Nx.sum()
:otherwise ->
# oh no
Nx.tensor(1_000_000)
end
not_drawn =
mask
|> Nx.slice_axis(winning_board_index, 1, 0)
|> Nx.logical_not()
winning_board =
boards
|> Nx.slice_axis(winning_board_index, 1, 0)
not_drawn
|> Nx.select(winning_board, 0)
|> Nx.sum()
|> Nx.multiply(last_drawn)
end
defnp play_bingo_until_last({draws, boards}) do
# the current draw
current = Nx.tensor(0)
# number of bingos
num_bingos = Nx.tensor(0, type: {:u, 64})
# mask of filled spaces on all boards, nobody has
# anything filled in
mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards)
# iterate through draws, this will be much easier
# when we merge the while loop on leading axis syntax
{current, bingo_mask, _, draws, boards} =
while {current, mask, num_bingos, draws, boards}, Nx.less(num_bingos, 99) do
next_draw = Nx.squeeze(draws[current])
values_to_fill = Nx.equal(boards, next_draw)
update_mask = Nx.logical_or(values_to_fill, mask)
num_bingos = count_bingos(update_mask)
{current + 1, update_mask, num_bingos, draws, boards}
end
{current, bingo_mask, draws, boards}
end
defnp count_bingos(mask) do
# it's possible to have duplicates unfortunately, so we
# need to count unique wins
row_bingos =
mask
|> Nx.sum(axes: [:rows])
|> Nx.equal(5)
|> Nx.select(Nx.iota({100, 5}, axis: 1), -1)
|> Nx.reduce_max(axes: [:columns])
col_bingos =
mask
|> Nx.sum(axes: [:columns])
|> Nx.equal(5)
|> Nx.select(Nx.iota({100, 5}, axis: 1), -1)
|> Nx.reduce_max(axes: [:rows])
row_bingos
|> Nx.not_equal(-1)
|> Nx.logical_or(Nx.not_equal(col_bingos, -1))
|> Nx.sum()
end
defnp compute_last_score({current, mask, draws, boards}) do
row_wins =
mask
|> Nx.sum(axes: [:rows])
|> Nx.equal(5)
|> Nx.reduce_max(axes: [:columns])
col_wins =
mask
|> Nx.sum(axes: [:columns])
|> Nx.equal(5)
|> Nx.reduce_max(axes: [:rows])
loser_idx =
row_wins
|> Nx.logical_or(col_wins)
|> Nx.logical_not()
|> Nx.multiply(Nx.iota({100}))
|> Nx.sum()
loser_mask = Nx.slice_axis(mask, loser_idx, 1, 0)
loser_board = Nx.slice_axis(boards, loser_idx, 1, 0)
{current, loser_mask, _, loser_board} =
while {current, loser_mask, draws, loser_board}, Nx.logical_not(bingo?(loser_mask)) do
next_draw = Nx.squeeze(draws[current])
values_to_fill = Nx.equal(loser_board, next_draw)
update_mask = Nx.logical_or(loser_mask, values_to_fill)
{current + 1, update_mask, draws, loser_board}
end
not_drawn =
loser_mask
|> Nx.logical_not()
last_drawn = Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0))
not_drawn
|> Nx.select(loser_board, 0)
|> Nx.sum()
|> Nx.multiply(last_drawn)
end
end
Day4.part1() |> IO.inspect()
Day4.part2() |> IO.inspect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment