Created
January 14, 2017 22:39
-
-
Save Egor-Skriptunoff/7eb078ccbc8923d7c3fcfe53ed341b50 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- long_unsigned_integers.lua | |
--------------------------------------------------------------------------------------------------------- | |
-- LONG ARITHMETIC (non-negative integers of arbitrary length) | |
--------------------------------------------------------------------------------------------------------- | |
-- Compatible with Lua 5.1, Lua 5.2, Lua 5.3, LuaJIT | |
-- THIS MODULE WAS NOT OPTIMIZED FOR VERY LONG NUMBERS (multiplication has quadratic time complexity) | |
-- Usage: | |
-- local L = require("long_unsigned_integers") | |
-- print( "13" * L("42") / 5 + L(314) ^ 7 ) | |
-- if L("12321") % 37 == L(0) then print("multiple of 37") end -- "== 0" will not work | |
--------------------------------------------------------------------------------------------------------- | |
local sqrt, floor, min, type, setmetatable, getmetatable = math.sqrt, math.floor, math.min, type, setmetatable, getmetatable | |
local L_M = 16777216 -- 2^24, must be power of 2 | |
local L_add_short, L_mul_short | |
local function L_norm(dest) | |
local carry, r = 0 | |
for j = 1, #dest do | |
carry = carry + dest[j] | |
r = carry % L_M | |
dest[j] = r | |
carry = (carry - r) / L_M | |
end | |
if carry < 0 then | |
error"Negative numbers are not allowed" | |
end | |
while carry ~= 0 do | |
r = carry % L_M | |
dest[#dest + 1] = r | |
carry = (carry - r) / L_M | |
end | |
for j = #dest, 1, -1 do | |
if dest[j] == 0 then | |
dest[j] = nil | |
else | |
break | |
end | |
end | |
return dest | |
end | |
local function L_to_float(src) | |
local result = 0.0 | |
for j = #src, 1, -1 do | |
result = result * L_M + src[j] | |
end | |
return result | |
end | |
local function L_new(n) | |
local t = type(n) | |
if t == "number" then -- create from number | |
return L_norm{n} | |
elseif t == "table" then -- create from another long | |
local result = {} | |
for j = 1, #n do | |
result[j] = n[j] | |
end | |
return result | |
elseif t == "string" then -- create from string | |
local result = {} | |
for digit in n:gmatch"%d" do | |
result = L_add_short(L_mul_short(result, 10), tonumber(digit)) | |
end | |
return result | |
end | |
end | |
local function L_compare(src1, src2) | |
-- returns sign(src1 - src2) | |
local delta = #src1 - #src2 | |
local j = #src1 | |
while delta == 0 and j > 0 do | |
delta = src1[j] - src2[j] | |
j = j - 1 | |
end | |
return | |
delta < 0 and -1 | |
or delta > 0 and 1 | |
or 0 | |
end | |
local function L_add(src1, src2, short_factor2) | |
short_factor2 = short_factor2 or 1 | |
local dest = {} | |
local len = min(#src1, #src2) | |
for j = 1, len do | |
dest[j] = src1[j] + src2[j] * short_factor2 | |
end | |
for j = len + 1, #src1 do | |
dest[j] = src1[j] | |
end | |
for j = len + 1, #src2 do | |
dest[j] = src2[j] * short_factor2 | |
end | |
return L_norm(dest) | |
end | |
function L_add_short(src, short_value) | |
local dest = {short_value + (src[1] or 0)} | |
for j = 2, #src do | |
dest[j] = src[j] | |
end | |
return L_norm(dest) | |
end | |
local function L_mul(src1, src2) | |
local result = {} | |
if #src1 > #src2 then | |
src1, src2 = src2, src1 | |
end | |
for j1 = 1, #src1 do | |
if j1 % 16 == 0 then | |
L_norm(result) -- make sure intermediate numbers never exceed 2^52 | |
end | |
local f = src1[j1] | |
for j2 = 1, #src2 do | |
result[j1 + j2 - 1] = (result[j1 + j2 - 1] or 0) + f * src2[j2] | |
end | |
end | |
return L_norm(result) | |
end | |
function L_mul_short(src, short_value) | |
local result = {} | |
for j = 1, #src do | |
result[j] = src[j] * short_value | |
end | |
return L_norm(result) | |
end | |
local function L_is_divisible_short(dividend, short_divisor) -- returns true/false | |
-- 0 < short_divisor < L_M | |
local remainder = 0 -- remainder < short_divisor | |
for j = #dividend, 1, -1 do | |
remainder = remainder * L_M + dividend[j] | |
remainder = remainder % short_divisor | |
end | |
return remainder == 0 | |
end | |
local function L_div_short(dividend, short_divisor) -- returns long quotient, short remainder | |
-- 0 < short_divisor < L_M | |
local quotient = {} | |
local remainder = 0 -- remainder < short_divisor | |
for j = #dividend, 1, -1 do | |
local r = remainder * L_M + dividend[j] | |
remainder = r % short_divisor | |
quotient[j] = (r - remainder) / short_divisor | |
end | |
return L_norm(quotient), remainder | |
end | |
local function L_div(dividend, divisor) -- returns long quotient, long remainder | |
local quotient = {} | |
local remainder = L_new(dividend) | |
local float_divisor = L_to_float(divisor) | |
local q = L_to_float(remainder)/float_divisor | |
while q > 2^45 do | |
q = floor(q * (1 - 2^(-44))) | |
local delta_quotient = {} | |
while q > 2^50 do | |
delta_quotient[#delta_quotient + 1] = 0 | |
q = floor(q / L_M) | |
end | |
while q ~= 0 do | |
delta_quotient[#delta_quotient + 1] = q % L_M | |
q = floor(q / L_M) | |
end | |
quotient = L_add(quotient, delta_quotient) | |
remainder = L_add(remainder, L_mul(delta_quotient, divisor), -1) | |
q = L_to_float(remainder)/float_divisor | |
end | |
-- q <= 2^45 | |
local delta_quotient = L_new(floor(q + 1.5)) | |
quotient = L_add(quotient, delta_quotient) | |
local delta_remainder = L_mul(delta_quotient, divisor) | |
while L_compare(delta_remainder, remainder) == 1 do | |
quotient = L_add_short(quotient, -1) | |
delta_remainder = L_add(delta_remainder, divisor, -1) | |
end | |
remainder = L_add(remainder, delta_remainder, -1) | |
return quotient, remainder | |
end | |
local function L_equals_short(long_value, short_value) | |
-- 0 <= short_value < L_M | |
return (long_value[1] or 0) == short_value and not long_value[2] | |
end | |
local function L_inv_mod(long_value, long_modulo) | |
-- long_value < long_modulo | |
-- returns nothing if inversion doesn't exist | |
local b = long_modulo | |
local s = long_value | |
local p = L_new(0) | |
local n = L_new(1) | |
local is_negative | |
while s[2] or (s[1] or 0) > 1 do -- while s > 1 | |
local q, r = L_div(b, s) | |
p, n = n, L_add(L_mul(q, n), p) | |
b, s = s, r | |
is_negative = not is_negative | |
end | |
if not s[1] then -- if s == 0 | |
return | |
elseif is_negative then | |
return L_add(long_modulo, n, -1) | |
else | |
return n | |
end | |
end | |
local function L_power_mod(long_base, long_exponent, long_modulo) | |
-- long_base < long_modulo | |
-- long_modulo == nil means raising to power without taking modulo | |
if not long_exponent[1] then -- if long_exponent = 0 | |
return L_new(1) | |
else | |
local result | |
for j = #long_exponent, 1, -1 do | |
local exp = long_exponent[j] | |
local mask = L_M | |
if not result then | |
result = long_base | |
repeat | |
mask = mask / 2 | |
until mask <= exp | |
exp = exp - mask | |
end | |
while mask > 1 do | |
result = L_mul(result, result) | |
mask = mask / 2 | |
if exp >= mask then | |
exp = exp - mask | |
result = L_mul(result, long_base) | |
end | |
if long_modulo then | |
local quotient, remainder = L_div(result, long_modulo) | |
result = remainder | |
end | |
end | |
end | |
return result | |
end | |
end | |
local function L_to_string(src) | |
local tmp, result, digit = L_new(src), {} | |
while tmp[1] do -- while tmp <> 0 | |
tmp, digit = L_div_short(tmp, 10) | |
result[#result + 1] = ("0123456789"):sub(digit + 1, digit + 1) | |
end | |
return table.concat(result):gsub("^$", "0"):reverse() | |
end | |
------------------------------------------------------------------------------------ | |
local mt = {} | |
local function create_long_number(value) | |
local tp = type(value) | |
if not (tp == "number" or tp == "string" or tp == "table" and getmetatable(value) == mt) then | |
error("Can't create long number from datatype '"..tp.."'") | |
end | |
if tp == "table" then | |
value = value[1] | |
end | |
return setmetatable({L_new(value)}, mt) | |
end | |
local function get_digits(L) | |
local tp = type(L) | |
if not (tp == "number" or tp == "string" or tp == "table" and getmetatable(L) == mt) then | |
error("Can't perform an operation with long number and datatype '"..tp.."'") | |
end | |
if tp == "table" then | |
return L[1] | |
else | |
return L_new(L) | |
end | |
end | |
local operations = { | |
add = function (a, b) return L_add(a, b) end, | |
sub = function (a, b) return L_add(a, b, -1) end, | |
mul = function (a, b) return L_mul(a, b) end, | |
div = function (a, b) return (L_div(a, b)) end, | |
mod = function (a, b) local q, r = L_div(a, b) return r end, | |
pow = function (a, b) return L_power_mod(a, b) end, | |
eq = function (a, b) return L_compare(a, b) == 0 end, | |
lt = function (a, b) return L_compare(a, b) < 0 end, | |
} | |
local function binary_operation(operation, L1, L2) | |
local L1, L2 = get_digits(L1), get_digits(L2) | |
return setmetatable({operations[operation](L1, L2)}, mt) | |
end | |
local function comparison_operation(operation, L1, L2) | |
local L1, L2 = get_digits(L1), get_digits(L2) | |
return operations[operation](L1, L2) | |
end | |
mt.__add = function(L1, L2) return binary_operation("add", L1, L2) end | |
mt.__sub = function(L1, L2) return binary_operation("sub", L1, L2) end | |
mt.__mul = function(L1, L2) return binary_operation("mul", L1, L2) end | |
mt.__div = function(L1, L2) return binary_operation("div", L1, L2) end | |
mt.__mod = function(L1, L2) return binary_operation("mod", L1, L2) end | |
mt.__pow = function(L1, L2) return binary_operation("pow", L1, L2) end | |
mt.__eq = function(L1, L2) return comparison_operation("eq", L1, L2) end | |
mt.__lt = function(L1, L2) return comparison_operation("lt", L1, L2) end | |
mt.__tostring = function(L) return L_to_string(L[1]) end | |
return create_long_number |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment