Skip to content

Instantly share code, notes, and snippets.

@kadereub
Last active September 22, 2020 08:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kadereub/d3390f75be00df14e65f2a16eb0dc9d9 to your computer and use it in GitHub Desktop.
Save kadereub/d3390f75be00df14e65f2a16eb0dc9d9 to your computer and use it in GitHub Desktop.
A `numba` implementation of heaps permutation algorithm (non-recursive)
import numpy as np
import numba as nb
# References
# [1] https://en.wikipedia.org/wiki/Heap%27s_algorithm
@nb.njit
def _factorial(n):
if n == 1:
return n
else:
return n * _factorial(n-1)
@nb.njit
def numba_heap_permutations(arr, d):
"""
Generating permutations of an array using Heap's Algorithm
Args:
arr (numpy.array): A vector of int/floats which one would like the permutations of
d (int): The number of permutations, this should in most cases be equal to arr.shape[0]
Returns:
(numpy.array): An array of d! rows and d columns, containing all permutations of arr
"""
d_fact = _factorial(d)
c = np.zeros(d, dtype=np.int32)
res = np.zeros(shape=(d_fact, d))
counter = 0
i = 0
res[counter] = arr
counter += 1
while i < d:
if c[i] < i:
if i % 2 == 0:
arr[0], arr[i] = arr[i], arr[0]
else:
arr[c[i]], arr[i] = arr[i], arr[c[i]]
# Swap has occurred ending the for-loop. Simulate the increment of the for-loop counter
c[i] += 1
res[counter] = arr
counter += 1
# Simulate recursive call reaching the base case by bringing setting i to the base case
i = 0
else:
# Calling the func(i+1, A) has ended as the for-loop terminated.
# Reset the state and increment i.
c[i] = 0
i += 1
# Return array of d! rows and d columns (all permutations)
return res
# Example use
numba_heap_perumatations(np.array([1, 2, 3, 4]), 4)
@kadereub
Copy link
Author

kadereub commented Feb 16, 2020

Very crude speed comparison:

Using numba heap permutations

%%timeit
numba_heap_permutations(np.array([1, 2, 3, 4, 5]), 5)

3.79 µs ± 165 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Using itertools permutations

%%timeit
list(itertools.permutations([1, 2, 3, 4, 5]))

7.19 µs ± 233 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

From my quick performance comparisons, looks like the numba implementation is faster for larger arrays... Additionally it's quite useful if you want to exclusively work in numpy.arrays and not python lists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment