Skip to content

Instantly share code, notes, and snippets.

@jepler
Created February 24, 2021 03:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jepler/a256bb6058bb64492aa548365d15cf5c to your computer and use it in GitHub Desktop.
Save jepler/a256bb6058bb64492aa548365d15cf5c to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
import sys
import math
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool as Pool
def add(x):
return x[0] + x[1]
def mul(x):
return x[0] * x[1]
pool = Pool(8)
try:
import gmpy2
except ImportError:
data_type = int
else:
data_type = gmpy2.mpz
@dataclass
class M2:
m00: data_type
m01: data_type
m10: data_type
m11: data_type
def __mul__(l, r):
even_products = pool.imap(mul, (
(l.m00, r.m00),
(l.m00, r.m01),
(l.m10, r.m00),
(l.m10, r.m01)))
odd_products = pool.imap(mul, (
(l.m01, r.m10),
(l.m01, r.m11),
(l.m11, r.m10),
(l.m11, r.m11)))
return M2(*pool.map(add, zip(even_products, odd_products)))
def __pow__(self, exp):
return fast_power(self, exp, self.mul_idn)
zero = data_type(0)
one = data_type(1)
M2.mul_idn = M2(data_type(1), data_type(0), data_type(0), data_type(1))
fib_gen = M2(zero, one, one, one)
def fast_power(base, exp, mul_idn):
result = mul_idn
if not exp:
return result
for i in range(exp.bit_length()):
print(end=f"{i}/{exp.bit_length()}\r", file=sys.stderr)
sys.stderr.flush()
if exp & (1<<i):
result *= base
base *= base
print(end=" " * len(f"{i}/{exp.bit_length()}") + "\r", file=sys.stderr)
sys.stderr.flush()
return result
def fib(n):
if n < 2:
return n
return(fib_gen ** (n-1)).m11
if __name__ == '__main__':
def conv_bit_length(n):
return f"{n.bit_length()} bits"
def conv_dec_length(n):
dig_est = math.floor(n.bit_length() * math.log(2, 10))
nn = data_type(10) ** dig_est
while nn <= n:
dig_est += 1
nn *= 10
tail = n % (10 ** 24)
head = n // (data_type(10) ** (dig_est - 24))
return f"{dig_est} digits: {head}…{tail}"
def conv_auto(n):
# This cut-off happens to be where the line exceeds 80 chars
if n.bit_length() < 250:
return str(n)
else:
return conv_dec_length(n)
import argparse
parser = argparse.ArgumentParser(description='Calculate fibonacci numbers')
parser.add_argument('integers', metavar='N', type=int, nargs='+',
help='Values for which to calculate fib(N)')
parser.add_argument('--auto', dest='converter', action='store_const',
const=str, default=conv_auto,
help='Print short numbers in base-10, long numbers as decimal length, and extremely long numbers as bit length')
parser.add_argument('--dec', dest='converter', action='store_const',
const=str, default=conv_auto,
help='Print the number in base-10')
parser.add_argument('--bit-length', dest='converter', action='store_const',
const=conv_bit_length,
help='Print the bit length of the number, instead of the number')
parser.add_argument('--dec-length', dest='converter', action='store_const',
const=conv_dec_length,
help='Print the bit length of the number, instead of the number')
parser.add_argument('--bin', dest='converter', action='store_const',
const=bin,
help='Print the number in base-2')
parser.add_argument('--hex', dest='converter', action='store_const',
const=hex,
help='Print the number in base-16')
parser.add_argument('--oct', dest='converter', action='store_const',
const=oct,
help='Print the number in base-8')
parser.add_argument('--ranges', dest='ranges', action='store_true',
default=False,
help='Treat pairs of numbers as inclusive ranges')
args = parser.parse_args()
if args.ranges:
for i in range(0, len(args.integers), 2):
for n in range(args.integers[i], args.integers[i+1]+1):
print(f'{n}: {args.converter(fib(n))}')
else:
for n in args.integers:
print(f'{n}: {args.converter(fib(n))}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment