Created
June 11, 2022 00:12
-
-
Save mfornet/ce273690b60e60f1b595902ac940cc65 to your computer and use it in GitHub Desktop.
Find proper shift to bring a number down to less than `n` bits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def shift_slow(n, bits): | |
shift = 0 | |
while n >= 2**bits: | |
n >>= 1 | |
shift += 1 | |
return shift | |
def compute_logbits(bits): | |
# Compute logbits | |
logbits = 0 | |
while (1 << (logbits + 1)) < bits: | |
logbits += 1 | |
assert (1 << logbits) < bits | |
assert (1 << (logbits + 1)) >= bits | |
return logbits | |
def shift_fast(n, bits, expected_logbits=None): | |
logbits = compute_logbits(bits) | |
if expected_logbits is not None: | |
assert expected_logbits == logbits | |
if n < 2**bits: | |
return 0 | |
s = 0 | |
for i in range(logbits, -1, -1): | |
b = n >> (1 << i) | |
if b >= 1 << bits: | |
s += 1 << i | |
n = b | |
return s + 1 | |
def main(): | |
assert compute_logbits(8) == 2 | |
for i in range(2**16): | |
assert shift_slow(i, 8) == shift_fast( | |
i, 8), (i, shift_slow(i, 8), shift_fast(i, 8)) | |
for i in range(2**16): | |
assert shift_slow(i, 128) == shift_fast(i, 128) | |
assert compute_logbits(128) == 6 | |
for i in range(-2**16, 2**16): | |
assert shift_slow(2**128 + i, 128) == shift_fast(2**128 + i, 128) | |
for i in range(2**16): | |
assert shift_slow( | |
2**256 - i - 1, 128) == shift_fast(2**256 - i - 1, 128), i | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment