Skip to content

Instantly share code, notes, and snippets.

@llaisdy
Created December 28, 2023 17:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llaisdy/dc8d2891fa3f592e1f4912aa221642e1 to your computer and use it in GitHub Desktop.
Save llaisdy/dc8d2891fa3f592e1f4912aa221642e1 to your computer and use it in GitHub Desktop.
Elixir matrix multiplication with Nx.tensors
defmodule MatMul do
use ExUnit.Case
def matmul(m1, m2) do
{rows1, cols1} = m1.shape
{rows2, cols2} = m2.shape
assert cols1 == rows2
m2t = Nx.transpose(m2)
for row <- 0..rows1-1, col <- 0..cols2-1
do
Nx.sum(Nx.multiply(m1[row], m2t[col]))
end
|> Enum.chunk_every(cols2)
|> Enum.map(&Nx.stack/1)
|> Nx.stack()
end
test "matrices 2x3 x 3x2" do
m1 = 1..6
|> Enum.chunk_every(3)
|> Nx.tensor()
m2 = 7..12
|> Enum.chunk_every(2)
|> Nx.tensor()
m3 = [58, 64, 139, 154]
|> Enum.chunk_every(2)
|> Nx.tensor()
assert matmul(m1, m2) == m3
end
test "matrices 3x2 x 2x3" do
m1 = 1..6
|> Enum.chunk_every(2)
|> Nx.tensor()
m2 = 7..12
|> Enum.chunk_every(3)
|> Nx.tensor()
m3 = [27, 30, 33, 61, 68, 75, 95, 106, 117]
|> Enum.chunk_every(3)
|> Nx.tensor()
assert matmul(m1, m2) == m3
end
test "matrices 4x2 x 2x3" do
m1 = 1..8
|> Enum.chunk_every(2)
|> Nx.tensor()
m2 = 9..14
|> Enum.chunk_every(3)
|> Nx.tensor()
m3 = [33, 36, 39, 75, 82, 89, 117, 128, 139, 159, 174, 189]
|> Enum.chunk_every(3)
|> Nx.tensor()
assert matmul(m1, m2) == m3
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment