Skip to content

Instantly share code, notes, and snippets.

@Egor-Skriptunoff
Created January 14, 2017 22:39
Show Gist options
  • Save Egor-Skriptunoff/7eb078ccbc8923d7c3fcfe53ed341b50 to your computer and use it in GitHub Desktop.
Save Egor-Skriptunoff/7eb078ccbc8923d7c3fcfe53ed341b50 to your computer and use it in GitHub Desktop.
-- long_unsigned_integers.lua
---------------------------------------------------------------------------------------------------------
-- LONG ARITHMETIC (non-negative integers of arbitrary length)
---------------------------------------------------------------------------------------------------------
-- Compatible with Lua 5.1, Lua 5.2, Lua 5.3, LuaJIT
-- THIS MODULE WAS NOT OPTIMIZED FOR VERY LONG NUMBERS (multiplication has quadratic time complexity)
-- Usage:
-- local L = require("long_unsigned_integers")
-- print( "13" * L("42") / 5 + L(314) ^ 7 )
-- if L("12321") % 37 == L(0) then print("multiple of 37") end -- "== 0" will not work
---------------------------------------------------------------------------------------------------------
local sqrt, floor, min, type, setmetatable, getmetatable = math.sqrt, math.floor, math.min, type, setmetatable, getmetatable
local L_M = 16777216 -- 2^24, must be power of 2
local L_add_short, L_mul_short
local function L_norm(dest)
local carry, r = 0
for j = 1, #dest do
carry = carry + dest[j]
r = carry % L_M
dest[j] = r
carry = (carry - r) / L_M
end
if carry < 0 then
error"Negative numbers are not allowed"
end
while carry ~= 0 do
r = carry % L_M
dest[#dest + 1] = r
carry = (carry - r) / L_M
end
for j = #dest, 1, -1 do
if dest[j] == 0 then
dest[j] = nil
else
break
end
end
return dest
end
local function L_to_float(src)
local result = 0.0
for j = #src, 1, -1 do
result = result * L_M + src[j]
end
return result
end
local function L_new(n)
local t = type(n)
if t == "number" then -- create from number
return L_norm{n}
elseif t == "table" then -- create from another long
local result = {}
for j = 1, #n do
result[j] = n[j]
end
return result
elseif t == "string" then -- create from string
local result = {}
for digit in n:gmatch"%d" do
result = L_add_short(L_mul_short(result, 10), tonumber(digit))
end
return result
end
end
local function L_compare(src1, src2)
-- returns sign(src1 - src2)
local delta = #src1 - #src2
local j = #src1
while delta == 0 and j > 0 do
delta = src1[j] - src2[j]
j = j - 1
end
return
delta < 0 and -1
or delta > 0 and 1
or 0
end
local function L_add(src1, src2, short_factor2)
short_factor2 = short_factor2 or 1
local dest = {}
local len = min(#src1, #src2)
for j = 1, len do
dest[j] = src1[j] + src2[j] * short_factor2
end
for j = len + 1, #src1 do
dest[j] = src1[j]
end
for j = len + 1, #src2 do
dest[j] = src2[j] * short_factor2
end
return L_norm(dest)
end
function L_add_short(src, short_value)
local dest = {short_value + (src[1] or 0)}
for j = 2, #src do
dest[j] = src[j]
end
return L_norm(dest)
end
local function L_mul(src1, src2)
local result = {}
if #src1 > #src2 then
src1, src2 = src2, src1
end
for j1 = 1, #src1 do
if j1 % 16 == 0 then
L_norm(result) -- make sure intermediate numbers never exceed 2^52
end
local f = src1[j1]
for j2 = 1, #src2 do
result[j1 + j2 - 1] = (result[j1 + j2 - 1] or 0) + f * src2[j2]
end
end
return L_norm(result)
end
function L_mul_short(src, short_value)
local result = {}
for j = 1, #src do
result[j] = src[j] * short_value
end
return L_norm(result)
end
local function L_is_divisible_short(dividend, short_divisor) -- returns true/false
-- 0 < short_divisor < L_M
local remainder = 0 -- remainder < short_divisor
for j = #dividend, 1, -1 do
remainder = remainder * L_M + dividend[j]
remainder = remainder % short_divisor
end
return remainder == 0
end
local function L_div_short(dividend, short_divisor) -- returns long quotient, short remainder
-- 0 < short_divisor < L_M
local quotient = {}
local remainder = 0 -- remainder < short_divisor
for j = #dividend, 1, -1 do
local r = remainder * L_M + dividend[j]
remainder = r % short_divisor
quotient[j] = (r - remainder) / short_divisor
end
return L_norm(quotient), remainder
end
local function L_div(dividend, divisor) -- returns long quotient, long remainder
local quotient = {}
local remainder = L_new(dividend)
local float_divisor = L_to_float(divisor)
local q = L_to_float(remainder)/float_divisor
while q > 2^45 do
q = floor(q * (1 - 2^(-44)))
local delta_quotient = {}
while q > 2^50 do
delta_quotient[#delta_quotient + 1] = 0
q = floor(q / L_M)
end
while q ~= 0 do
delta_quotient[#delta_quotient + 1] = q % L_M
q = floor(q / L_M)
end
quotient = L_add(quotient, delta_quotient)
remainder = L_add(remainder, L_mul(delta_quotient, divisor), -1)
q = L_to_float(remainder)/float_divisor
end
-- q <= 2^45
local delta_quotient = L_new(floor(q + 1.5))
quotient = L_add(quotient, delta_quotient)
local delta_remainder = L_mul(delta_quotient, divisor)
while L_compare(delta_remainder, remainder) == 1 do
quotient = L_add_short(quotient, -1)
delta_remainder = L_add(delta_remainder, divisor, -1)
end
remainder = L_add(remainder, delta_remainder, -1)
return quotient, remainder
end
local function L_equals_short(long_value, short_value)
-- 0 <= short_value < L_M
return (long_value[1] or 0) == short_value and not long_value[2]
end
local function L_inv_mod(long_value, long_modulo)
-- long_value < long_modulo
-- returns nothing if inversion doesn't exist
local b = long_modulo
local s = long_value
local p = L_new(0)
local n = L_new(1)
local is_negative
while s[2] or (s[1] or 0) > 1 do -- while s > 1
local q, r = L_div(b, s)
p, n = n, L_add(L_mul(q, n), p)
b, s = s, r
is_negative = not is_negative
end
if not s[1] then -- if s == 0
return
elseif is_negative then
return L_add(long_modulo, n, -1)
else
return n
end
end
local function L_power_mod(long_base, long_exponent, long_modulo)
-- long_base < long_modulo
-- long_modulo == nil means raising to power without taking modulo
if not long_exponent[1] then -- if long_exponent = 0
return L_new(1)
else
local result
for j = #long_exponent, 1, -1 do
local exp = long_exponent[j]
local mask = L_M
if not result then
result = long_base
repeat
mask = mask / 2
until mask <= exp
exp = exp - mask
end
while mask > 1 do
result = L_mul(result, result)
mask = mask / 2
if exp >= mask then
exp = exp - mask
result = L_mul(result, long_base)
end
if long_modulo then
local quotient, remainder = L_div(result, long_modulo)
result = remainder
end
end
end
return result
end
end
local function L_to_string(src)
local tmp, result, digit = L_new(src), {}
while tmp[1] do -- while tmp <> 0
tmp, digit = L_div_short(tmp, 10)
result[#result + 1] = ("0123456789"):sub(digit + 1, digit + 1)
end
return table.concat(result):gsub("^$", "0"):reverse()
end
------------------------------------------------------------------------------------
local mt = {}
local function create_long_number(value)
local tp = type(value)
if not (tp == "number" or tp == "string" or tp == "table" and getmetatable(value) == mt) then
error("Can't create long number from datatype '"..tp.."'")
end
if tp == "table" then
value = value[1]
end
return setmetatable({L_new(value)}, mt)
end
local function get_digits(L)
local tp = type(L)
if not (tp == "number" or tp == "string" or tp == "table" and getmetatable(L) == mt) then
error("Can't perform an operation with long number and datatype '"..tp.."'")
end
if tp == "table" then
return L[1]
else
return L_new(L)
end
end
local operations = {
add = function (a, b) return L_add(a, b) end,
sub = function (a, b) return L_add(a, b, -1) end,
mul = function (a, b) return L_mul(a, b) end,
div = function (a, b) return (L_div(a, b)) end,
mod = function (a, b) local q, r = L_div(a, b) return r end,
pow = function (a, b) return L_power_mod(a, b) end,
eq = function (a, b) return L_compare(a, b) == 0 end,
lt = function (a, b) return L_compare(a, b) < 0 end,
}
local function binary_operation(operation, L1, L2)
local L1, L2 = get_digits(L1), get_digits(L2)
return setmetatable({operations[operation](L1, L2)}, mt)
end
local function comparison_operation(operation, L1, L2)
local L1, L2 = get_digits(L1), get_digits(L2)
return operations[operation](L1, L2)
end
mt.__add = function(L1, L2) return binary_operation("add", L1, L2) end
mt.__sub = function(L1, L2) return binary_operation("sub", L1, L2) end
mt.__mul = function(L1, L2) return binary_operation("mul", L1, L2) end
mt.__div = function(L1, L2) return binary_operation("div", L1, L2) end
mt.__mod = function(L1, L2) return binary_operation("mod", L1, L2) end
mt.__pow = function(L1, L2) return binary_operation("pow", L1, L2) end
mt.__eq = function(L1, L2) return comparison_operation("eq", L1, L2) end
mt.__lt = function(L1, L2) return comparison_operation("lt", L1, L2) end
mt.__tostring = function(L) return L_to_string(L[1]) end
return create_long_number
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment