Skip to content

Instantly share code, notes, and snippets.

@mfornet
Created June 11, 2022 00:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfornet/ce273690b60e60f1b595902ac940cc65 to your computer and use it in GitHub Desktop.
Save mfornet/ce273690b60e60f1b595902ac940cc65 to your computer and use it in GitHub Desktop.
Find proper shift to bring a number down to less than `n` bits
def shift_slow(n, bits):
shift = 0
while n >= 2**bits:
n >>= 1
shift += 1
return shift
def compute_logbits(bits):
# Compute logbits
logbits = 0
while (1 << (logbits + 1)) < bits:
logbits += 1
assert (1 << logbits) < bits
assert (1 << (logbits + 1)) >= bits
return logbits
def shift_fast(n, bits, expected_logbits=None):
logbits = compute_logbits(bits)
if expected_logbits is not None:
assert expected_logbits == logbits
if n < 2**bits:
return 0
s = 0
for i in range(logbits, -1, -1):
b = n >> (1 << i)
if b >= 1 << bits:
s += 1 << i
n = b
return s + 1
def main():
assert compute_logbits(8) == 2
for i in range(2**16):
assert shift_slow(i, 8) == shift_fast(
i, 8), (i, shift_slow(i, 8), shift_fast(i, 8))
for i in range(2**16):
assert shift_slow(i, 128) == shift_fast(i, 128)
assert compute_logbits(128) == 6
for i in range(-2**16, 2**16):
assert shift_slow(2**128 + i, 128) == shift_fast(2**128 + i, 128)
for i in range(2**16):
assert shift_slow(
2**256 - i - 1, 128) == shift_fast(2**256 - i - 1, 128), i
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment