Last active
May 5, 2024 03:29
-
-
Save ample-samples/9223ce833370df85b323987c85e9d94e to your computer and use it in GitHub Desktop.
Get the dot product of any MxN matrices
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Utility functions for matrix.dot | |
local function dotScalar1dVector(scalar, v1) | |
local outputVector = {} | |
for i = 1, #v1, 1 do | |
outputVector[i] = v1[i] * scalar | |
end | |
return matrix(outputVector) | |
end | |
local function dotScalar2dVector(scalar, v1) | |
local outputVector = {} | |
for i = 1, #v1, 1 do | |
outputVector[i] = dotScalar1dVector(scalar, v1[i]) | |
end | |
return matrix(outputVector) | |
end | |
local function dot1dVector1dVector(v1, v2) | |
if #v1 ~= #v2 then | |
error("Shape error: #m1 ~= #m2, " .. #v1 .. "~=" .. #v2, nil) | |
end | |
local sum = 0 | |
for i = 1, #v1, 1 do | |
sum = sum + v1[i] * v2[i] | |
end | |
return sum | |
end | |
local function dot1dVector2dVector(v1, v2) | |
if #v1 ~= #v2[1] then | |
error("Shape error: m1 columns ~= m2 rows, " .. #v1 .. "~=" .. #v2[1], nil) | |
end | |
local layerOutput = {}; | |
for i = 1, #v2, 1 do | |
layerOutput[i] = dot1dVector1dVector(v1, v2[i]) | |
end | |
return layerOutput | |
end | |
local function dot2dVector2dVector(v1, v2) | |
if #v1[1] ~= #v2 then | |
error("Shape error: m1 columns ~= m2 rows, " .. #v1[1] .. "~=" .. #v2, nil) | |
end | |
local result = {} | |
for i = 1, #v1, 1 do | |
for j = 1, #v2[1], 1 do | |
local v1Rowi = v1[i] | |
local v2Colj = transpose(v2)[j] | |
if type(result[i]) == "nil" then result[i] = {} end | |
result[i][j] = dot1dVector1dVector(v1Rowi, v2Colj) | |
end | |
end | |
return result | |
end | |
function transpose(v1) | |
local transposedBoard = {} | |
for i = 1, #v1[1], 1 do | |
local newRow = {} | |
for j = 1, #v1, 1 do | |
newRow[#newRow+1] = v1[j][i] | |
end | |
transposedBoard[#transposedBoard+1] = newRow | |
end | |
return transposedBoard | |
end | |
--// matrix.dot ( m1, m2 ) | |
-- returns the dot product of two vectors | |
function matrix.dot(v1, v2) | |
if type(v1) == "number" and type(v2) == "table" and type(v2[1]) == "number" then | |
return dotScalar1dVector(v1, v2) | |
elseif type(v1) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then | |
return dotScalar2dVector(v1, v2) | |
elseif type(v1) == "table" and type(v1[1]) == "number" and type(v2) == "table" and type(v2[1]) == "number" then | |
return dot1dVector1dVector(v1, v2) | |
elseif type(v1) == "table" and type(v1[1]) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then | |
return dot1dVector2dVector(v1, v2) | |
elseif type(v1) == "table" and type(v1[1]) == "table" and type(v1[1][1]) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then | |
return dot2dVector2dVector(v1, v2) | |
end | |
error("Not valid matrix inputs") | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment