Last active
October 3, 2022 15:32
-
-
Save murbard/5c711d8efe114c1ecb2152579937064c to your computer and use it in GitHub Desktop.
average number of queries to sort
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import random | |
class Oracle(): | |
def __init__(self, n, replace = True): | |
self.n = n | |
self.count = 0 | |
self.replace = replace | |
self.returned = [] | |
def query(self): | |
# draw random pair of distinct elements | |
while True: | |
i, j = random.sample(range(self.n), 2) | |
ordered = (i, j) if i < j else (j, i) | |
if self.replace or ordered not in self.returned: | |
break | |
if not self.replace: | |
self.returned.append(ordered) | |
self.count += 1 | |
return ordered | |
def test(n, replace = True): | |
oracle = Oracle(n, replace) | |
smaller = np.eye(n) | |
while True: | |
i, j = oracle.query() | |
smaller[i, j] = 1 | |
# transitive closure | |
for k in range(n): | |
for i in range(n): | |
for j in range(n): | |
if smaller[i, k] and smaller[k, j]: | |
smaller[i, j] = 1 | |
if np.all(smaller + smaller.T > 0): | |
break | |
return oracle.count | |
if __name__ == "__main__": | |
for i in range(2, 20): | |
tests = np.array([test(i) for _ in range(10000)]) | |
print('w/ replace', i, tests.mean(), tests.std() / np.sqrt(len(tests))) | |
for i in range(2, 20): | |
tests = np.array([test(i, replace=False) for _ in range(10000)]) | |
print('wo replace', i, tests.mean(), tests.std() / np.sqrt(len(tests))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment