Skip to content

Instantly share code, notes, and snippets.

@seanmor5
Last active April 19, 2022 19:22
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seanmor5/8a27ff8048040e22ae012983981f97b7 to your computer and use it in GitHub Desktop.
Save seanmor5/8a27ff8048040e22ae012983981f97b7 to your computer and use it in GitHub Desktop.
defmodule Day3 do
# Since my last answer wasn't purely Nx, I'm going
# to try to stick to Nx as much as is possible, but
# we don't have string manipulation stuff so that will
# have to be done in Elixir
import Nx.Defn
def part1 do
File.read!("aoc/3.txt")
|> parse_input()
|> power_consumption()
end
def part2 do
File.read!("aoc/3.txt")
|> parse_input()
|> compute_ratings()
end
defp parse_input(file) do
file
# we are all children of windows
|> String.replace("\r", "")
|> String.split("\n")
# from byte value, then shift by 48, wouldn't it be nice
# to have some char/string manipulation in Nx? (star for yes)
|> Enum.map(&Nx.subtract(Nx.from_binary(&1, {:u, 8}), Nx.tensor(48)))
|> Nx.stack()
end
defnp power_consumption(bytes) do
count_ones = count_value(bytes, 1, axis: 0)
count_zeros = count_value(bytes, 0, axis: 0)
# tensors are now {bitwidth} shape, so we can compare count
# ones versus count zeros and the result will tell us which
# value is more prevalent in each bit position
gamma = Nx.greater(count_ones, count_zeros)
# gamma are most prevalent bits, so epsilon is logically the
# opposite!
epsilon = Nx.logical_not(gamma)
# convert binary to decimal and multiply
gamma_dec = bin2dec(gamma)
epsilon_dec = bin2dec(epsilon)
Nx.multiply(gamma_dec, epsilon_dec)
end
defnp compute_ratings(bytes) do
# To compute the rating, we build the mask and
# then select values where the mask is true, otherwise
# we select 0, then we sum along the zeroth axis to reduce
# the tensor down to the correct chosen bit values, finally
# we convert to decimal :)
oxygen_rating =
bytes
|> build_mask(&Nx.greater_equal/2)
|> then(&Nx.select(bytes, &1, 0))
|> Nx.sum(axes: [0])
|> bin2dec()
co2_rating =
bytes
|> build_mask(&Nx.less/2)
|> then(&Nx.select(bytes, &1, 0))
|> Nx.sum(axes: [0])
|> bin2dec()
Nx.multiply(oxygen_rating, co2_rating)
end
defnp bin2dec(x, opts \\ []) do
opts = keyword!(opts, bitwidth: 12)
# the binary representation is ordered MSB to LSB,
# so we can obtain this by using iota (a counter)
# and taking element-wise 2^x. Then we reverse (bits
# are MSB to LSB) and take the dot product between
# our binary number and the bit values
2
|> Nx.power(Nx.iota({opts[:bitwidth]}))
|> Nx.reverse()
|> Nx.dot(x)
end
defnp count_value(x, val, opts \\ []) do
# the number of times a value is present in a tensor
# is the sum of the equality x == val, because the equality
# will be computed elementwise (scalar value will be broadcasted)
# and thus the resulting tensor will be all 1's and 0's, 1's in
# positions the value is present, and 0's everywhere else, you
# can compute this along an axis by passing an axis to sum
opts = keyword!(opts, axis: 0)
Nx.sum(Nx.equal(x, val), axes: [opts[:axis]])
end
# we're going to iteratively build a mask over the input
defnp build_mask(bytes, condition, opts \\ []) do
opts = keyword!(opts, bitwidth: 12)
# to start, nothing is masked, so the default mask
# is all true, we also need to make sure that we're
# not slicing passed the bitwidth in the input bytes,
# we can stop when we have exactly `bitwidth` values left
# in the mask (this represents 1 whole value remaining)
{_, mask, _} =
while {i = Nx.tensor(0), mask = Nx.broadcast(Nx.tensor(1, type: {:u, 8}), bytes), bytes},
Nx.logical_and(
Nx.less(i, opts[:bitwidth]),
Nx.not_equal(Nx.sum(mask), opts[:bitwidth])
) do
# slice bytes along the current axis to count the number
# of ones and zeros, we select between bytes and -1 in order
# to show that some of the byte values are no longer considered
# in the count
bytes_slice =
mask
|> Nx.select(bytes, -1)
|> Nx.slice_axis(i, 1, 1)
# we have to squeeze the bytes slice so we get a scalar
count_zeros = count_value(Nx.squeeze(bytes_slice), 0)
count_ones = count_value(Nx.squeeze(bytes_slice), 1)
# condition above is a condition which chooses a value (0 or 1)
# based on what we're trying to compute
value = condition.(count_ones, count_zeros)
# mask slice computes positions in bytes slice that are equal
# to the value chosen in this axis, then we compute the new
# mask where our allowed values are positions where mask slice
# is true AND mask is true (because they are still considered)!
# notice that mask slice has shape {samples, 1}, so it will be
# broadcasted across the bitwidth of the current mask!
updated_mask =
bytes_slice
|> Nx.equal(value)
|> Nx.logical_and(mask)
{Nx.add(i, 1), updated_mask, bytes}
end
mask
end
end
# Part 1
Day3.part1() |> IO.inspect()
# Part 2
Day3.part2() |> IO.inspect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment