Skip to content

Instantly share code, notes, and snippets.

@miloyip
Last active November 5, 2019 17:12
Show Gist options
  • Save miloyip/69663b78b26afa0dcc260382a6034b1a to your computer and use it in GitHub Desktop.
Save miloyip/69663b78b26afa0dcc260382a6034b1a to your computer and use it in GitHub Desktop.
Big integer square root
#include <cassert>
#include <iostream>
#include <vector>
#include <type_traits>
// http://www.embedded.com/electronics-blogs/programmer-s-toolbox/4219659/Integer-Square-Roots
uint32_t isqrt0(uint32_t n) {
uint32_t delta = 3;
for (uint32_t square = 1; square <= n; delta += 2)
square += delta;
return delta / 2 - 1;
}
uint32_t isqrt1(uint32_t n) {
uint32_t remainder = 0, root = 0, divisor;
for (size_t i = 0; i < 16; i++) {
root <<= 1;
remainder <<= 2;
remainder |= n >> 30; n <<= 2; // Extract 2 MSB from n
divisor = (root << 1) + 1;
if (divisor <= remainder) {
remainder -= divisor;
++root;
}
}
return root;
}
uint32_t isqrt2(uint32_t n) {
uint32_t remainder = 0, root = 0;
for (size_t i = 0; i < 16; i++) {
root <<= 1;
++root;
remainder <<= 2;
remainder |= n >> 30; n <<= 2; // Extract 2 MSB from n
if (root <= remainder) {
remainder -= root;
++root;
}
else
--root;
}
return root >>= 1;
}
template <typename T>
struct isqrt_traits {
static_assert(std::is_unsigned<T>::value, "generic isqrt only on unsigned types");
// Number of bits in multiples of two
static size_t bitCount(const T& n) {
T a(n);
size_t count = 0;
while (a > 0) {
a >>= 2;
count += 2;
}
return count;
}
// Extract i+1, i bits
static uint8_t extractTwoBitsAt(const T& n, size_t i) {
return static_cast<uint8_t>((n >> i) & 3);
}
};
template <typename T>
T isqrt(const T& n) {
T remainder{}, root{};
auto bitCount = isqrt_traits<T>::bitCount(n);
for (size_t i = bitCount; i > 0; ) {
i -= 2;
root <<= 1;
++root;
remainder <<= 2;
remainder |= isqrt_traits<T>::extractTwoBitsAt(n, i);
if (root <= remainder) {
remainder -= root;
++root;
}
else
--root;
}
return root >>= 1;
}
template <typename U>
class biguint {
public:
biguint() : v{0} {}
biguint(std::initializer_list<U> init) : v(init) {}
biguint& operator<<=(size_t shift) {
assert(shift <= unitBitCount);
U inBits = 0;
for (auto& n : v) {
U outBits = n >> (unitBitCount - shift);
n = (n << shift) | inBits;
inBits = outBits;
}
if (inBits)
v.push_back(inBits);
return *this;
}
biguint& operator>>=(size_t shift) {
assert(shift <= unitBitCount);
U inBits = 0;
for (auto itr = v.rbegin(); itr != v.rend(); ++itr) {
U outBits = *itr << (unitBitCount - shift);
*itr = (*itr >> shift) | inBits;
inBits = outBits;
}
if (v.back() == 0)
v.pop_back();
return *this;
}
biguint& operator|=(uint8_t rhs) {
v[0] |= rhs;
return *this;
}
biguint& operator-=(const biguint& rhs) {
assert(rhs <= *this);
U inBorrow = 0;
for (size_t i = 0; i < v.size(); i++) {
U r = i < rhs.v.size() ? rhs.v[i] : 0;
U previous = v[i];
v[i] -= r + inBorrow;
inBorrow = v[i] > previous ? 1 : 0;
}
assert(inBorrow == 0);
while (v.size() > 1 && v.back() == 0)
v.pop_back();
return *this;
}
biguint& operator++() {
for (auto& n : v)
if (++n != 0)
return *this;
v.push_back(1);
return *this;
}
biguint& operator--() {
assert(!(v.size() == 1 && v[0] == 0)); // non-zero
for (auto& n : v)
if (n-- != 0)
return *this;
return *this;
}
bool operator<=(const biguint& rhs) const {
if (v.size() == rhs.v.size()) {
for (auto i = v.size(); i-- > 0; )
if (v[i] < rhs.v[i])
return true;
else if (v[i] > rhs.v[i])
return false;
return true;
}
else
return v.size() < rhs.v.size();
}
friend std::ostream& operator<<(std::ostream& os, const biguint& n) {
auto f(os.flags());
os << "0x" << std::hex;
for (auto itr = n.v.rbegin(); itr != n.v.rend(); ++itr)
os << *itr;
os.flags(f);
return os;
}
friend struct isqrt_traits<biguint>;
private:
static const size_t unitBitCount = sizeof(U) * 8;
std::vector<U> v;
};
template<typename U>
struct isqrt_traits<biguint<U>> {
static size_t bitCount(const biguint<U>& n) {
return biguint<U>::unitBitCount * (n.v.size() - 1) + isqrt_traits<U>::bitCount(n.v.back());
}
static uint8_t extractTwoBitsAt(const biguint<U>& n, size_t i) {
return static_cast<uint8_t>((n.v[i / biguint<U>::unitBitCount] >> (i % biguint<U>::unitBitCount)) & 3);
}
};
int main() {
// floor(sqrt(45765)) = 213
std::cout << isqrt0(45765) << std::endl;
std::cout << isqrt1(45765) << std::endl;
std::cout << isqrt2(45765) << std::endl;
std::cout << isqrt<unsigned>(45765) << std::endl;
// 50! = 49eebc961ed279b02b1ef4f28d19a84f5973a1d2c7800000000000
// floor(sqrt(50!)) = 899310e94a8b185249821ebce70
std::cout << isqrt(biguint<uint32_t>{0x00000000, 0xd2c78000, 0x4f5973a1, 0xf28d19a8, 0xb02b1ef4, 0x961ed279, 0x49eebc}) << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment