Last active
February 9, 2019 07:19
-
-
Save cwchentw/f8954546450cc3082ebc9c0b5cfcaf9a to your computer and use it in GitHub Desktop.
Math Matrix in Pure Lua (Apache 2.0)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local Matrix = {} | |
package.loaded["Matrix"] = Matrix | |
Matrix.__index = Matrix | |
Matrix.__eq = function (a, b) | |
assert(a:col() == b:col()) | |
assert(a:row() == b:row()) | |
for i = 1, a:col() do | |
for j = 1, a:row() do | |
if a:at(i, j) ~= b:at(i, j) then | |
return false | |
end | |
end | |
end | |
return true | |
end | |
Matrix.__add = function (a, b) | |
local _scalar_add = function (s, m) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, s + m:at(i, j)) | |
end | |
end | |
return out | |
end | |
if type(a) == "number" then | |
return _scalar_add(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_add(b, a) | |
end | |
assert(a:col() == b:col()) | |
assert(a:row() == b:row()) | |
local out = Matrix:new(a:col(), a:row()) | |
for i = 1, a:col() do | |
for j = 1, a:row() do | |
out:setAt(i, j, a:at(i, j) + b:at(i, j)) | |
end | |
end | |
return out | |
end | |
Matrix.__sub = function (a, b) | |
local _scalar_sub_first = function (s, m) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, s - m:at(i, j)) | |
end | |
end | |
return out | |
end | |
local _scalar_sub_second = function (m, s) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, m:at(i, j) - s) | |
end | |
end | |
return out | |
end | |
if type(a) == "number" then | |
return _scalar_sub_first(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_sub_second(a, b) | |
end | |
assert(a:col() == b:col()) | |
assert(a:row() == b:row()) | |
local out = Matrix:new(a:col(), a:row()) | |
for i = 1, a:col() do | |
for j = 1, a:row() do | |
out:setAt(i, j, a:at(i, j) - b:at(i, j)) | |
end | |
end | |
return out | |
end | |
Matrix.__mul = function (a, b) | |
local _scalar_mul = function (s, m) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, s * m:at(i, j)) | |
end | |
end | |
return out | |
end | |
if type(a) == "number" then | |
return _scalar_mul(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_mul(b, a) | |
end | |
assert(a:col() == b:col()) | |
assert(a:row() == b:row()) | |
local out = Matrix:new(a:col(), a:row()) | |
for i = 1, a:col() do | |
for j = 1, a:row() do | |
out:setAt(i, j, a:at(i, j) * b:at(i, j)) | |
end | |
end | |
return out | |
end | |
Matrix.__div = function (a, b) | |
local _scalar_div_first = function (s, m) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, s / m:at(i, j)) | |
end | |
end | |
return out | |
end | |
local _scalar_div_second = function (m, s) | |
local out = Matrix:new(m:col(), m:row()) | |
for i = 1, m:col() do | |
for j = 1, m:row() do | |
out:setAt(i, j, m:at(i, j) / s) | |
end | |
end | |
return out | |
end | |
if type(a) == "number" then | |
return _scalar_mul(a, b) | |
end | |
if type(b) == "number" then | |
return _scalar_mul(b, a) | |
end | |
assert(a:col() == b:col()) | |
assert(a:row() == b:row()) | |
local out = Matrix:new(a:col(), a:row()) | |
for i = 1, a:col() do | |
for j = 1, a:row() do | |
out:setAt(i, j, a:at(i, j) / b:at(i, j)) | |
end | |
end | |
return out | |
end | |
Matrix.dot = function (a, b) | |
assert(a:col() == b:row()) | |
local out = Matrix:new(a:row(), b:col()) | |
for i = 1, a:row() do | |
for j = 1, b:col() do | |
local temp = 0 | |
for k = 1, a:col() do | |
temp = temp + a:at(k, i) * b:at(j, k) | |
end | |
out:setAt(j, i, temp) | |
end | |
end | |
return out | |
end | |
function Matrix:new(c, r) | |
self = {} | |
self._col = c | |
self._row = r | |
self._mtx = {} | |
for i = 1, c * r do | |
self._mtx[i] = 0 | |
end | |
setmetatable(self, Matrix) | |
return self | |
end | |
function Matrix:fromData(data) | |
local r = #data | |
local c = #(data[1]) | |
local m = self:new(c, r) | |
for i = 1, c do | |
for j = 1, r do | |
m:setAt(i, j, data[j][i]) | |
end | |
end | |
return m | |
end | |
function Matrix:col() | |
return self._col | |
end | |
function Matrix:row() | |
return self._row | |
end | |
function Matrix:at(c, r) | |
local _c = c - 1 | |
local _r = r - 1 | |
return self._mtx[_c + _r * self:col() + 1] | |
end | |
function Matrix:setAt(c, r, value) | |
local _c = c - 1 | |
local _r = r - 1 | |
self._mtx[_c + _r * self:col() + 1] = value | |
end | |
return Matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local matrix = require("matrix") | |
do | |
local m = matrix:fromData({ | |
{1, 2, 3}, | |
{4, 5, 6} | |
}) | |
assert(m:at(1, 1) == 1) | |
assert(m:at(2, 1) == 2) | |
assert(m:at(3, 1) == 3) | |
assert(m:at(1, 2) == 4) | |
assert(m:at(2, 2) == 5) | |
assert(m:at(3, 2) == 6) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m3 = matrix:fromData({ | |
{1, 2, 3}, | |
{4, 5, 6} | |
}) | |
assert(m1 == m2) | |
assert(m1 ~= m3) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = m1 + 3 | |
local m3 = matrix:fromData({ | |
{4, 5, 6}, | |
{5, 6, 7} | |
}) | |
assert(m2 == m3) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = matrix:fromData({ | |
{3, 2, 1}, | |
{4, 3, 2} | |
}) | |
local m3 = m1 + m2 | |
local m = matrix:fromData({ | |
{4, 4, 4}, | |
{6, 6, 6} | |
}) | |
assert(m3 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = 3 - m1 | |
local m = matrix:fromData({ | |
{2, 1, 0}, | |
{1, 0, -1} | |
}) | |
assert(m2 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = m1 - 3 | |
local m = matrix:fromData({ | |
{-2, -1, 0}, | |
{-1, 0, 1} | |
}) | |
assert(m2 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = matrix:fromData({ | |
{1, 2, 3}, | |
{4, 5, 6} | |
}) | |
local m3 = m1 - m2 | |
local m = matrix:fromData({ | |
{0, 0, 0}, | |
{-2, -2, -2} | |
}) | |
assert(m3 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = m1 * 3 | |
local m = matrix:fromData({ | |
{3, 6, 9}, | |
{6, 9, 12} | |
}) | |
assert(m2 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = matrix:fromData({ | |
{1, 2, 3}, | |
{4, 5, 6} | |
}) | |
local m3 = m1 * m2 | |
local m = matrix:fromData({ | |
{1, 4, 9}, | |
{8, 15, 24} | |
}) | |
assert(m3 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{2, 3, 4} | |
}) | |
local m2 = matrix:fromData({ | |
{2, 3, 4}, | |
{3, 4, 5} | |
}) | |
local m3 = m1 / m2 | |
local m = matrix:fromData({ | |
{1/2, 2/3, 3/4}, | |
{2/3, 3/4, 4/5} | |
}) | |
assert(m3 == m) | |
end | |
do | |
local m1 = matrix:fromData({ | |
{1, 2, 3}, | |
{4, 5, 6} | |
}) | |
local m2 = matrix:fromData({ | |
{1, 2}, | |
{3, 4}, | |
{5, 6} | |
}) | |
local m3 = matrix.dot(m1, m2) | |
local m = matrix:fromData({ | |
{22, 28}, | |
{49, 64} | |
}) | |
assert(m3 == m) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment