Skip to content

Instantly share code, notes, and snippets.

@Bananattack
Last active February 2, 2022 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Bananattack/6242ba7b8265c90ce6f3c2d84670c52d to your computer and use it in GitHub Desktop.
Save Bananattack/6242ba7b8265c90ce6f3c2d84670c52d to your computer and use it in GitHub Desktop.
A signed 128-bit integer type. Internally stores its data in Two's Complement within two unsigned 64-bit integer parts. Still needs more testing, but basic implementation is fairly complete.
#ifndef WIZ_UTILITY_INT128_H
#define WIZ_UTILITY_INT128_H
#include <cassert>
#include <cstdint>
#include <cstddef>
#include <cstring>
#include <cstdlib>
#include <utility>
#include <string>
#include <iostream>
namespace wiz {
struct Int128 {
Int128()
: low(0), high(0) {}
Int128(const Int128& other) = default;
Int128(Int128&& other) = default;
explicit Int128(std::int8_t value)
: low(value < 0
? (((static_cast<uint64_t>(-value) ^ 0xFF) + 1) | (0xFFFFFFFFFFFFFF00))
: static_cast<uint64_t>(value)),
high(value < 0 ? UINT64_MAX : 0) {}
explicit Int128(std::int16_t value)
: low(value < 0
? (((static_cast<uint64_t>(-value) ^ 0xFFFF) + 1) | (0xFFFFFFFFFFFF0000))
: static_cast<uint64_t>(value)),
high(value < 0 ? UINT64_MAX : 0) {}
explicit Int128(std::int32_t value)
: low(value < 0
? (((static_cast<uint64_t>(-value) ^ 0xFFFFFFFF) + 1) | (0xFFFFFFFF00000000))
: static_cast<uint64_t>(value)),
high(value < 0 ? UINT64_MAX : 0) {}
explicit Int128(std::int64_t value)
: low(value < 0
? ((static_cast<uint64_t>(-value) ^ UINT64_MAX) + 1)
: static_cast<uint64_t>(value)),
high(value < 0 ? UINT64_MAX : 0) {}
explicit Int128(std::uint8_t value)
: low(value), high(0) {}
explicit Int128(std::uint16_t value)
: low(value), high(0) {}
explicit Int128(std::uint32_t value)
: low(value), high(0) {}
explicit Int128(std::uint64_t value)
: low(value), high(0) {}
Int128(std::uint64_t low, std::uint64_t high)
: low(low), high(high) {}
static Int128 zero() {
return Int128(0, 0);
}
static Int128 one() {
return Int128(1, 0);
}
static Int128 minValue() {
return Int128(0, 0x8000000000000000);
}
static Int128 maxValue() {
return Int128(UINT64_MAX, 0x7FFFFFFFFFFFFFFF);
}
enum class ParseResult {
Success,
InvalidArgument,
FormatError,
RangeError,
};
static std::pair<ParseResult, Int128> parse(const char* str, std::size_t base = 10) {
std::size_t length = std::strlen(str);
return parse(str, str + length, base);
}
template <class InputIterator>
static std::pair<ParseResult, Int128> parse(InputIterator first, InputIterator last, std::size_t base = 10, bool negative = false) {
if (base < 2 || base > 36 || first == last) {
return {ParseResult::InvalidArgument, zero()};
}
if (!negative) {
if (first != last) {
if (*first == '-') {
negative = true;
++first;
} else if (*first == '+') {
++first;
}
}
}
std::pair<CheckedArithmeticResult, Int128> result = {CheckedArithmeticResult::Success, zero()};
if (first == last) {
return {ParseResult::FormatError, zero()};
}
while (first != last) {
result = result.second.checkedMultiply(Int128(base, 0));
if (result.first == CheckedArithmeticResult::OverflowError) {
return {ParseResult::RangeError, zero()};
}
const auto c = static_cast<uint8_t>(*first);
Int128 digit;
if (c >= '0' && c <= '0' + base) {
digit = Int128(c - '0');
} else if (base > 10 && c >= 'a' && c <= ('a' + base - 10)) {
digit = Int128((c - 'a') + 10);
} else if (base > 10 && c >= 'A' && c <= ('A' + base - 10)) {
digit = Int128((c - 'A') + 10);
} else {
return {ParseResult::FormatError, zero()};
}
result = result.second.checkedAdd(negative ? -digit : digit);
if (result.first == CheckedArithmeticResult::OverflowError) {
return {ParseResult::RangeError, zero()};
}
++first;
}
return {ParseResult::Success, result.second};
}
bool isZero() const {
return low == 0 && high == 0;
}
bool isPositive() const {
return !isZero() && !isNegative();
}
bool isNegative() const {
return (high & 0x8000000000000000) != 0;
}
Int128 getAbsoluteValue() const {
return isNegative() ? -*this : *this;
}
bool getBit(std::size_t bit) const {
if (bit >= 128) {
return 0;
} else if (bit >= 64) {
return (high & (static_cast<uint64_t>(1) << static_cast<uint64_t>(bit - 64))) != 0;
} else {
return (low & (static_cast<uint64_t>(1) << static_cast<uint64_t>(bit))) != 0;
}
}
void setBit(std::size_t bit, bool value) {
if (bit >= 128) {
return;
} else if (bit >= 64) {
std::uint64_t mask = static_cast<uint64_t>(1) << static_cast<uint64_t>(bit - 64);
if (value) {
high |= mask;
} else {
high &= ~mask;
}
} else {
std::uint64_t mask = static_cast<uint64_t>(1) << static_cast<uint64_t>(bit);
if (value) {
low |= mask;
} else {
low &= ~mask;
}
}
}
Int128 logicalShiftLeftOnce() const {
return Int128(low << 1, (high << 1) | (low >> 63));
}
Int128 logicalShiftRightOnce() const {
return Int128((low >> 1) | (high << 63), high >> 1);
}
Int128 arithmeticShiftRightOnce() const {
return Int128((low >> 1) | (high << 63), (high >> 1) | (high & 0x8000000000000000));
}
Int128 logicalShiftLeft(std::size_t bits) const {
return *this << bits;
}
Int128 logicalShiftRight(std::size_t bits) const {
if (bits == 0) {
return *this;
} else if (bits >= 128) {
return zero();
} else if (bits >= 64) {
return Int128(high >> (bits - 64), 0);
} else {
return Int128((low >> bits) | (high << (64 - bits)), high >> bits);
}
}
Int128 arithmeticShiftLeft(std::size_t bits) const {
return *this << bits;
}
Int128 arithmeticShiftRight(std::size_t bits) const {
if (bits == 0) {
return *this;
} else if (bits >= 128) {
return isNegative() ? Int128(-1) : zero();
} else if (bits >= 64) {
return Int128((high >> (bits - 64)) | (isNegative() ? (UINT64_MAX << (64 - (bits - 64))) : 0), UINT64_MAX);
} else {
return Int128((low >> bits) | (high << (64 - bits)), (high >> bits) | (isNegative() ? (UINT64_MAX << (64 - bits)) : 0));
}
}
std::pair<Int128, Int128> unsignedDivisionWithRemainder(Int128 other) const {
if (other.isZero()) {
assert(!other.isZero());
std::abort();
return {zero(), zero()};
} else if (other == one()) {
return {*this, zero()};
} else if (*this == other) {
return {one(), zero()};
} else if (isZero() || (*this != minValue() && *this < other)) {
return {zero(), *this};
} else if (high == 0 && other.high == 0) {
return {Int128(low / other.low, 0), Int128(low % other.low, 0)};
} else {
auto quotient = zero();
auto remainder = zero();
for (std::size_t i = findMostSignificantBit(); i >= 0 && i <= 128; --i) {
remainder = remainder.logicalShiftLeftOnce();
remainder.setBit(0, getBit(i));
if (remainder >= other) {
remainder -= other;
quotient.setBit(i, true);
}
}
return {quotient, remainder};
}
}
std::pair<Int128, Int128> divisionWithRemainder(Int128 other) const {
if (isNegative()) {
const auto negativeThis = -*this;
if (other.isNegative()) {
const auto result = negativeThis.unsignedDivisionWithRemainder(-other);
return {result.first, -result.second};
} else {
const auto result = negativeThis.unsignedDivisionWithRemainder(other);
return {-result.first, -result.second};
}
} else {
if (other.isNegative()) {
const auto result = unsignedDivisionWithRemainder(-other);
return {-result.first, result.second};
} else {
return unsignedDivisionWithRemainder(other);
}
}
}
std::size_t findLeastSignificantBit() const {
std::size_t index = 0;
auto value = *this;
while (!value.getBit(0)) {
++index;
value = value.logicalShiftRightOnce();
}
return index;
}
std::size_t findMostSignificantBit() const {
std::size_t index = 0;
auto value = *this;
while (!value.isZero()) {
++index;
value = value.logicalShiftRightOnce();
}
return index;
}
std::string toString(std::size_t base = 10) const {
if (base < 2 || base >= 36) {
return "";
}
const auto negative = isNegative();
if (base == 10 && (high == 0 || (low < 0x8000000000000000 && high == UINT64_MAX))) {
if (negative) {
return std::to_string(-static_cast<int64_t>((low - 1) ^ UINT64_MAX));
} else {
return std::to_string(low);
}
} else {
char buffer[129] = {0};
std::size_t bufferIndex = 128;
std::pair<Int128, Int128> quotientAndRemainder(getAbsoluteValue(), zero());
do {
quotientAndRemainder = quotientAndRemainder.first.unsignedDivisionWithRemainder(Int128(base));
if (quotientAndRemainder.second.low < 10) {
buffer[--bufferIndex] = static_cast<char>(quotientAndRemainder.second.low + '0');
} else if (quotientAndRemainder.second.low < 36) {
buffer[--bufferIndex] = static_cast<char>(quotientAndRemainder.second.low - 10 + 'a');
}
} while (!quotientAndRemainder.first.isZero());
if (negative) {
buffer[--bufferIndex] = '-';
}
return std::string(&buffer[bufferIndex]);
}
}
enum class CheckedArithmeticResult {
Success,
OverflowError,
DivideByZeroError
};
std::pair<CheckedArithmeticResult, Int128> checkedAdd(Int128 other) const {
if (isNegative()) {
if (other.isNegative() && *this < minValue() - other) {
return {CheckedArithmeticResult::OverflowError, zero()};
}
} else {
if (!other.isNegative() && *this > maxValue() - other) {
return {CheckedArithmeticResult::OverflowError, zero()};
}
}
return {CheckedArithmeticResult::Success, *this + other};
}
std::pair<CheckedArithmeticResult, Int128> checkedSubtract(Int128 other) const {
if (isNegative()) {
if (!other.isNegative() && *this < minValue() + other) {
return {CheckedArithmeticResult::OverflowError, zero()};
}
} else {
if (other.isNegative() && *this > maxValue() + other) {
return {CheckedArithmeticResult::OverflowError, zero()};
}
}
return {CheckedArithmeticResult::Success, *this - other};
}
std::pair<CheckedArithmeticResult, Int128> checkedMultiply(Int128 other) const {
Int128 result;
if (isZero() || other.isZero()) {
return {CheckedArithmeticResult::Success, Int128()};
}
if (isNegative()) {
if (other.isNegative()) {
if (other < maxValue() / *this) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
} else {
const auto limit = minValue() / other;
if (*this < limit) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
}
} else {
if (other.isNegative()) {
if (other < minValue() / *this) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
} else {
if (*this > maxValue() / other) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
}
}
return {CheckedArithmeticResult::Success, *this * other};
}
std::pair<CheckedArithmeticResult, Int128> checkedDivide(Int128 other) const {
if (other.isZero()) {
return {CheckedArithmeticResult::DivideByZeroError, Int128()};
} else if (*this == minValue() && other == Int128(-1)) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
return {CheckedArithmeticResult::Success, *this / other};
}
std::pair<CheckedArithmeticResult, Int128> checkedModulo(Int128 other) const {
if (other.isZero()) {
return {CheckedArithmeticResult::DivideByZeroError, Int128()};
} else if (*this == minValue() && other == Int128(-1)) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
return {CheckedArithmeticResult::Success, *this % other};
}
std::pair<CheckedArithmeticResult, Int128> checkedLogicalShiftLeft(std::size_t bits) const {
if (!isZero()) {
if (bits >= 128 || (bits > 0 && !logicalShiftRight(127 - bits).isZero())) {
return {CheckedArithmeticResult::OverflowError, Int128()};
}
}
return {CheckedArithmeticResult::Success, *this << bits};
}
explicit operator std::int8_t() const {
return isNegative() ? -static_cast<std::int8_t>((low - 1) ^ UINT64_MAX) : static_cast<std::int8_t>(low);
}
explicit operator std::int16_t() const {
return isNegative() ? -static_cast<std::int16_t>((low - 1) ^ UINT64_MAX) : static_cast<std::int16_t>(low);
}
explicit operator std::int32_t() const {
return isNegative() ? -static_cast<std::int32_t>((low - 1) ^ UINT64_MAX) : static_cast<std::int32_t>(low);
}
explicit operator std::int64_t() const {
return isNegative() ? -static_cast<std::int64_t>((low - 1) ^ UINT64_MAX) : static_cast<std::int64_t>(low);
}
explicit operator std::uint8_t() const {
return static_cast<std::uint8_t>(low);
}
explicit operator std::uint16_t() const {
return static_cast<std::uint16_t>(low);
}
explicit operator std::uint32_t() const {
return static_cast<std::uint32_t>(low);
}
explicit operator std::uint64_t() const {
return static_cast<std::uint64_t>(low);
}
Int128& operator =(const Int128& other) = default;
Int128& operator =(Int128&& other) = default;
bool operator ==(Int128 other) const {
return low == other.low && high == other.high;
}
bool operator !=(Int128 other) const {
return !(*this == other);
}
bool operator <(Int128 other) const {
if (isNegative()) {
if (other.isNegative()) {
return high < other.high
|| high == other.high && low < other.low;
} else {
return true;
}
} else {
if (other.isNegative()) {
return false;
} else {
return high < other.high
|| high == other.high && low < other.low;
}
}
}
bool operator <=(Int128 other) const {
return !(other < *this);
}
bool operator >(Int128 other) const {
return other < *this;
}
bool operator >=(Int128 other) const {
return !(*this < other);
}
Int128 operator ~() const {
return Int128(~low, ~high);
}
Int128 operator +() const {
return *this;
}
Int128 operator -() const {
Int128 result = ~*this;
++result;
return result;
}
Int128 operator &(Int128 other) const {
return Int128(low & other.low, high & other.high);
}
Int128 operator |(Int128 other) const {
return Int128(low | other.low, high | other.high);
}
Int128 operator ^(Int128 other) const {
return Int128(low ^ other.low, high ^ other.high);
}
Int128 operator <<(std::size_t bits) const {
if (bits == 0) {
return *this;
} else if (bits >= 128) {
return zero();
} else if (bits >= 64) {
return Int128(0, low << (bits - 64));
} else {
return Int128(low << bits, (high << bits) | (low >> (64 - bits)));
}
}
Int128& operator --() {
if (low == 0) {
--high;
}
--low;
return *this;
}
Int128& operator ++() {
++low;
if (low == 0) {
++high;
}
return *this;
}
Int128 operator --(int) {
Int128 result = *this;
--*this;
return result;
}
Int128 operator ++(int) {
Int128 result = *this;
++*this;
return result;
}
Int128 operator +(Int128 other) const {
const auto carry = other.low > UINT64_MAX - low;
return Int128(low + other.low, high + other.high + (carry ? 1 : 0));
}
Int128 operator -(Int128 other) const {
return *this + -other;
}
Int128 operator *(Int128 other) const {
if (other == one()) {
return *this;
} else if (isZero() || other.isZero()) {
return Int128();
} else if (((high == 0 || high == UINT64_MAX) && low <= UINT32_MAX) && ((other.high == 0 || other.high == UINT64_MAX) && other.low <= UINT32_MAX)) {
return Int128(low * other.low, (high == UINT64_MAX) != (other.high == UINT64_MAX) ? UINT64_MAX : 0);
} else {
// First do a 64 x 64 -> 128-bit multiply.
//
// a * b
// = (2^32 * ah + al) * (2^32 * bh + bl) [rewriting 64-bit values in their split 32-bit form]
// = 2^64 * ah * bh + 2^32 * ah * bl + 2^32 * al * bh + al * bl [expanding product]
// = w + z + y + x [giving names to each sub-product]
//
// x: 32 x 32 -> 64-bit product, bits 0..63
// y and z: 32 x 32 -> 64-bit product, shifted by 32 bits, bits 32..95
// w: 32 x 32 -> 64-bit product, bits 64..128
// addition happens across lo/hi word boundaries and can generate middle carries, so we need to add each piece separately as 128-bit integers.
//
// Then to make a 128 x 128 -> 128-bit multiply, we do similarly, but since we're asking for 128-bit instead of 256-bit result,
// we discard the upper product, and just keep the lower 128 bits of the result.
//
// 2^128 * (ah64 * bh64) + 2^64 * (ah64 * bl64 + al64 * bh64) + (al64 * bl64)
// = 2^64 * (ah64 * bl64 + al64 * bh64) + (al64 * bl64) modulo 2^128 [note: al64 * bl64 was figured out by 64x64 -> 64 multiply]
const auto al = low & 0xFFFFFFFF;
const auto ah = (low >> 32) & 0xFFFFFFFF;
const auto bl = other.low & 0xFFFFFFFF;
const auto bh = (other.low >> 32) & 0xFFFFFFFF;
const auto x = al * bl;
const auto y = al * bh;
const auto z = ah * bl;
const auto w = ah * bh;
return Int128(x, 0) + Int128(y << 32, y >> 32) + Int128(z << 32, z >> 32) + Int128(0, w + low * other.high + high * other.low);
}
}
Int128 operator /(Int128 other) const {
return divisionWithRemainder(other).first;
}
Int128 operator %(Int128 other) const {
return divisionWithRemainder(other).second;
}
Int128& operator +=(Int128 other) {
*this = *this + other;
return *this;
}
Int128& operator -=(Int128 other) {
*this = *this - other;
return *this;
}
Int128& operator *=(Int128 other) {
*this = *this * other;
return *this;
}
Int128& operator /=(Int128 other) {
*this = *this / other;
return *this;
}
Int128& operator %=(Int128 other) {
*this = *this % other;
return *this;
}
Int128& operator &=(Int128 other) {
*this = *this & other;
return *this;
}
Int128& operator |=(Int128 other) {
*this = *this | other;
return *this;
}
Int128& operator ^=(Int128 other) {
*this = *this ^ other;
return *this;
}
Int128& operator <<=(std::size_t bits) {
*this = *this << bits;
return *this;
}
friend std::ostream& operator<<(std::ostream& out, const Int128& value);
std::uint64_t low;
std::uint64_t high;
};
inline std::ostream& operator <<(std::ostream& out, const Int128& value) {
const auto negative = value.isNegative();
if (value.high == 0 || (value.low < 0x8000000000000000 && value.high == UINT64_MAX)) {
if (negative) {
out << -static_cast<int64_t>((value.low - 1) ^ UINT64_MAX);
} else {
out << value.low;
}
} else {
std::size_t base = 10;
if ((out.flags() & out.dec) != 0) {
base = 10;
} else if ((out.flags() & out.hex) != 0) {
base = 16;
} else if ((out.flags() & out.oct) != 0) {
base = 8;
}
char buffer[129] = {0};
std::size_t bufferIndex = 128;
std::pair<Int128, Int128> quotientAndRemainder(value.getAbsoluteValue(), Int128::zero());
do {
quotientAndRemainder = quotientAndRemainder.first.unsignedDivisionWithRemainder(Int128(base));
if (quotientAndRemainder.second.low < 10) {
buffer[--bufferIndex] = static_cast<char>(quotientAndRemainder.second.low + '0');
} else if (quotientAndRemainder.second.low < 36) {
buffer[--bufferIndex] = static_cast<char>(quotientAndRemainder.second.low - 10 + 'a');
}
} while (!quotientAndRemainder.first.isZero());
if (negative) {
buffer[--bufferIndex] = '-';
}
out << &buffer[bufferIndex];
}
return out;
}
}
#endif
#include <iostream>
#include "int128.h"
int main() {
{
// Create an Int128 from an uint64_t (also possible from int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, and uint32_t)
wiz::Int128 a(0xFFFFFFFFFFFFFFFF);
// prints 18446744073709551615
std::cout << a.toString() << std::endl;
}
// The largest negative number. Uses constructor that takes two unsigned 64-bit integers representing the raw two's complement data.
{
wiz::Int128 b(0x0000000000000000, 0x8000000000000000);
// Print in a few different bases.
// prints -170141183460469231731687303715884105728
std::cout << b << std::endl;
// prints -2000000000000000000000000000000000000000000
std::cout << std::oct << b << std::endl;
// prints -80000000000000000000000000000000
std::cout << std::hex << b << std::endl;
// prints -10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
std::cout << b.toString(2) << std::endl;
}
// Parse a 128-bit number from a string. Last argument is the base, which defaults to 10.
// Parsing returns a std::pair<ParseResult, Int128>, where ParseResult is Success when a value could be successfully read.
{
const auto parseResult = wiz::Int128::parse("7fffffffffffffffffffffffffffffff", 16);
// prints success = true value = 0x7fffffffffffffffffffffffffffffff
std::cout << std::hex << "success = " << (parseResult.first == wiz::Int128::ParseResult::Success ? "true" : "false") << " value = " << parseResult.second << std::endl;
}
// Parsing will also ensure valid radix, input number format and input number range
{
// Range error, needs to be in 128-bit range.
const auto parseResult = wiz::Int128::parse("123456789123456789123456789123456789123456789123456789123456789123456789");
// prints success = false value = 0
std::cout << std::hex << "success = " << (parseResult.first == wiz::Int128::ParseResult::Success ? "true" : "false") << " value = " << parseResult.second << std::endl;
}
{
// Hex digits are not possible in base 10
const auto parseResult = wiz::Int128::parse("abc");
// prints success = false value = 0
std::cout << std::hex << "success = " << (parseResult.first == wiz::Int128::ParseResult::Success ? "true" : "false") << " value = " << parseResult.second << std::endl;
}
// prints 670629624197529624197529624197492745
std::cout << std::dec << (wiz::Int128::parse("12345678912345678912345678912345").second * wiz::Int128(54321)) << std::endl;
// prints 579
std::cout << std::dec << (wiz::Int128(123) + wiz::Int128(456)) << std::endl;
// prints 0x40000
std::cout << std::hex << "0x" << (wiz::Int128(0x1000) << 6) << std::endl;
return 0;
}
@SPyofgame200
Copy link

Can you also implement fast square root ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment