Skip to content

Instantly share code, notes, and snippets.

@KivApple
Created January 3, 2018 10:50

Revisions

  1. KivApple created this gist Jan 3, 2018.
    502 changes: 502 additions & 0 deletions BigInteger.kt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,502 @@
    package com.eternal_search.bignum

    import kotlin.math.max

    class BigInteger(private val sign: Int, private val digits: IntArray, private var digitCount: Int,
    private var digitsOffset: Int = 0): Number(), Comparable<BigInteger> {
    override fun toLong(): Long {
    var result = 0L
    for (i in digitCount - 1 downTo 0) {
    result = result shl INTERNAL_DIGIT_WIDTH
    result = result or digits[i + digitsOffset].toLong()
    }
    return result * sign
    }

    override fun toInt(): Int {
    var result = 0
    for (i in digitCount - 1 downTo 0) {
    result = result shl INTERNAL_DIGIT_WIDTH
    result = result or digits[i + digitsOffset]
    }
    return result * sign
    }

    override fun toShort(): Short = toInt().toShort()

    override fun toByte(): Byte = toInt().toByte()

    override fun toChar(): Char = toInt().toChar()

    override fun toDouble(): Double = toLong().toDouble()

    override fun toFloat(): Float = toLong().toFloat()

    override fun toString(): String = toString(10)

    fun toString(radix: Int): String {
    if ((radix < 2) || (radix > 16))
    throw IllegalArgumentException("Radix must be between 2 and 16")
    val builder = StringBuilder()
    val tmp = BigInteger(1, IntArray(digitCount) { digits[it + digitsOffset] }, digitCount)
    do {
    val digit = tmp.fastDivRem(radix)
    builder.append(DIGITS[digit])
    } while ((tmp.digitCount > 1) || (tmp.digits[0] != 0))
    if (sign < 0) builder.append('-')
    return builder.reverse().toString()
    }

    operator fun unaryPlus() = this

    operator fun unaryMinus() = BigInteger(-sign, digits, digitCount)

    fun inv(): BigInteger {
    val newDigits = IntArray(digitCount)
    for (i in 0 until digitCount) {
    newDigits[i] = digits[i + digitsOffset].inv().and(INTERNAL_DIGIT_MASK)
    }
    return BigInteger(sign, newDigits, digitCount)
    }

    operator fun plus(other: BigInteger): BigInteger {
    var carry = 0
    if (sign == other.sign) {
    val newDigitCount = max(digitCount, other.digitCount) + 1
    val newDigits = IntArray(newDigitCount)
    for (i in 0 until newDigitCount - 1) {
    val a = if (i < digitCount) digits[i + digitsOffset] else 0
    val b = if (i < other.digitCount) other.digits[i + other.digitsOffset] else 0
    val c = a + b + carry
    newDigits[i] = if (c >= INTERNAL_RADIX) {
    carry = 1
    c and INTERNAL_DIGIT_MASK
    } else {
    carry = 0
    c
    }
    }
    newDigits[newDigitCount - 1] = carry
    return BigInteger(sign, newDigits, if (carry > 0) newDigitCount else newDigitCount - 1)
    } else {
    val newDigitCount = max(digitCount, other.digitCount)
    val newDigits = IntArray(newDigitCount)
    var lastNonZeroIndex = 0
    for (i in 0 until newDigitCount) {
    val a = if (i < digitCount) digits[i + digitsOffset] else 0
    val b = if (i < other.digitCount) other.digits[i + other.digitsOffset] else 0
    val c = a - b + carry
    val d = if (c < 0) {
    carry = -1
    INTERNAL_RADIX + c
    } else {
    carry = 0
    c
    }
    newDigits[i] = d
    if (d != 0) lastNonZeroIndex = i
    }
    if (carry < 0) {
    newDigits[0] = INTERNAL_RADIX - newDigits[0]
    for (i in 1..lastNonZeroIndex) {
    val d = INTERNAL_RADIX - newDigits[i] - 1
    newDigits[i] = d
    if (d != 0) lastNonZeroIndex = i
    }
    }
    return BigInteger(if (carry < 0) -1 else 1, newDigits, lastNonZeroIndex + 1)
    }
    }

    operator fun minus(other: BigInteger): BigInteger = this + -other

    private fun fastAdd(value: Int) {
    var c = digits[0 + digitsOffset] + value
    digits[0 + digitsOffset] = c and INTERNAL_DIGIT_MASK
    var carry = c ushr INTERNAL_DIGIT_WIDTH
    if (carry > 0) {
    var i = 1
    while ((carry != 0) && (i < (digits.size - digitsOffset))) {
    c = digits[i + digitsOffset] + carry
    digits[i + digitsOffset] = c and INTERNAL_DIGIT_MASK
    carry = c ushr INTERNAL_DIGIT_WIDTH
    i++
    }
    if (i >= digitCount) digitCount = i + 1
    }
    }

    private fun fastSub(value: BigInteger, dst: BigInteger) { // this >= value
    var carry = 0
    for (i in 0 until digitCount) {
    val a = digits[i + digitsOffset]
    val b = if (i < value.digitCount) value.digits[i + value.digitsOffset] else 0
    val c = a - b + carry
    carry = if (c < 0) {
    dst.digits[i + dst.digitsOffset] = INTERNAL_RADIX + c
    -1
    } else {
    dst.digits[i + dst.digitsOffset] = c
    0
    }
    }
    dst.digitCount = digitCount
    while ((dst.digitCount > 1) && (dst.digits[dst.digitCount - 1 + dst.digitsOffset] == 0)) dst.digitCount--
    }

    private fun fastMul(value: Int) {
    var carry = 0
    for (i in 0 until digitCount) {
    val c = digits[i + digitsOffset] * value + carry
    digits[i + digitsOffset] = c and INTERNAL_DIGIT_MASK
    carry = c ushr INTERNAL_DIGIT_WIDTH
    }
    if ((carry != 0) && (digitCount < (digits.size - digitsOffset))) {
    digits[digitCount++ + digitsOffset] = carry
    }
    }

    private fun fastMul(value: Int, dst: BigInteger) {
    dst.digitCount = digitCount
    var carry = 0
    for (i in 0 until digitCount) {
    val c = digits[i + digitsOffset] * value + carry
    dst.digits[i + dst.digitsOffset] = c and INTERNAL_DIGIT_MASK
    carry = c ushr INTERNAL_DIGIT_WIDTH
    }
    if ((carry != 0) && (dst.digitCount < (dst.digits.size - dst.digitsOffset))) {
    dst.digits[dst.digitCount++ + dst.digitsOffset] = carry
    }
    }

    operator fun times(other: BigInteger): BigInteger {
    if ((digitCount == 1) && (digits[0 + digitsOffset] == 0)) return ZERO
    if ((other.digitCount == 1) && (other.digits[0 + digitsOffset] == 0)) return ZERO
    val sign = this.sign * other.sign
    if ((digitCount == 1) && (digits[0 + digitsOffset] == 1))
    return if (sign >= 0) other else -other
    else if ((other.digitCount == 1) && (other.digits[0 + digitsOffset] == 1))
    return if (sign >= 0) this else -this
    val newDigits = IntArray(max(digitCount, other.digitCount) * 2 + 1)
    val tmp = IntArray(other.digitCount + 1)
    var newDigitCount = 1
    for (i in 0 until digitCount) {
    val a = digits[i + digitsOffset]
    if (a == 0) continue
    var carry = 0
    for (j in 0 until other.digitCount) {
    val b = other.digits[j + digitsOffset]
    val c = a * b + carry
    tmp[j] = if (c >= INTERNAL_RADIX) {
    carry = c ushr INTERNAL_DIGIT_WIDTH
    c and INTERNAL_DIGIT_MASK
    } else {
    carry = 0
    c
    }
    }
    tmp[other.digitCount] = carry
    val tmpCount = if (carry > 0) other.digitCount + 1 else other.digitCount
    carry = 0
    for (j in 0 until tmpCount) {
    val x = newDigits[j + i]
    val y = tmp[j]
    val z = x + y + carry
    newDigits[j + i] = if (z >= INTERNAL_RADIX) {
    carry = 1
    z and INTERNAL_DIGIT_MASK
    } else {
    carry = 0
    z
    }
    }
    newDigits[tmpCount + i] = carry
    newDigitCount = if (carry > 0) tmpCount + i + 1 else tmpCount + i
    }
    return BigInteger(sign, newDigits, newDigitCount)
    }

    private fun fastDivRem(value: Int): Int {
    var i = digitCount - 1
    var rem = 0L
    div@while (i >= 0) {
    var group = rem * INTERNAL_RADIX + digits[i-- + digitsOffset]
    while (group < value) {
    digits[i + 1 + digitsOffset] = 0
    if (i < 0) {
    rem = group
    break@div
    }
    group = group * INTERNAL_RADIX + digits[i-- + digitsOffset]
    }
    digits[i + 1 + digitsOffset] = (group / value).toInt()
    rem = group % value
    }
    val start = i + 1
    var stop = digitCount - 1
    while ((stop > start) && (digits[stop + digitsOffset] == 0)) stop--
    for (j in 0..stop - start) {
    digits[j + digitsOffset] = digits[j + start + digitsOffset]
    }
    for (j in stop - start until stop) {
    digits[j + digitsOffset] = 0
    }
    digitCount = stop + 1
    return rem.toInt()
    }

    fun divRem(other: BigInteger): Pair<BigInteger, BigInteger> {
    if (other.digitCount == 1)
    if (other.digits[0 + digitsOffset] == 0)
    throw IllegalArgumentException("Division by zero")
    else if (other.digits[0 + digitsOffset] == 1)
    return Pair(this, ZERO)
    if ((digitCount == 1) && (digits[0 + digitsOffset] == 0)) return Pair(ZERO, ZERO)
    val newDigits = IntArray(digitCount)
    var newDigitCount = 0
    val dividend = BigInteger(1, IntArray(digitCount * 2), 1, digitCount - 1)
    val temp = BigInteger(1, IntArray(digitCount + 1), 1, 0)
    var i = digitCount - 1
    div@while (i >= 0) {
    val digit = digits[i-- + digitsOffset]
    if ((digit == 0) && (dividend.digitCount == 1) && (dividend.digits[0 + dividend.digitsOffset] == 0)) {
    newDigitCount++
    continue
    }
    if ((dividend.digitCount == 1) && (dividend.digits[0 + dividend.digitsOffset] == 0)) {
    dividend.digits[0 + dividend.digitsOffset] = digit
    } else {
    dividend.digitsOffset--
    dividend.digits[0 + dividend.digitsOffset] = digit
    dividend.digitCount++
    }
    while (dividend < other) {
    if (i < 0) {
    break@div
    }
    dividend.digits[--dividend.digitsOffset] = digits[i-- + digitsOffset]
    dividend.digitCount++
    newDigitCount++
    }
    if ((dividend.digitCount == 1) && (other.digitCount == 1)) {
    newDigits[newDigitCount++] = dividend.digits[0 + dividend.digitsOffset] /
    other.digits[0 + other.digitsOffset]
    dividend.digits[digitCount - 1] = dividend.digits[0 + dividend.digitsOffset] %
    other.digits[0 + other.digitsOffset]
    dividend.digitsOffset = digitCount - 1
    dividend.digitCount = 1
    continue
    }
    var minD = 0
    var maxD = INTERNAL_RADIX
    do {
    val d = (minD + maxD) / 2
    other.fastMul(d, temp)
    val cmp = temp.compareTo(dividend)
    if (cmp == 0) {
    minD = d
    maxD = d
    } else if (cmp < 0) {
    if (minD == d) {
    maxD = d
    } else {
    minD = d
    }
    } else { // cmp > 0
    maxD = d
    }
    } while (minD != maxD)
    newDigits[newDigitCount++] = minD
    dividend.fastSub(temp, dividend)
    }
    newDigits.reverse()
    while ((newDigitCount > 0) && (newDigits[newDigitCount - 1] == 0)) newDigitCount--
    val sign = this.sign * other.sign
    return Pair(
    BigInteger(sign, newDigits, newDigitCount),
    BigInteger(sign, dividend.digits, dividend.digitCount, dividend.digitsOffset)
    )
    }

    operator fun div(other: BigInteger): BigInteger = divRem(other).first

    operator fun rem(other: BigInteger): BigInteger = divRem(other).second

    infix fun or(other: BigInteger): BigInteger {
    val newDigitCount = max(digitCount, other.digitCount)
    val newDigits = IntArray(newDigitCount)
    for (i in 0 until newDigitCount) {
    val a = if (i < digitCount) digits[i + digitsOffset] else 0
    val b = if (i < other.digitCount) other.digits[i + digitsOffset] else 0
    newDigits[i] = a or b
    }
    return BigInteger(1, newDigits, newDigitCount)
    }

    infix fun xor(other: BigInteger): BigInteger {
    val newDigitCount = max(digitCount, other.digitCount)
    val newDigits = IntArray(newDigitCount)
    for (i in 0 until newDigitCount) {
    val a = if (i < digitCount) digits[i + digitsOffset] else 0
    val b = if (i < other.digitCount) other.digits[i + digitsOffset] else 0
    newDigits[i] = a xor b
    }
    return BigInteger(1, newDigits, newDigitCount)
    }

    infix fun and(other: BigInteger): BigInteger {
    val newDigitCount = max(digitCount, other.digitCount)
    val newDigits = IntArray(newDigitCount)
    for (i in 0 until newDigitCount) {
    val a = if (i < digitCount) digits[i + digitsOffset] else 0
    val b = if (i < other.digitCount) other.digits[i + digitsOffset] else 0
    newDigits[i] = a and b
    }
    return BigInteger(1, newDigits, newDigitCount)
    }

    infix fun shl(shift: Int): BigInteger {
    if (shift < 0) return this shr -shift
    val fastShift = shift / INTERNAL_DIGIT_WIDTH
    val slowShift = shift % INTERNAL_DIGIT_WIDTH
    val newDigits = IntArray(digitCount + fastShift)
    for (i in 0 until digitCount) {
    newDigits[i + fastShift] = digits[i + digitsOffset]
    }
    return BigInteger(sign, newDigits, newDigits.size) * valueOf(1 shl slowShift)
    }

    infix fun shr(shift: Int): BigInteger {
    if (shift < 0) return this shl -shift
    val fastShift = shift / INTERNAL_DIGIT_WIDTH
    val slowShift = shift % INTERNAL_DIGIT_WIDTH
    val newDigits = IntArray(digitCount - fastShift)
    for (i in 0 until newDigits.size) {
    newDigits[i] = digits[i + fastShift + digitsOffset]
    }
    return BigInteger(sign, newDigits, newDigits.size) / valueOf(1 shl slowShift)
    }

    override fun compareTo(other: BigInteger): Int {
    if (sign == other.sign) {
    if (digitCount > other.digitCount) {
    return sign
    } else if (digitCount < other.digitCount) {
    return -sign
    } else {
    for (i in digitCount - 1 downTo 0) {
    val a = digits[i + digitsOffset]
    val b = other.digits[i + other.digitsOffset]
    if (a > b) {
    return sign
    } else if (a < b) {
    return -sign
    }
    }
    return 0
    }
    } else {
    return sign
    }
    }

    override fun equals(other: Any?): Boolean = (other is BigInteger) && compareTo(other) == 0

    override fun hashCode(): Int = digits.sliceArray(0 until digitCount).contentHashCode() * sign

    companion object {
    private val INTERNAL_DIGIT_WIDTH = 16
    private val INTERNAL_RADIX = 1 shl INTERNAL_DIGIT_WIDTH
    private val INTERNAL_DIGIT_MASK = INTERNAL_RADIX - 1

    private val BITS_PER_DIGIT = intArrayOf(0, 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4)
    private val DIGITS = "0123456789ABCDEF"

    val ZERO = valueOf(0)
    val ONE = valueOf(1)
    val TEN = valueOf(10)

    fun valueOf(value: Long): BigInteger {
    var n = value
    val sign = if (n < 0) { n = -n; -1 } else 1
    val digits = IntArray(((64 + (INTERNAL_DIGIT_WIDTH - 1))) / INTERNAL_DIGIT_WIDTH)
    var i = 0
    while (n > 0) {
    digits[i++] = n.toInt() and INTERNAL_DIGIT_MASK
    n = n ushr INTERNAL_DIGIT_WIDTH
    }
    return BigInteger(sign, digits, i)
    }

    fun valueOf(value: Int): BigInteger {
    var n = value
    val sign = if (n < 0) { n = -n; -1 } else 1
    val digits = IntArray((64 + (INTERNAL_DIGIT_WIDTH - 1)) / INTERNAL_DIGIT_WIDTH)
    var i = 0
    while (n > 0) {
    digits[i++] = n and INTERNAL_DIGIT_MASK
    n = n ushr INTERNAL_DIGIT_WIDTH
    }
    return BigInteger(sign, digits, i)
    }

    fun valueOf(value: Short): BigInteger {
    return if (value >= 0)
    BigInteger(1, intArrayOf(value.toInt()), 1)
    else
    BigInteger(-1, intArrayOf(-value.toInt()), 1)
    }

    fun valueOf(value: Byte): BigInteger {
    return if (value >= 0)
    BigInteger(1, intArrayOf(value.toInt()), 1)
    else
    BigInteger(-1, intArrayOf(-value.toInt()), 1)
    }

    fun valueOf(value: Number): BigInteger = when (value) {
    is BigInteger -> value
    is Long -> valueOf(value)
    is Int -> valueOf(value)
    is Short -> valueOf(value)
    is Byte -> valueOf(value)
    else -> valueOf(value.toLong())
    }

    fun valueOf(value: Char): BigInteger = valueOf(value.toInt())

    fun valueOf(value: String, radix: Int = 10): BigInteger {
    if ((radix < 2) || (radix > 16))
    throw IllegalArgumentException("Radix must be between 2 and 16")
    if (value.isEmpty())
    throw IllegalArgumentException("String is empty")
    var firstIndex = 0
    val sign = if (value[0] == '-') {
    firstIndex = 1
    -1
    } else {
    if (value[0] == '+') firstIndex++
    1
    }
    if (firstIndex >= value.length) throw IllegalArgumentException("String does not contain digits")
    while ((firstIndex < value.length) && (value[firstIndex] == '0')) firstIndex++
    if (firstIndex == value.length) return ZERO
    val bitCount = (value.length - firstIndex) * BITS_PER_DIGIT[radix]
    val result = BigInteger(
    sign, IntArray((bitCount + INTERNAL_DIGIT_WIDTH - 1) / INTERNAL_DIGIT_WIDTH), 1
    )
    for (i in firstIndex until value.length) {
    val digit = when (value[i]) {
    in '0'..'9' -> value[i] - '0'
    in 'A'..'F' -> value[i] - 'A' + 10
    in 'a'..'f' -> value[i] - 'a' + 10
    else -> throw IllegalArgumentException("Invalid character in the string")
    }
    if (digit >= radix) throw IllegalArgumentException("Invalid character in the string")
    result.fastMul(radix)
    result.fastAdd(digit)
    }
    return result
    }
    }
    }