Skip to content

Instantly share code, notes, and snippets.

@Kittoes0124
Last active April 13, 2024 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kittoes0124/275bc8e97f1a1844ab4084e310346a26 to your computer and use it in GitHub Desktop.
Save Kittoes0124/275bc8e97f1a1844ab4084e310346a26 to your computer and use it in GitHub Desktop.
using System.Numerics;
using System.Runtime.CompilerServices;
var input = UInt128.MaxValue;
var v = ((byte)input).SquareRoot();
Console.WriteLine($"{v}");
var w = ((ushort)input).SquareRoot();
Console.WriteLine($"{w}");
var x = ((uint)input).SquareRoot();
Console.WriteLine($"{x}");
var y = ((ulong)input).SquareRoot();
Console.WriteLine($"{y}");
var z = UInt128.CreateChecked(value: input).SquareRoot();
Console.WriteLine($"{z}");
public static class BinaryIntegerConstants<T> where T : IBinaryInteger<T>
{
/// <returns>
/// The result of (1 - SQRT(0.5)) in the fixed-point format Q0.((T.PopCount(value: T.AllBitsSet) >> 1) - 1).
/// </returns>
internal static T ComputeOneMinusSquareRootOfOneHalf() {
if (typeof(T) == typeof(byte)) { return T.CreateChecked(value: 5); }
if (typeof(T) == typeof(ushort)) { return T.CreateChecked(value: 75); }
var @base = (T.One << 8);
var length = ((T.PopCount(value: T.AllBitsSet) >> 4) + T.One);
var x = (T.One << 1);
var y = T.Zero;
do {
if (x < ((y << 1) + T.One)) {
x *= (@base * @base);
y *= @base;
}
else {
x = ((x - (y << 1)) - T.One);
y += T.One;
}
} while (y < @base.Exponentiate(exponent: length));
return ((@base.Exponentiate(exponent: (length - T.One)) - ((y / @base) >> 1)) - T.One);
}
public static T Log2Size { get; }
public static T OneMinusSquareRootOfOneHalf { get; }
public static T Size { get; }
public static T Ten { get; }
static BinaryIntegerConstants() {
var size = T.PopCount(value: T.AllBitsSet);
Log2Size = T.Log2(value: size);
OneMinusSquareRootOfOneHalf = ComputeOneMinusSquareRootOfOneHalf();
Size = size;
Ten = T.CreateChecked(value: 10U);
}
}
public static class BinaryIntegerFunctions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static T As<T>(this bool value) where T : IBinaryInteger<T> =>
T.CreateTruncating(value: Unsafe.As<bool, byte>(source: ref value));
public static T Exponentiate<T>(this T value, T exponent) where T : IBinaryInteger<T> {
var result = T.One;
do {
if (T.IsOddInteger(value: exponent)) {
result *= value;
}
exponent >>= 1;
value *= value;
} while (T.Zero < exponent);
return result;
}
public static T MostSignificantBit<T>(this T value) where T : IBinaryInteger<T> =>
(BinaryIntegerConstants<T>.Size - T.LeadingZeroCount(value: value));
}
public static class UnsignedNumberFunctions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static T IsGreaterThan<T>(this T value, T other) where T : IBinaryInteger<T> =>
(value > other).As<T>();
public static T SquareRoot<T>(this T value) where T : IBinaryInteger<T>, IUnsignedNumber<T> {
var bitCount = int.CreateChecked(value: BinaryIntegerConstants<T>.Size);
return bitCount switch {
#if !FORCE_SOFTWARE_SQRT
8 => T.CreateTruncating(value: ((uint)MathF.Sqrt(x: uint.CreateTruncating(value: value)))),
16 => T.CreateTruncating(value: ((uint)MathF.Sqrt(x: uint.CreateTruncating(value: value)))),
32 => T.CreateTruncating(value: ((uint)Math.Sqrt(d: uint.CreateTruncating(value: value)))),
64 => T.CreateTruncating(value: Sqrt(value: ulong.CreateTruncating(value: value))),
#endif
_ => SoftwareImplementation(value: value),
};
/*
Credit goes to njuffa for providing a reference implementation:
https://stackoverflow.com/a/31149161/1186165
Notes:
- This implementation of the algorithm runs in constant time, based on the size of T.
- Ignoring the loop that is entered when the size of T exceeds 64, all branches get eliminated during JIT compilation.
*/
static T SoftwareImplementation(T value) {
var bitCount = int.CreateChecked(value: BinaryIntegerConstants<T>.Size);
var msb = int.CreateTruncating(value: value.MostSignificantBit());
var msbIsOdd = (msb & 1);
var m = ((msb + 1) >> 1);
var mMinusOne = (m - 1);
var mPlusOne = (m + 1);
var x = (T.One << mMinusOne);
var y = (x - (value >> (mPlusOne - msbIsOdd)));
var z = y;
x += x;
if (bitCount > 8) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (bitCount > 16) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (bitCount > 32) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (bitCount > 64) {
var i = T.CreateTruncating(value: (BinaryIntegerConstants<T>.Size >> 3));
do {
i -= (T.One << 3);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
} while (T.Zero < i);
}
y = (x - y);
x = T.CreateTruncating(value: msbIsOdd);
y -= bitCount switch {
8 => (x * ((y * T.CreateChecked(value: 5UL)) >> 4)),
16 => (x * ((y * T.CreateChecked(value: 75UL)) >> 8)),
_ => (x * ((y * BinaryIntegerConstants<T>.OneMinusSquareRootOfOneHalf) >> (bitCount >> 1))),
};
x = (T.One << (bitCount - 1));
y -= (value - (y * y)).IsGreaterThan(other: x);
if (bitCount > 8) {
y -= (value - (y * y)).IsGreaterThan(other: x);
y -= (value - (y * y)).IsGreaterThan(other: x);
}
if (bitCount > 32) {
y -= (value - (y * y)).IsGreaterThan(other: x);
y -= (value - (y * y)).IsGreaterThan(other: x);
y -= (value - (y * y)).IsGreaterThan(other: x);
}
if ((bitCount > 128) && T.IsEvenInteger(value: BinaryIntegerConstants<T>.Log2Size) && T.IsPow2(value: BinaryIntegerConstants<T>.Size)) {
var i = int.CreateChecked(value: (BinaryIntegerConstants<T>.Log2Size >> 1));
do {
y -= (value - (y * y)).IsGreaterThan(other: x);
} while (0 < --i);
}
return (y & (T.AllBitsSet >> 1));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static uint Sqrt(ulong value) {
var x = ((uint)Math.Sqrt(d: unchecked((long)value)));
var y = (unchecked(((ulong)x) * x) > value).As<uint>(); // ((x * x) > value) ? 1 : 0
var z = ((uint)(value >> 63)); // (64 == value.MostSignificantBit()) ? 1 : 0
return unchecked(x - (y | z));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment