Created
December 28, 2023 17:28
-
-
Save llaisdy/dc8d2891fa3f592e1f4912aa221642e1 to your computer and use it in GitHub Desktop.
Elixir matrix multiplication with Nx.tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule MatMul do | |
use ExUnit.Case | |
def matmul(m1, m2) do | |
{rows1, cols1} = m1.shape | |
{rows2, cols2} = m2.shape | |
assert cols1 == rows2 | |
m2t = Nx.transpose(m2) | |
for row <- 0..rows1-1, col <- 0..cols2-1 | |
do | |
Nx.sum(Nx.multiply(m1[row], m2t[col])) | |
end | |
|> Enum.chunk_every(cols2) | |
|> Enum.map(&Nx.stack/1) | |
|> Nx.stack() | |
end | |
test "matrices 2x3 x 3x2" do | |
m1 = 1..6 | |
|> Enum.chunk_every(3) | |
|> Nx.tensor() | |
m2 = 7..12 | |
|> Enum.chunk_every(2) | |
|> Nx.tensor() | |
m3 = [58, 64, 139, 154] | |
|> Enum.chunk_every(2) | |
|> Nx.tensor() | |
assert matmul(m1, m2) == m3 | |
end | |
test "matrices 3x2 x 2x3" do | |
m1 = 1..6 | |
|> Enum.chunk_every(2) | |
|> Nx.tensor() | |
m2 = 7..12 | |
|> Enum.chunk_every(3) | |
|> Nx.tensor() | |
m3 = [27, 30, 33, 61, 68, 75, 95, 106, 117] | |
|> Enum.chunk_every(3) | |
|> Nx.tensor() | |
assert matmul(m1, m2) == m3 | |
end | |
test "matrices 4x2 x 2x3" do | |
m1 = 1..8 | |
|> Enum.chunk_every(2) | |
|> Nx.tensor() | |
m2 = 9..14 | |
|> Enum.chunk_every(3) | |
|> Nx.tensor() | |
m3 = [33, 36, 39, 75, 82, 89, 117, 128, 139, 159, 174, 189] | |
|> Enum.chunk_every(3) | |
|> Nx.tensor() | |
assert matmul(m1, m2) == m3 | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment