Skip to content

Instantly share code, notes, and snippets.

@demotomohiro
Created August 11, 2022 14:00
Show Gist options
  • Save demotomohiro/704e6110919cf163de18ab36b8bbef4c to your computer and use it in GitHub Desktop.
Save demotomohiro/704e6110919cf163de18ab36b8bbef4c to your computer and use it in GitHub Desktop.
Fast approximated BigInt sqrt
import std/[math, options]
import bigints
func sqrt(a: BigInt): Option[BigInt] =
if a < 0'bi:
none(BigInt)
else:
# Approximate BigInt as z = x * 2^(2y)
# (0 ≦ x < 2^62, y ∈ N)
# Then sqrt(z) = sqrt(x) * 2^y.
#
# Take highest 62 bits, convert it to int64 and float.
# Calculate square root of it using math.sqrt.
# Convert it to int64 and convert it to BigInt
let
log2 = a.fastLog2 and (not 1)
shift = max(log2 - 60, 0)
# Need to multiply 2^shiftf to the result of sqrt
# before converting it to int64
# to avoid losing precision.
shiftf = min(shift div 2, 30)
shiftBack = (shift div 2) - shiftf
head = (a shr shift).toInt[:int64]().get
sq = (sqrt(head.float) * pow(2.0, shiftf.float)).int64
some(sq.initBigInt shl shiftBack)
proc test =
doAssert sqrt(4'bi).get == 2'bi
doAssert sqrt(10'bi).get == 3'bi
doAssert sqrt(400'bi).get == 20'bi
doAssert not sqrt(-400'bi).isSome
doAssert sqrt(4000'bi).get == 63'bi
doAssert sqrt(12345'bi).get == 111'bi
doAssert sqrt(12345678987654321'bi).get == 111111111'bi
doAssert sqrt(0xfff_ffff_ffff_ffff'bi).get == 1073741824'bi
doAssert sqrt(0x4_0000_0000_0000_0000'bi).get == 0x2_0000_0000'bi
doAssert sqrt(0x8_0000_0000_0000_0000'bi).get == (sqrt(8.0) * pow(16.0, 8)).int.initBigInt
doAssert sqrt(0x9_0000_0000_0000_0000'bi).get == 0x3_0000_0000.int.initBigInt
doAssert sqrt(0xc_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 8)).int.initBigInt
doAssert sqrt(0x1_0000_0000_0000_0000_0000'bi).get == 0x100_0000_0000'bi
doAssert sqrt(0x2_0000_0000_0000_0000_0000'bi).get == (sqrt(2.0) * pow(16.0, 10)).int.initBigInt
doAssert sqrt(0x4_0000_0000_0000_0000_0000'bi).get == 0x200_0000_0000'bi
doAssert sqrt(0x9_0000_0000_0000_0000_0000'bi).get == 0x300_0000_0000'bi
doAssert sqrt(0xc_0000_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 10)).int.initBigInt
doAssert sqrt(0x1_0000_0000_0000_0000_0000_0000'bi).get == 0x1_0000_0000_0000'bi
doAssert sqrt(0x2_0000_0000_0000_0000_0000_0000'bi).get == (sqrt(2.0) * pow(16.0, 12)).int.initBigInt
doAssert sqrt(0x4_0000_0000_0000_0000_0000_0000'bi).get == 0x2_0000_0000_0000'bi
doAssert sqrt(0x9_0000_0000_0000_0000_0000_0000'bi).get == 0x3_0000_0000_0000'bi
doAssert sqrt(0xc_0000_0000_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 12)).int.initBigInt
doAssert sqrt(pow(2'bi, 1000)).get == pow(2'bi, 500)
doAssert sqrt(4'bi * pow(2'bi, 1000)).get == pow(2'bi, 501)
doAssert sqrt(9'bi * pow(2'bi, 1000)).get == 3'bi * pow(2'bi, 500)
doAssert sqrt(25'bi * pow(2'bi, 1000)).get == 5'bi * pow(2'bi, 500)
doAssert sqrt(pow(2'bi, 10000)).get == pow(2'bi, 5000)
doAssert sqrt(4'bi * pow(2'bi, 10000)).get == 2'bi * pow(2'bi, 5000)
doAssert sqrt(9'bi * pow(2'bi, 10000)).get == 3'bi * pow(2'bi, 5000)
doAssert sqrt(16'bi * pow(2'bi, 10000)).get == 4'bi * pow(2'bi, 5000)
doAssert sqrt(25'bi * pow(2'bi, 10000)).get == 5'bi * pow(2'bi, 5000)
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10000)).get == 111111111'bi * pow(2'bi, 5000)
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10002)).get == 111111111'bi * pow(2'bi, 5001)
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10004)).get == 111111111'bi * pow(2'bi, 5002)
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment