Skip to content

Instantly share code, notes, and snippets.

@tkokof
Created January 4, 2022 13:15
Show Gist options
  • Save tkokof/deaaaa07846a01b2769e44c3161b0c83 to your computer and use it in GitHub Desktop.
Save tkokof/deaaaa07846a01b2769e44c3161b0c83 to your computer and use it in GitHub Desktop.
big int add & mul
-- desc big int add & mul
-- maintainer hugoyu
local big_int = {}
local min_add_digit_num = 9
local min_mul_digit_num = 5
local function digit_num(v)
return #tostring(v)
end
function big_int.add(a, b)
a = tostring(a)
b = tostring(b)
local result_buffer = {}
local a_start_index = -min_add_digit_num
local a_end_index = -1
local b_start_index = -min_add_digit_num
local b_end_index = -1
local last_carry = 0
while true do
local sub_a_str = a:sub(a_start_index, a_end_index)
local sub_b_str = b:sub(b_start_index, b_end_index)
local sub_a = tonumber(sub_a_str)
local sub_b = tonumber(sub_b_str)
if last_carry == 0 then
if not sub_a and not sub_b then
break
elseif not sub_a then
table.insert(result_buffer, 1, b:sub(1, b_end_index))
break
elseif not sub_b then
table.insert(result_buffer, 1, a:sub(1, a_end_index))
break
end
end
sub_a = sub_a or 0
sub_b = sub_b or 0
local sub_result = sub_a + sub_b + last_carry
local sub_result_digit_num = digit_num(sub_result)
local sub_max_digit_num = math.max(digit_num(sub_a_str), digit_num(sub_b_str))
if sub_max_digit_num >= min_add_digit_num then
if sub_result_digit_num > sub_max_digit_num then
last_carry = 1
table.insert(result_buffer, 1, tostring(sub_result):sub(2))
elseif sub_result_digit_num == sub_max_digit_num then
last_carry = 0
table.insert(result_buffer, 1, tostring(sub_result))
else
-- handling heading zeros
last_carry = 0
table.insert(result_buffer, 1, string.rep("0", sub_max_digit_num - sub_result_digit_num) .. tostring(sub_result))
end
else
last_carry = 0
table.insert(result_buffer, 1, tostring(sub_result))
end
a_end_index = a_start_index - 1
a_start_index = a_start_index - min_add_digit_num
b_end_index = b_start_index - 1
b_start_index = b_start_index - min_add_digit_num
end
return table.concat(result_buffer)
end
function big_int.add_2(a, b)
a = tostring(a)
b = tostring(b)
local result_buffer = {}
local cur_a_index = -1
local cur_b_index = -1
local last_carry = 0
while true do
local sub_a_str = a:sub(cur_a_index, cur_a_index)
local sub_b_str = b:sub(cur_b_index, cur_b_index)
local sub_a = tonumber(sub_a_str)
local sub_b = tonumber(sub_b_str)
if last_carry == 0 then
if not sub_a and not sub_b then
break
elseif not sub_a then
table.insert(result_buffer, 1, b:sub(1, cur_b_index))
break
elseif not sub_b then
table.insert(result_buffer, 1, a:sub(1, cur_a_index))
break
end
end
sub_a = sub_a or 0
sub_b = sub_b or 0
local sub_result = sub_a + sub_b + last_carry
if sub_result >= 10 then
last_carry = 1
table.insert(result_buffer, 1, tostring(sub_result):sub(2))
else
last_carry = 0
table.insert(result_buffer, 1, tostring(sub_result))
end
cur_a_index = cur_a_index - 1
cur_b_index = cur_b_index - 1
end
return table.concat(result_buffer)
end
function big_int.mul(a, b)
a = tostring(a)
b = tostring(b)
local a_digit_num = digit_num(a)
local b_digit_num = digit_num(b)
if a_digit_num + b_digit_num <= 2 * min_mul_digit_num then
return tostring((tonumber(a) or 0) * (tonumber(b) or 0))
else
local a_digit_num_h = math.ceil(a_digit_num / 2)
local a_digit_num_l = a_digit_num - a_digit_num_h
local b_digit_num_h = math.ceil(b_digit_num / 2)
local b_digit_num_l = b_digit_num - b_digit_num_h
local ah = a:sub(1, a_digit_num_h)
local al = a:sub(a_digit_num_h + 1, -1)
local bh = b:sub(1, b_digit_num_h)
local bl = b:sub(b_digit_num_h + 1, -1)
local ah_mul_bh = big_int.mul(ah, bh)
local ah_mul_bl = big_int.mul(ah, bl)
local al_mul_bh = big_int.mul(al, bh)
local al_mul_bl = big_int.mul(al, bl)
local result = ah_mul_bh .. string.rep("0", a_digit_num_l + b_digit_num_l)
result = big_int.add(result, ah_mul_bl .. string.rep("0", a_digit_num_l))
result = big_int.add(result, al_mul_bh .. string.rep("0", b_digit_num_l))
result = big_int.add(result, al_mul_bl)
return result
end
end
function big_int.test()
local test_count = 100000
for i = 1, test_count do
local a = math.random(1, 2 ^ 60)
local b = math.random(1, 2 ^ 60)
local a_add_b = big_int.add(a, b)
local a_add_b_2 = big_int.add_2(a, b)
local a_add_b_str = tostring(a + b)
if a_add_b ~= a_add_b_str or
a_add_b_2 ~= a_add_b_str then
print(a .. " + " .. b)
print(a_add_b)
print(a_add_b_2)
print(a_add_b_str)
end
end
local t = os.clock()
for i = 1, test_count do
local a = math.random(1, 2 ^ 60)
local b = math.random(1, 2 ^ 60)
big_int.add(a, b)
end
print(os.clock() - t)
t = os.clock()
for i = 1, test_count do
local a = math.random(1, 2 ^ 60)
local b = math.random(1, 2 ^ 60)
big_int.add_2(a, b)
end
print(os.clock() - t)
for i = 1, test_count do
local a = math.random(1, 2 ^ 30)
local b = math.random(1, 2 ^ 30)
local a_mul_b = big_int.mul(a, b)
local a_mul_b_str = tostring(a * b)
if a_mul_b ~= a_mul_b_str then
print(a .. " * " .. b)
print(a_mul_b)
print(a_mul_b_str)
end
end
end
return big_int
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment