Last active
July 28, 2018 03:01
-
-
Save MikuAuahDark/9229c39b54946d0355f381e0aed9495c to your computer and use it in GitHub Desktop.
Lua GLSL-like vector & matrix (with vector swizzle mask)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- luagvect | |
-- GLSL-like vector & matrix function | |
local vec = {} | |
local math = require("math") | |
------------------ | |
-- Vector class -- | |
------------------ | |
local swizzle_mt = {mask_map = { | |
x = 1, y = 2, z = 3, w = 4, | |
r = 1, g = 2, b = 3, a = 4, | |
s = 1, t = 2, p = 3, q = 4 | |
}} | |
function swizzle_mt.__index(vec, var) | |
assert(type(var) == "string" and #var <= #vec, "Invalid index") | |
if #var == 1 then | |
return vec[assert(swizzle_mt.mask_map[var], "Invalid swizzle mask")] | |
end | |
local ret_vec = {} | |
for i = 1, #var do | |
local char = var:sub(i, i) | |
ret_vec[i] = vec[assert(swizzle_mt.mask_map[char], "Invalid swizzle mask")] | |
end | |
return (setmetatable(ret_vec, swizzle_mt)) | |
end | |
function swizzle_mt.__newindex(vec, var, val) | |
if #var == 1 then | |
-- float | |
vec[assert(swizzle_mt.mask_map[var])] = assert(type(val) == "number" and val, "Invalid value") | |
else | |
-- vector | |
assert(#val == #var, "Invalid vector") | |
for i = 1, #var do | |
local char = var:sub(i, i) | |
vec[i] = val[assert(swizzle_mt.mask_map[char], "Invalid swizzle mask")] | |
end | |
end | |
end | |
function swizzle_mt.__add(veca, vecb) | |
assert(#veca == #vecb, "Invalid vector") | |
local vecr = {} | |
for i = 1, #veca do | |
vecr[i] = veca[i] + vecb[i] | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
function swizzle_mt.__unm(veca) | |
local vecr = {} | |
for i = 1, #veca do | |
vecr[i] = -veca[i] | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
function swizzle_mt.__sub(veca, vecb) | |
return -vecb + veca | |
end | |
function swizzle_mt.__mul(veca, b) | |
local vecr = {} | |
if type(veca) == "number" then | |
-- swap | |
veca, b = b, veca | |
end | |
if type(b) == "number" then | |
-- vector with scalar | |
for i = 1, #veca do | |
vecr[i] = veca[i] * b | |
end | |
else | |
-- component-wise vector with vector | |
assert(#veca == #b, "Invalid vector") | |
for i = 1, #veca do | |
vecr[i] = veca[i] * b[i] | |
end | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
function swizzle_mt.__div(veca, b) | |
local vecr = {} | |
if type(b) == "number" then | |
-- vector with scalar | |
for i = 1, #veca do | |
vecr[i] = veca[i] / b | |
end | |
else | |
-- component-wise vector with vector | |
assert(#veca == #b, "Invalid vector") | |
for i = 1, #veca do | |
vecr[i] = veca[i] / b[i] | |
end | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
function swizzle_mt.__tostring(vec) | |
local strb = {"vec"} | |
strb[#strb + 1] = tostring(#vec) | |
strb[#strb + 1] = "(" | |
for i = 1, #vec do | |
strb[#strb + 1] = string.format("%f", vec[i]) | |
strb[#strb + 1] = ", " | |
end | |
strb[#strb] = ")" | |
return table.concat(strb) | |
end | |
local function make_vecn(n) | |
return function(...) | |
local vecs = {} | |
local len = select("#", ...) | |
local j = 1 | |
if len == 1 then -- float | |
local v = select(1, ...) | |
for i = 1, n do | |
vecs[i] = v | |
end | |
else | |
for i = 1, select("#", ...) do | |
local v = select(i, ...) | |
if type(v) == "number" then | |
vecs[j] = v | |
j = j + 1 | |
else | |
assert(getmetatable(v) == swizzle_mt) | |
for k = 1, #v do | |
vecs[j] = v[k] | |
j = j + 1 | |
end | |
end | |
end | |
for i = n + 1, #vecs do | |
vecs[i] = nil | |
end | |
end | |
return (setmetatable(vecs, swizzle_mt)) | |
end | |
end | |
vec.vec2 = make_vecn(2) | |
vec.vec3 = make_vecn(3) | |
vec.vec4 = make_vecn(4) | |
------------------ | |
-- Matrix class -- | |
------------------ | |
local matrix_mt = {} | |
function matrix_mt.__add(ma, mb) | |
assert(getmetatable(ma) == getmetatable(mb), "Invalid type") | |
assert(#ma == #mb, "Invalid matrix") | |
local mr = {} | |
for i = 1, #ma do | |
local a = {} | |
assert(#ma[i] == #mb[i], "Invalid matrix") | |
for j = 1, #ma[i] do | |
a[j] = ma[i][j] + mb[i][j] | |
end | |
mr[i] = setmetatable(a, swizzle_mt) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
function matrix_mt.__unm(ma) | |
local mr = {} | |
for i = 1, #ma do | |
local a = {} | |
for j = 1, #ma[i] do | |
a[j] = -ma[i][j] | |
end | |
mr[i] = setmetatable(a, swizzle_mt) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
function matrix_mt.__sub(ma, mb) | |
return -mb + ma | |
end | |
function matrix_mt.__mul(a, mb) | |
local mr = {} | |
if getmetatable(mb) == swizzle_mt then | |
-- matrix with vector | |
local a = {} | |
for i = 1, #mb do | |
a[i] = {mb[i]} | |
end | |
mb = setmetatable(a, matrix_mt) | |
end | |
if type(a) == "number" then | |
-- multiply scalar with matrix | |
for i = 1, #mb do | |
local b = {} | |
for j = 1, #mb[i] do | |
b[j] = mb[i][j] * a | |
end | |
mr[i] = setmetatable(b, swizzle_mt) | |
end | |
else | |
-- matrix with matrix | |
assert(#a[1] == #mb, "Invalid matrix") | |
for i = 1, #a do | |
mr[i] = {} | |
for j = 1, #mb[1] do | |
mr[i][j] = 0 | |
for k = 1, #mb do | |
mr[i][j] = mr[i][j] + a[i][k] * mb[k][j] | |
end | |
end | |
end | |
end | |
if #mr[1] == 1 then | |
-- Change to vector | |
local a = {} | |
for i = 1, #mr do | |
a[i] = mr[i][1] | |
end | |
return (setmetatable(a, swizzle_mt)) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
function matrix_mt.__div(ma, mb) | |
-- component-wise matrix divide | |
assert(getmetatable(ma) == getmetatable(mb), "Invalid type") | |
assert(#ma == #mb, "Invalid matrix") | |
local mr = {} | |
for i = 1, #ma do | |
local a = {} | |
assert(#ma[i] == #mb[i], "Invalid matrix") | |
for j = 1, #ma[i] do | |
a[j] = ma[i][j] / mb[i][j] | |
end | |
mr[i] = setmetatable(a, swizzle_mt) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
function matrix_mt.__tostring(m) | |
local a = {} | |
for i = 1, #m do | |
a[#a + 1] = tostring(m[i]) | |
end | |
return "{"..table.concat(m, ", ").."}" | |
end | |
-- FIXME: Matrix library is buggy atm | |
local function make_mat(m, n) | |
return function(...) | |
local vals = {} | |
local len = select("#", ...) | |
local j = 1 | |
if m == n and len <= 1 then | |
local id = select(1, ...) or 0 | |
for i = 1, m * n do | |
vals[i] = 0 | |
end | |
for i = 0, m-1 do | |
vals[i * m + i + 1] = val | |
end | |
else | |
for i = 1, len do | |
local v = select(i, ...) | |
if type(v) == "number" then | |
vals[j] = v | |
j = j + 1 | |
else | |
assert(getmetatable(v) == swizzle_mt) | |
for k = 1, #v do | |
vals[j] = v[k] | |
j = j + 1 | |
end | |
end | |
end | |
while #vals < m * n do | |
vals[#vals + 1] = 0 | |
end | |
end | |
local mr = {} | |
for i = 1, m do | |
local v = {} | |
for j = 1, n do | |
v[j] = table.remove(vals, 1) | |
end | |
mr[i] = setmetatable(v, swizzle_mt) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
end | |
vec.mat2 = make_mat(2, 2) | |
vec.mat3 = make_mat(3, 3) | |
vec.mat4 = make_mat(4, 4) | |
vec.mat2x2 = vec.mat2 | |
vec.mat3x3 = vec.mat3 | |
vec.mat4x4 = vec.mat4 | |
vec.mat2x3 = make_mat(2, 3) | |
vec.mat2x4 = make_mat(2, 4) | |
vec.mat3x2 = make_mat(3, 2) | |
vec.mat4x2 = make_mat(4, 2) | |
vec.mat3x4 = make_mat(3, 4) | |
vec.mat4x3 = make_mat(4, 3) | |
------------------------------------- | |
-- Component-wise vector functions -- | |
------------------------------------- | |
local function make_component_wise_func(func) | |
return function(vec) | |
if type(vec) == "number" then | |
return func(vec) | |
else | |
local vecr = {} | |
for i = 1, #vec do | |
vecr[i] = func(vec[i]) | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
end | |
end | |
local function make_component_wise_func2(func) | |
return function(veca, vecb) | |
assert(type(veca) == type(vecb), "Invalid type") | |
if type(veca) == "number" then | |
return func(veca, vecb) | |
else | |
local vecr = {} | |
for i = 1, #veca do | |
vecr[i] = func(veca[i], vecb[i]) | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
end | |
end | |
-- from http://www.shaderific.com/glsl-functions/ | |
vec.radians = make_component_wise_func(math.rad) | |
vec.degrees = make_component_wise_func(math.deg) | |
vec.sin = make_component_wise_func(math.sin) | |
vec.cos = make_component_wise_func(math.cos) | |
vec.tan = make_component_wise_func(math.tan) | |
vec.asin = make_component_wise_func(math.asin) | |
vec.acos = make_component_wise_func(math.acos) | |
vec.atan = make_component_wise_func(math.atan) | |
vec.atan2 = make_component_wise_func2(math.atan2) | |
vec.pow = make_component_wise_func2(math.pow) | |
vec.exp = make_component_wise_func(math.exp) | |
vec.log = make_component_wise_func(math.log) | |
vec.exp2 = make_component_wise_func(function(x) return 2^x end) | |
vec.log2 = make_component_wise_func(function(x) return math.log(x) / math.log(2) end) | |
vec.sqrt = make_component_wise_func(math.sqrt) | |
vec.inversesqrt = make_component_wise_func(function(x) return 1 / math.sqrt(x) end) | |
vec.abs = make_component_wise_func(math.abs) | |
vec.sign = make_component_wise_func(function(x) return x > 0 and 1 or (x < 0 and -1 or 0) end) | |
vec.floor = make_component_wise_func(math.floor) | |
vec.ceil = make_component_wise_func(math.ceil) | |
vec.fract = make_component_wise_func(function(x) return x - math.floor(x) end) | |
vec.mod = make_component_wise_func2(function(x, y) return x % y end) | |
function vec.min(x, y) | |
if type(x) == "number" then | |
return math.min(x, y) | |
else | |
local vecr = {} | |
if type(y) == "number" then | |
local z = y | |
y = {z, z, z, z} | |
else | |
assert(#x == #y, "Invalid vector") | |
end | |
for i = 1, #x do | |
vecr[i] = math.min(x[i], y[i]) | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
end | |
function vec.max(x, y) | |
if type(x) == "number" then | |
return math.max(x, y) | |
else | |
local vecr = {} | |
if type(y) == "number" then | |
local z = y | |
y = {z, z, z, z} | |
else | |
assert(#x == #y, "Invalid vector") | |
end | |
for i = 1, #x do | |
vecr[i] = math.max(x[i], y[i]) | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
end | |
function vec.clamp(x, min, max) | |
if type(x) == "number" then | |
return math.min(math.max(x, min), max) | |
else | |
local vecr = {} | |
for i = 1, #x do | |
vecr[i] = math.min(math.max(x[i], min), max) | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
end | |
function vec.mix(x, y, a) | |
if type(x) == "number" and type(y) == "number" then | |
-- Number lerp | |
assert(type(a) == "number", "Invalid type") | |
return x * (1 - a) + y * a | |
end | |
assert(#x == #y, "Invalid vector") | |
if type(a) == "number" then | |
local b = a | |
a = {b, b, b, b} | |
else | |
assert(#x == #a, "Invalid vector") | |
end | |
local vecr = {} | |
for i = 1, #x do | |
vecr[i] = x[i] * (1 - a[i]) + y[i] * a[i] | |
end | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
-------------------------------- | |
-- Vector geometric functions -- | |
-------------------------------- | |
function vec.length(vec) | |
local r = 0 | |
for i = 1, #vec do | |
r = r + vec[i] * vec[i] | |
end | |
return math.sqrt(r) | |
end | |
function vec.distance(veca, vecb) | |
return vec.length(veca - vecb) | |
end | |
function vec.dot(veca, vecb) | |
assert(#veca == #vecb, "Invalid vector") | |
local res = 0 | |
for i = 1, #veca do | |
res = res + veca[i] * vecb[i] | |
end | |
return res | |
end | |
function vec.cross(veca, vecb) | |
assert(#veca == 3 and #vecb == 3, "Invalid vector (must be vec3)") | |
local vecr = {} | |
vecr[1] = veca[2] * vecb[3] - veca[3] * vecb[2] | |
vecr[2] = veca[3] * vecb[1] - veca[1] * vecb[3] | |
vecr[3] = veca[1] * vecb[2] - veca[2] * vecb[1] | |
return (setmetatable(vecr, swizzle_mt)) | |
end | |
function vec.normalize(v) | |
local length = vec.length(v) | |
if length == 0 then return v * 0 end | |
return v / length | |
end | |
function vec.reflect(I, N) | |
return I - 2 * vec.dot(N, I) * N | |
end | |
------------------------------ | |
-- Vector compare functions -- | |
------------------------------ | |
local function make_cmp_function(cmpfunc) | |
return function(veca, vecb) | |
assert(#veca == #vecb, "Invalid vector") | |
local len = #veca | |
if len == 2 then | |
return cmpfunc(veca[1], vecb[1]), cmpfunc(veca[2], vecb[2]) | |
elseif len == 3 then | |
return cmpfunc(veca[1], vecb[1]), cmpfunc(veca[2], vecb[2]), cmpfunc(veca[3], vecb[3]) | |
elseif len == 4 then | |
return cmpfunc(veca[1], vecb[1]), cmpfunc(veca[2], vecb[2]), cmpfunc(veca[3], vecb[3]), cmpfunc(veca[4], vecb[4]) | |
else | |
-- Shouldn't happen | |
local t = {} | |
for i = 1, len do | |
t[i] = cmpfunc(veca[i], vecb[i]) | |
end | |
return unpack(t) | |
end | |
end | |
end | |
vec.lessThan = make_cmp_function(function(a, b) return a < b end) | |
vec.lessThanEqual = make_cmp_function(function(a, b) return a <= b end) | |
vec.greaterThan = make_cmp_function(function(a, b) return a > b end) | |
vec.greaterThanEqual = make_cmp_function(function(a, b) return a >= b end) | |
vec.equal = make_cmp_function(function(a, b) return a == b end) | |
vec.notEqual = make_cmp_function(function(a, b) return a ~= b end) | |
-- Matrix functions | |
function vec.matrixCompMult(ma, mb) | |
-- component-wise matrix multiply | |
assert(getmetatable(ma) == getmetatable(mb), "Invalid type") | |
assert(#ma == #mb, "Invalid matrix") | |
local mr = {} | |
for i = 1, #ma do | |
local a = {} | |
assert(#ma[i] == #mb[i], "Invalid matrix") | |
for j = 1, #ma[i] do | |
a[j] = ma[i][j] * mb[i][j] | |
end | |
mr[i] = setmetatable(a, swizzle_mt) | |
end | |
return (setmetatable(mr, matrix_mt)) | |
end | |
return vec |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment