Skip to content

Instantly share code, notes, and snippets.

@rain-1
Last active November 2, 2023 19:58
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save rain-1/51944f4ed9318c320cfa0af2a03e6808 to your computer and use it in GitHub Desktop.
Save rain-1/51944f4ed9318c320cfa0af2a03e6808 to your computer and use it in GitHub Desktop.
How is a matrix used to count fish?

This is explaining stuff relevant to AOC 2021 day 6

How is a matrix used to count fish?

First lets do fibonacci numbers because it's smaller (2x2 matrix instead of 9x9) and it's familiar ground.

So you can implement fibs like this:

def fib(n):
    x = 0
    y = 1
    for _ in range(0, n):
        x,y = y,x+y
    return y

print(fib(1))
print(fib(2))
print(fib(3))
print(fib(4))
print(fib(5))

matrix and vector multiplication

A matrix is a grid of numbers. I forget this every time but we do rows then cols. So:

/ 3 4 5 \
\ 7 2 3 /

is a 2x3 matrix.

A vector is just a matrix that's nx1.

Ill show the formula to multiply a 2x2 matrix with a 2x1 vector (the result is another 2x1 vector):

/ a b \ / x \  _  / ax + by \
\ c d / \ y /  -  \ cx + dy /

and implementation in python

def m2x2_times_v2x1(m,v):
    a,b,c,d=m[0],m[1],m[2],m[3]
    x,y=v[0],v[1]
    return [a*x+b*y, c*x+d*y]

Now for why it's ax+by and not ax+bx or anything else there's a whole bunch of math that people spend years learning. But really all that matters here is that we can use a 2x2 matrix to take our value x,y and compute two new values.

Rewriting fibs using a matrix

Look at what this specific matrix does:

/ 0 1 \ / x \  _  / 0*x + 1*y \  _  /  y  \
\ 1 1 / \ y /  -  \ 1*x + 1*y /  -  \ x+y /

this is exactly the inner loop of the fibs program, so I can rewrite fibs to use a matrix:

def fib_m(n):
    m=[0,1,
       1,1]
    v=[0,1]
    for _ in range(0, n):
        v = m2x2_times_v2x1(m, v)
    return v[1]

Applying binary exponentiation to matrix math

Mathematically what we are doing to calcuate fib_m(5) (for example) is using a matrix and vector

m = [0,1; 1,1] v = [0; 1]

and the we calculate

m * m * m * m * m * v

and take the second coordinate of the vector we get.

This can be written m^5 * v.

And there is a really clever technique to efficiently calculate m^n called binary exponentiation.

https://en.wikipedia.org/wiki/Exponentiation_by_squaring

Using binary exponentiation lets you accelerate the calculation from O(n) to O(log(n)).

Code

def fib(n):
    a = 0
    b = 1
    for _ in range(0, n):
        a,b = b,a+b
    return b

print(fib(1))
print(fib(2))
print(fib(3))
print(fib(4))
print(fib(5))
print("")

def m2x2_times_v2x1(m,v):
    a,b,c,d=m[0],m[1],m[2],m[3]
    x,y=v[0],v[1]
    return [a*x+b*y, c*x+d*y]

def fib_m(n):
    m=[0,1,
       1,1]
    v=[0,1]
    for _ in range(0, n):
        v = m2x2_times_v2x1(m, v)
    return v[1]

print("")
print(fib_m(1))
print(fib_m(2))
print(fib_m(3))
print(fib_m(4))
print(fib_m(5))

How can this technique be applied to the fish problem?

The input is provided like this:

Initial state: 3,4,3,1,2
After  1 day:  2,3,2,0,1
After  2 days: 1,2,1,6,0,8
After  3 days: 0,1,0,5,6,7,8
After  4 days: 6,0,6,4,5,6,7,8,8

The order of the lists don't matter so lets sort them

Initial state: 1,2,3,3,4
After  1 day:  0,1,2,2,3
After  2 days: 0,1,1,2,6,8
After  3 days: 0,0,1,5,6,7,8
After  4 days: 0,4,5,6,6,6,7,8,8

and in fact the numbers are always from 0 to 8 so we can use a fixed size vector to hold the counts, for a really efficient representation

               0,1,2,3,4,5,6,7,8
Initial state: 0,1,1,2,1,0,0,0,0
After  1 day:  1,1,2,1,0,0,0,0,0
After  2 days: 1,2,1,0,0,0,1,0,1
After  3 days: 2,1,0,0,0,1,1,1,1
After  4 days: 1,0,0,0,1,1,3,1,2

Now the goal is to come up with a matrix M that brings a state to the next days state, i.e.

M * (0,1,1,2,1,0,0,0,0) = (1,1,2,1,0,0,0,0,0)
M * (1,1,2,1,0,0,0,0,0) = (1,2,1,0,0,0,1,0,1)
M * (1,2,1,0,0,0,1,0,1) = (2,1,0,0,0,1,1,1,1)

and so on.

This is actually possible to do, the reason for this is that the problem is "linear". I wont go into the details of this but there's a whole area of math about this called linear algebra.

The general pattern that is happening is:

M * (a,b,c,d,e,f,g,h,i) = (b,c,d,e,f,g,h+a,i,a)

This is pretty similar to the fibonacci setup, but it uses a larger vector and hence a larger matrix.

There's a standard way to perform the "shift" backwards using a permutation matrix

and all you need to do is add an extra 1 to that matrix to get the h+a term.

@dnnnp
Copy link

dnnnp commented Dec 7, 2021

Interesting read, thanks for sharing.

@timvisee
Copy link

timvisee commented Dec 7, 2021

I complete 256 iterations/days in 2.8μs without a matrix. I wonder if a matrix will make it faster for such a small iteration count.

https://github.com/timvisee/advent-of-code-2021/blob/master/day06b/src/main.rs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment