Skip to content

Instantly share code, notes, and snippets.

@cwchentw
Last active February 9, 2019 07:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cwchentw/f8954546450cc3082ebc9c0b5cfcaf9a to your computer and use it in GitHub Desktop.
Save cwchentw/f8954546450cc3082ebc9c0b5cfcaf9a to your computer and use it in GitHub Desktop.
Math Matrix in Pure Lua (Apache 2.0)
local Matrix = {}
package.loaded["Matrix"] = Matrix
Matrix.__index = Matrix
Matrix.__eq = function (a, b)
assert(a:col() == b:col())
assert(a:row() == b:row())
for i = 1, a:col() do
for j = 1, a:row() do
if a:at(i, j) ~= b:at(i, j) then
return false
end
end
end
return true
end
Matrix.__add = function (a, b)
local _scalar_add = function (s, m)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, s + m:at(i, j))
end
end
return out
end
if type(a) == "number" then
return _scalar_add(a, b)
end
if type(b) == "number" then
return _scalar_add(b, a)
end
assert(a:col() == b:col())
assert(a:row() == b:row())
local out = Matrix:new(a:col(), a:row())
for i = 1, a:col() do
for j = 1, a:row() do
out:setAt(i, j, a:at(i, j) + b:at(i, j))
end
end
return out
end
Matrix.__sub = function (a, b)
local _scalar_sub_first = function (s, m)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, s - m:at(i, j))
end
end
return out
end
local _scalar_sub_second = function (m, s)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, m:at(i, j) - s)
end
end
return out
end
if type(a) == "number" then
return _scalar_sub_first(a, b)
end
if type(b) == "number" then
return _scalar_sub_second(a, b)
end
assert(a:col() == b:col())
assert(a:row() == b:row())
local out = Matrix:new(a:col(), a:row())
for i = 1, a:col() do
for j = 1, a:row() do
out:setAt(i, j, a:at(i, j) - b:at(i, j))
end
end
return out
end
Matrix.__mul = function (a, b)
local _scalar_mul = function (s, m)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, s * m:at(i, j))
end
end
return out
end
if type(a) == "number" then
return _scalar_mul(a, b)
end
if type(b) == "number" then
return _scalar_mul(b, a)
end
assert(a:col() == b:col())
assert(a:row() == b:row())
local out = Matrix:new(a:col(), a:row())
for i = 1, a:col() do
for j = 1, a:row() do
out:setAt(i, j, a:at(i, j) * b:at(i, j))
end
end
return out
end
Matrix.__div = function (a, b)
local _scalar_div_first = function (s, m)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, s / m:at(i, j))
end
end
return out
end
local _scalar_div_second = function (m, s)
local out = Matrix:new(m:col(), m:row())
for i = 1, m:col() do
for j = 1, m:row() do
out:setAt(i, j, m:at(i, j) / s)
end
end
return out
end
if type(a) == "number" then
return _scalar_mul(a, b)
end
if type(b) == "number" then
return _scalar_mul(b, a)
end
assert(a:col() == b:col())
assert(a:row() == b:row())
local out = Matrix:new(a:col(), a:row())
for i = 1, a:col() do
for j = 1, a:row() do
out:setAt(i, j, a:at(i, j) / b:at(i, j))
end
end
return out
end
Matrix.dot = function (a, b)
assert(a:col() == b:row())
local out = Matrix:new(a:row(), b:col())
for i = 1, a:row() do
for j = 1, b:col() do
local temp = 0
for k = 1, a:col() do
temp = temp + a:at(k, i) * b:at(j, k)
end
out:setAt(j, i, temp)
end
end
return out
end
function Matrix:new(c, r)
self = {}
self._col = c
self._row = r
self._mtx = {}
for i = 1, c * r do
self._mtx[i] = 0
end
setmetatable(self, Matrix)
return self
end
function Matrix:fromData(data)
local r = #data
local c = #(data[1])
local m = self:new(c, r)
for i = 1, c do
for j = 1, r do
m:setAt(i, j, data[j][i])
end
end
return m
end
function Matrix:col()
return self._col
end
function Matrix:row()
return self._row
end
function Matrix:at(c, r)
local _c = c - 1
local _r = r - 1
return self._mtx[_c + _r * self:col() + 1]
end
function Matrix:setAt(c, r, value)
local _c = c - 1
local _r = r - 1
self._mtx[_c + _r * self:col() + 1] = value
end
return Matrix
local matrix = require("matrix")
do
local m = matrix:fromData({
{1, 2, 3},
{4, 5, 6}
})
assert(m:at(1, 1) == 1)
assert(m:at(2, 1) == 2)
assert(m:at(3, 1) == 3)
assert(m:at(1, 2) == 4)
assert(m:at(2, 2) == 5)
assert(m:at(3, 2) == 6)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m3 = matrix:fromData({
{1, 2, 3},
{4, 5, 6}
})
assert(m1 == m2)
assert(m1 ~= m3)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = m1 + 3
local m3 = matrix:fromData({
{4, 5, 6},
{5, 6, 7}
})
assert(m2 == m3)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = matrix:fromData({
{3, 2, 1},
{4, 3, 2}
})
local m3 = m1 + m2
local m = matrix:fromData({
{4, 4, 4},
{6, 6, 6}
})
assert(m3 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = 3 - m1
local m = matrix:fromData({
{2, 1, 0},
{1, 0, -1}
})
assert(m2 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = m1 - 3
local m = matrix:fromData({
{-2, -1, 0},
{-1, 0, 1}
})
assert(m2 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = matrix:fromData({
{1, 2, 3},
{4, 5, 6}
})
local m3 = m1 - m2
local m = matrix:fromData({
{0, 0, 0},
{-2, -2, -2}
})
assert(m3 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = m1 * 3
local m = matrix:fromData({
{3, 6, 9},
{6, 9, 12}
})
assert(m2 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = matrix:fromData({
{1, 2, 3},
{4, 5, 6}
})
local m3 = m1 * m2
local m = matrix:fromData({
{1, 4, 9},
{8, 15, 24}
})
assert(m3 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{2, 3, 4}
})
local m2 = matrix:fromData({
{2, 3, 4},
{3, 4, 5}
})
local m3 = m1 / m2
local m = matrix:fromData({
{1/2, 2/3, 3/4},
{2/3, 3/4, 4/5}
})
assert(m3 == m)
end
do
local m1 = matrix:fromData({
{1, 2, 3},
{4, 5, 6}
})
local m2 = matrix:fromData({
{1, 2},
{3, 4},
{5, 6}
})
local m3 = matrix.dot(m1, m2)
local m = matrix:fromData({
{22, 28},
{49, 64}
})
assert(m3 == m)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment