Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Get a Fibonacci number with O(log(n)).
# coding: utf-8
"""Get a Fibonacci number with O(log(n)).
1, 1, 2, 3, 5, 8, 13, 21, ...
The logic to get a Figonacci number with O(log(n)) is explained at https://kukuruku.co/post/the-nth-fibonacci-number-in-olog-n .
"""
from functools import lru_cache, reduce
P = (
(1, 1),
(1, 0),
)
E = (
(1, 0),
(0, 1),
)
def get_fibonacci(n, cache = False):
"""Gets Fibonacci number.
"""
func = pow_matrix if cache else pow_matrix
matrices = (func(P, x) for x in get_2exponent_series(n))
matrix = reduce(multiply_matrix, matrices, E)
return matrix[0][0]
def get_2exponent_series(n):
"""Gets values which the passed value is composed of.
"""
binary_digits = reversed(bin(n)[2:])
exponents = [i for i, digit in enumerate(binary_digits) if digit == '1']
return exponents
@lru_cache(128)
def pow_matrix(m, exp2):
"""Caluculates m^(2^exp2).
"""
if exp2 < 0:
raise Exception('exp2 must be equal to or greater than 0.')
if exp2 == 0:
return m
else:
sub = pow_matrix(m, exp2 - 1)
return multiply_matrix(sub, sub)
def multiply_matrix(m1, m2):
"""Multiplies 2 matrices.
"""
a11 = m1[0][0] * m2[0][0] + m1[0][1] * m2[1][0]
a12 = m1[0][0] * m2[0][1] + m1[0][1] * m2[1][1]
a21 = m1[1][0] * m2[0][0] + m1[1][1] * m2[1][0]
a22 = m1[1][0] * m2[0][1] + m1[1][1] * m2[1][1]
return (
(a11, a12),
(a21, a22),
)
if __name__ == '__main__':
import unittest
class FibonacciTestCase(unittest.TestCase):
def test_get_fibonacci(self):
pairs = [
[0, 1],
[1, 1],
[2, 2],
[3, 3],
[4, 5],
[5, 8],
[6, 13],
[7, 21],
]
for n, expected in pairs:
self.assertEqual(get_fibonacci(n), expected)
def test_get_2exponent_series(self):
pairs = [
[0, []],
[1, [0]],
[2, [1]],
[5, [0, 2]],
[9, [0, 3]],
[15, [0, 1, 2, 3]],
]
for n, expected in pairs:
self.assertEqual(get_2exponent_series(n), expected)
def test_pow_matrix(self):
self.assertEqual(pow_matrix(E, 0), E)
self.assertEqual(pow_matrix(E, 5), E)
self.assertEqual(pow_matrix(P, 0), P)
self.assertEqual(pow_matrix(P, 1), ((2, 1), (1, 1)))
def test_multiply_matrix(self):
self.assertEqual(multiply_matrix(E, E), E)
self.assertEqual(multiply_matrix(E, P), P)
self.assertEqual(multiply_matrix(P, E), P)
self.assertEqual(multiply_matrix(P, P), ((2, 1), (1, 1)))
def test_pow_matrix_cached(self):
pow_matrix.cache_clear()
pow_matrix(P, 3)
pow_matrix(P, 6)
pow_matrix(P, 9)
self.assertTrue(pow_matrix.cache_info().hits > 0)
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment