Last active
November 5, 2019 17:12
-
-
Save miloyip/69663b78b26afa0dcc260382a6034b1a to your computer and use it in GitHub Desktop.
Big integer square root
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <iostream> | |
#include <vector> | |
#include <type_traits> | |
// http://www.embedded.com/electronics-blogs/programmer-s-toolbox/4219659/Integer-Square-Roots | |
uint32_t isqrt0(uint32_t n) { | |
uint32_t delta = 3; | |
for (uint32_t square = 1; square <= n; delta += 2) | |
square += delta; | |
return delta / 2 - 1; | |
} | |
uint32_t isqrt1(uint32_t n) { | |
uint32_t remainder = 0, root = 0, divisor; | |
for (size_t i = 0; i < 16; i++) { | |
root <<= 1; | |
remainder <<= 2; | |
remainder |= n >> 30; n <<= 2; // Extract 2 MSB from n | |
divisor = (root << 1) + 1; | |
if (divisor <= remainder) { | |
remainder -= divisor; | |
++root; | |
} | |
} | |
return root; | |
} | |
uint32_t isqrt2(uint32_t n) { | |
uint32_t remainder = 0, root = 0; | |
for (size_t i = 0; i < 16; i++) { | |
root <<= 1; | |
++root; | |
remainder <<= 2; | |
remainder |= n >> 30; n <<= 2; // Extract 2 MSB from n | |
if (root <= remainder) { | |
remainder -= root; | |
++root; | |
} | |
else | |
--root; | |
} | |
return root >>= 1; | |
} | |
template <typename T> | |
struct isqrt_traits { | |
static_assert(std::is_unsigned<T>::value, "generic isqrt only on unsigned types"); | |
// Number of bits in multiples of two | |
static size_t bitCount(const T& n) { | |
T a(n); | |
size_t count = 0; | |
while (a > 0) { | |
a >>= 2; | |
count += 2; | |
} | |
return count; | |
} | |
// Extract i+1, i bits | |
static uint8_t extractTwoBitsAt(const T& n, size_t i) { | |
return static_cast<uint8_t>((n >> i) & 3); | |
} | |
}; | |
template <typename T> | |
T isqrt(const T& n) { | |
T remainder{}, root{}; | |
auto bitCount = isqrt_traits<T>::bitCount(n); | |
for (size_t i = bitCount; i > 0; ) { | |
i -= 2; | |
root <<= 1; | |
++root; | |
remainder <<= 2; | |
remainder |= isqrt_traits<T>::extractTwoBitsAt(n, i); | |
if (root <= remainder) { | |
remainder -= root; | |
++root; | |
} | |
else | |
--root; | |
} | |
return root >>= 1; | |
} | |
template <typename U> | |
class biguint { | |
public: | |
biguint() : v{0} {} | |
biguint(std::initializer_list<U> init) : v(init) {} | |
biguint& operator<<=(size_t shift) { | |
assert(shift <= unitBitCount); | |
U inBits = 0; | |
for (auto& n : v) { | |
U outBits = n >> (unitBitCount - shift); | |
n = (n << shift) | inBits; | |
inBits = outBits; | |
} | |
if (inBits) | |
v.push_back(inBits); | |
return *this; | |
} | |
biguint& operator>>=(size_t shift) { | |
assert(shift <= unitBitCount); | |
U inBits = 0; | |
for (auto itr = v.rbegin(); itr != v.rend(); ++itr) { | |
U outBits = *itr << (unitBitCount - shift); | |
*itr = (*itr >> shift) | inBits; | |
inBits = outBits; | |
} | |
if (v.back() == 0) | |
v.pop_back(); | |
return *this; | |
} | |
biguint& operator|=(uint8_t rhs) { | |
v[0] |= rhs; | |
return *this; | |
} | |
biguint& operator-=(const biguint& rhs) { | |
assert(rhs <= *this); | |
U inBorrow = 0; | |
for (size_t i = 0; i < v.size(); i++) { | |
U r = i < rhs.v.size() ? rhs.v[i] : 0; | |
U previous = v[i]; | |
v[i] -= r + inBorrow; | |
inBorrow = v[i] > previous ? 1 : 0; | |
} | |
assert(inBorrow == 0); | |
while (v.size() > 1 && v.back() == 0) | |
v.pop_back(); | |
return *this; | |
} | |
biguint& operator++() { | |
for (auto& n : v) | |
if (++n != 0) | |
return *this; | |
v.push_back(1); | |
return *this; | |
} | |
biguint& operator--() { | |
assert(!(v.size() == 1 && v[0] == 0)); // non-zero | |
for (auto& n : v) | |
if (n-- != 0) | |
return *this; | |
return *this; | |
} | |
bool operator<=(const biguint& rhs) const { | |
if (v.size() == rhs.v.size()) { | |
for (auto i = v.size(); i-- > 0; ) | |
if (v[i] < rhs.v[i]) | |
return true; | |
else if (v[i] > rhs.v[i]) | |
return false; | |
return true; | |
} | |
else | |
return v.size() < rhs.v.size(); | |
} | |
friend std::ostream& operator<<(std::ostream& os, const biguint& n) { | |
auto f(os.flags()); | |
os << "0x" << std::hex; | |
for (auto itr = n.v.rbegin(); itr != n.v.rend(); ++itr) | |
os << *itr; | |
os.flags(f); | |
return os; | |
} | |
friend struct isqrt_traits<biguint>; | |
private: | |
static const size_t unitBitCount = sizeof(U) * 8; | |
std::vector<U> v; | |
}; | |
template<typename U> | |
struct isqrt_traits<biguint<U>> { | |
static size_t bitCount(const biguint<U>& n) { | |
return biguint<U>::unitBitCount * (n.v.size() - 1) + isqrt_traits<U>::bitCount(n.v.back()); | |
} | |
static uint8_t extractTwoBitsAt(const biguint<U>& n, size_t i) { | |
return static_cast<uint8_t>((n.v[i / biguint<U>::unitBitCount] >> (i % biguint<U>::unitBitCount)) & 3); | |
} | |
}; | |
int main() { | |
// floor(sqrt(45765)) = 213 | |
std::cout << isqrt0(45765) << std::endl; | |
std::cout << isqrt1(45765) << std::endl; | |
std::cout << isqrt2(45765) << std::endl; | |
std::cout << isqrt<unsigned>(45765) << std::endl; | |
// 50! = 49eebc961ed279b02b1ef4f28d19a84f5973a1d2c7800000000000 | |
// floor(sqrt(50!)) = 899310e94a8b185249821ebce70 | |
std::cout << isqrt(biguint<uint32_t>{0x00000000, 0xd2c78000, 0x4f5973a1, 0xf28d19a8, 0xb02b1ef4, 0x961ed279, 0x49eebc}) << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment