Last active
September 2, 2015 19:11
-
-
Save mbarkhau/c4671ba5c0087a7631d7 to your computer and use it in GitHub Desktop.
Failed attempt at a performant SortedCounter based on sortedcontainers.SortedDict and defaultdict
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Failed attempt at a performant SortedCounter based on | |
sortedcontainers.SortedDict and defaultdict | |
At least it is not useful for my use case which is insertion | |
heavy, YMMV. | |
Since the collections.Counter class doesn't keep the sort order | |
of its items, it has to recalculate the order for every call | |
to most_common is done. Also it doesn't provide any function | |
for the least common items. This class is always sorted and | |
simple iteration/reverse iteration gives the least/most common | |
keys without any recalculation. | |
""" | |
from __future__ import print_function | |
import collections | |
from sortedcontainers import SortedDict, SortedListWithKey | |
from sortedcontainers.sorteddict import _IlocWrapper | |
class SortedCounter(SortedDict, collections.defaultdict): | |
"""Same as SortedDict but with defaultdict as baseclass | |
which is always initialized with int and has hardcoded | |
_key method to sort by count. | |
""" | |
def __init__(self, *args, **kwargs): | |
if len(args) > 0 and type(args[0]) == int: | |
self._load = args[0] | |
args = args[1:] | |
else: | |
self._load = 1000 | |
self._list = SortedListWithKey(key=self._key, load=self._load) | |
collections.defaultdict.__init__(self, int) | |
# Cache function pointers to defaultdict methods. | |
_dict = super(SortedDict, self) | |
self._dict = _dict | |
self._clear = _dict.clear | |
self._delitem = _dict.__delitem__ | |
self._iter = _dict.__iter__ | |
self._pop = _dict.pop | |
self._setdefault = _dict.setdefault | |
self._setitem = _dict.__setitem__ | |
self._update = _dict.update | |
# Cache function pointers to SortedList methods. | |
_list = self._list | |
self._list_add = _list.add | |
self.bisect_left = _list.bisect_left | |
self.bisect = _list.bisect_right | |
self.bisect_right = _list.bisect_right | |
self._list_clear = _list.clear | |
self.index = _list.index | |
self._list_pop = _list.pop | |
self._list_remove = _list.remove | |
self._list_update = _list.update | |
self._list_index = _list.index | |
self._list__pos = _list._pos | |
self.irange = _list.irange | |
self.islice = _list.islice | |
if self._key is not None: | |
self.bisect_key_left = _list.bisect_key_left | |
self.bisect_key_right = _list.bisect_key_right | |
self.bisect_key = _list.bisect_key | |
self.irange_key = _list.irange_key | |
self.iloc = _IlocWrapper(self) | |
self.update(*args, **kwargs) | |
def __setitem__(self, key, value): | |
if key in self: | |
self._list_remove(key) | |
self._setitem(key, value) | |
self._list_add(key) | |
def __setitem_fail__(self, key, value): | |
# I tried only removing if the value of the left item is | |
# greater than the new value for the key, or when the value | |
# of the right item is less then the new value for the key. | |
# As it turns out though, the calls to self._list_index are | |
# even more expensive than simply removing the key and | |
# reinserting it so that it will be automaticaly repositioned. | |
# Also there seems to be a bug in here somewhere, but I don't | |
# think it affects performance | |
if key not in self: | |
self._setitem(key, value) | |
self._list_add(key) | |
return | |
key_idx = self._list_index(key) | |
resort_required = False | |
if key_idx > 0: | |
left_val = self._list[key_idx - 1] | |
if left_val > value: | |
resort_required = True | |
if key_idx < len(self) - 1: | |
right_val = self._list[key_idx + 1] | |
if right_val < value: | |
resort_required = True | |
if resort_required: | |
if key in self: | |
self._list_remove(key) | |
self._setitem(key, value) | |
self._list_add(key) | |
def _key(self, key): | |
return self[key] | |
import random | |
_test_items = [ | |
(letter, random.random()) | |
for letter in "abcdefghijk" | |
] | |
def main(): | |
sc = SortedCounter(_test_items) | |
assert sorted(sc.values()) == list(sc.values()) | |
assert len(sc) == len(_test_items) | |
sc.update(_test_items) | |
assert sorted(sc.values()) == list(sc.values()) | |
assert len(sc) == len(_test_items) | |
sc['foo'] += -100 | |
assert sorted(sc.values()) == list(sc.values()) | |
sc['bar'] += 100 | |
assert sorted(sc.values()) == list(sc.values()) | |
sc['baz'] = 1000 | |
assert sorted(sc.values()) == list(sc.values()) | |
# for k, v in sc.items(): | |
# print(k, v) | |
# for k, v in sc.irange_items(reverse=True): | |
# print(k, v) | |
import time | |
_rand_increments = None | |
def _bench(cls): | |
t0 = time.time() | |
c = cls() | |
for k, inc in _rand_increments: | |
c[k] += inc | |
return time.time() - t0 | |
def bench(): | |
import cProfile | |
import pstats | |
import random | |
import string | |
global _rand_increments | |
_rand_increments = [ | |
(random.choice(string.letters), random.random()) | |
for i in range(10000) | |
] | |
sc_times = [] | |
dd_times = [] # bench insertion of defaultdict for comparison | |
for i in range(5): | |
filename = 'profile_stats_%d.stats' % i | |
cProfile.run("_bench(SortedCounter)", filename) | |
sc_times.append(_bench(SortedCounter)) | |
dd_times.append(_bench(lambda: collections.defaultdict(int))) | |
# Read all stats files into a single object | |
stats = pstats.Stats('profile_stats_0.stats') | |
for i in range(1, 5): | |
stats.add('profile_stats_%d.stats' % i) | |
# stats.strip_dirs() | |
stats.sort_stats('tottime') | |
stats.print_stats() | |
print("sc best of five runs:", min(sc_times)) | |
print("dd best of five runs:", min(dd_times)) | |
if __name__ == '__main__': | |
main() | |
bench() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment