Last active
November 11, 2017 16:24
-
-
Save AdrienLemaire/20d428684a65b32d7c02 to your computer and use it in GitHub Desktop.
Benchmark 2 methods to get the indices for a multi-dimensional array N highest values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# $ benchmark.py | |
# | |
# Test with test_argpartition: | |
# Results: | |
# [[78 55 72] | |
# [ 0 50 15] | |
# [ 8 44 27] | |
# [51 93 75] | |
# [38 24 82]] | |
# Time needed to execute 100: 0.996204137802sec | |
# | |
# Test with test_where: | |
# Results: | |
# [[ 0 50 15] | |
# [ 8 44 27] | |
# [38 24 82] | |
# [51 93 75] | |
# [78 55 72]] | |
# Time needed to execute 100: 22.2839028835sec | |
import heapq | |
import numpy as np | |
import timeit | |
HIST = np.random.random((100,100,100)) | |
NB_VAL = 5 | |
NB_BENCHMARK = 100 | |
def test_argpartition(): | |
indices = np.argpartition(HIST.flatten(), -NB_VAL)[-NB_VAL:] | |
result = np.vstack(np.unravel_index(indices, HIST.shape)).T | |
return result | |
def test_where(): | |
result = np.vstack(np.where(HIST>=heapq.nlargest(NB_VAL, HIST.flatten())[-1])).T | |
return result | |
def run_test(f): | |
print "\nTest with {}:".format(f.__name__) | |
print "Results:\n{}".format(f()) | |
print "Time needed to execute {}: {}sec".format( | |
NB_BENCHMARK, | |
timeit.timeit( | |
"f()", | |
setup="from __main__ import {} as f".format(f.__name__), | |
number=NB_BENCHMARK, | |
), | |
) | |
if __name__ == '__main__': | |
run_test(test_argpartition) | |
run_test(test_where) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment