Skip to content

Instantly share code, notes, and snippets.

@ample-samples
Last active May 5, 2024 03:29
Show Gist options
  • Save ample-samples/9223ce833370df85b323987c85e9d94e to your computer and use it in GitHub Desktop.
Save ample-samples/9223ce833370df85b323987c85e9d94e to your computer and use it in GitHub Desktop.
Get the dot product of any MxN matrices
-- Utility functions for matrix.dot
local function dotScalar1dVector(scalar, v1)
local outputVector = {}
for i = 1, #v1, 1 do
outputVector[i] = v1[i] * scalar
end
return matrix(outputVector)
end
local function dotScalar2dVector(scalar, v1)
local outputVector = {}
for i = 1, #v1, 1 do
outputVector[i] = dotScalar1dVector(scalar, v1[i])
end
return matrix(outputVector)
end
local function dot1dVector1dVector(v1, v2)
if #v1 ~= #v2 then
error("Shape error: #m1 ~= #m2, " .. #v1 .. "~=" .. #v2, nil)
end
local sum = 0
for i = 1, #v1, 1 do
sum = sum + v1[i] * v2[i]
end
return sum
end
local function dot1dVector2dVector(v1, v2)
if #v1 ~= #v2[1] then
error("Shape error: m1 columns ~= m2 rows, " .. #v1 .. "~=" .. #v2[1], nil)
end
local layerOutput = {};
for i = 1, #v2, 1 do
layerOutput[i] = dot1dVector1dVector(v1, v2[i])
end
return layerOutput
end
local function dot2dVector2dVector(v1, v2)
if #v1[1] ~= #v2 then
error("Shape error: m1 columns ~= m2 rows, " .. #v1[1] .. "~=" .. #v2, nil)
end
local result = {}
for i = 1, #v1, 1 do
for j = 1, #v2[1], 1 do
local v1Rowi = v1[i]
local v2Colj = transpose(v2)[j]
if type(result[i]) == "nil" then result[i] = {} end
result[i][j] = dot1dVector1dVector(v1Rowi, v2Colj)
end
end
return result
end
function transpose(v1)
local transposedBoard = {}
for i = 1, #v1[1], 1 do
local newRow = {}
for j = 1, #v1, 1 do
newRow[#newRow+1] = v1[j][i]
end
transposedBoard[#transposedBoard+1] = newRow
end
return transposedBoard
end
--// matrix.dot ( m1, m2 )
-- returns the dot product of two vectors
function matrix.dot(v1, v2)
if type(v1) == "number" and type(v2) == "table" and type(v2[1]) == "number" then
return dotScalar1dVector(v1, v2)
elseif type(v1) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then
return dotScalar2dVector(v1, v2)
elseif type(v1) == "table" and type(v1[1]) == "number" and type(v2) == "table" and type(v2[1]) == "number" then
return dot1dVector1dVector(v1, v2)
elseif type(v1) == "table" and type(v1[1]) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then
return dot1dVector2dVector(v1, v2)
elseif type(v1) == "table" and type(v1[1]) == "table" and type(v1[1][1]) == "number" and type(v2) == "table" and type(v2[1]) == "table" and type(v2[1][1]) == "number" then
return dot2dVector2dVector(v1, v2)
end
error("Not valid matrix inputs")
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment