Skip to content

Instantly share code, notes, and snippets.

@internetimagery
Last active May 30, 2021 10:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save internetimagery/41e0a824620c5959df6a0ac5fb5f826d to your computer and use it in GitHub Desktop.
Save internetimagery/41e0a824620c5959df6a0ac5fb5f826d to your computer and use it in GitHub Desktop.
Global Cache. Manage cache through proxies. Invalidate data via tags, and automatically invalidate dependencies.
import time
import heapq
import types
import itertools
import functools
import threading
import contextlib
import collections
import operator
from concurrent.futures import Future
def get_cache():
""" Get global cache """
return Cache.instance()
def cache(ttl=0, lfu=0, lru=0, tags=None):
""" Decorate function to have its values cached """
def decorator(func):
return CacheProxy(func, ttl, lfu, lru, tags)
return decorator
class CacheProxy(object):
def __init__(self, func, ttl=0, lfu=0, lru=0, tags=None):
self._func = func
self._ttl = ttl
self._lfu = lfu
self._lru = lru
self._tags = set(tags) if tags else set()
self._tags.add(func)
def invalidate(self):
cache = get_cache()
cache.invalidate(tags=self._tags)
def get_key(self, *args, **kwargs):
cache = get_cache()
return cache.generate_key(self._func, *args, **kwargs)
def is_cached(self, *args, **kwargs):
cache = get_cache()
key = cache.generate_key(self._func, *args, **kwargs)
try:
cache.get(key)
except cache.Missing:
return False
return True
def __call__(self, *args, **kwargs):
cache = get_cache()
key = cache.generate_key(self._func, *args, **kwargs)
with cache.context(key) as parent_key:
try:
return cache.get(key, parent_key=parent_key)
except cache.Missing:
pass
if self._lfu or self._lru:
cache.invalidate(tags=self._tags, lfu=self._lfu, lru=self._lru)
value = self._func(*args, **kwargs)
cache.set(key, value, parent_key=parent_key, tags=self._tags, ttl=self._ttl)
if isinstance(value, Future):
def callback(future):
if future.exception() or future.cancelled() and self._data_cache.get(key) is future:
cache.invalidate(keys=[key])
value.add_done_callback(callback)
return value
@property
def __wrapped__(self):
return self._func
def __get__(self, inst, _=None):
return types.MethodType(self, inst) if inst else self
def __repr__(self):
return "<{0.__class__.__name__} wrapping {0._func!r} at {1:#x}>".format(self, id(self))
class Item(object):
__slots__ = ("value", "_expires", "access_count", "access_time")
def __init__(self, value, ttl=0):
self.value = value
self.access_time = time.time()
self.access_count = 0
self._expires = (self.access_time + ttl) if ttl else None
def expired(self, current_time):
return current_time >= self._expires if self._expires else False
class Key(object):
__slots__ = ("_hash", "_value")
def __init__(self, *value):
self._value = value
self._hash = hash(value)
def __hash__(self):
return self._hash
def __eq__(self, other):
if not isinstance(other, Key):
return NotImplemented
return self._value == other._value
def __ne__(self, other):
if not isinstance(other, Key):
return NotImplemented
return self._value != other._value
def __repr__(self):
return "<{0.__class__.__name__} {1} at {2:#x}>".format(self, self._hash, id(self))
class Cache(object):
_instance = None
_lock = threading.RLock()
_local = threading.local()
_cleanup_frequency = 60 * 3
class Missing(Exception):
pass
@classmethod
def instance(cls, *args, **kwargs):
""" Keep a global cache instance if desired """
inst = cls._instance
if not inst:
inst = cls._instance = cls(*args, **kwargs)
return inst
def generate_key(self, func, *args, **kwargs):
""" Create something we can use as a key for storing information """
try:
return Key(
self._hash(func),
tuple(self._hash(arg) for arg in args) if args else False,
frozenset((k,self._hash(v)) for k, v in kwargs.items()) if kwargs else False,
)
except TypeError:
return None
def get(self, key, parent_key=None):
""" Query cache for info. Raise Missing error if not present """
if key is None:
raise self.Missing("Null key")
if parent_key:
self.link(key, parent_key)
try:
item = self._data_cache[key]
except KeyError:
raise self.Missing("Key not in cache")
else:
item.access_time = time.time()
item.access_count += 1
if item.expired(item.access_time):
self.invalidate(keys=(key,))
raise self.Missing("Key expired")
return item.value
def set(self, key, value, parent_key=None, tags=None, ttl=0):
""" Insert / Update data in the cache for the given key. """
if key is None:
return
with self._lock:
self._data_cache[key] = Item(value, ttl=ttl)
if parent_key:
self.link(key, parent_key)
if tags:
for tag in tags:
self.link(tag, key)
def link(self, tag, key):
""" Link a piece of information with a cache data key """
if key is None:
return
self._links[tag].add(key)
self._links_reverse_lookup[key].add(tag)
def invalidate(self, keys=None, tags=None, lfu=0, lru=0):
"""
Invalidate data that meets specific criteria.
Usage:
>>> cache.invalidate() # Invalidate EVERYTHING!
>>> key = cache.generate_key(func, *args, **kwargs)
>>> cache.invalidate(keys=[key]) # Invalidate a specific key (and dependants)
>>> key = cache.generate_key(func, *args, **kwargs)
>>> cache.link("sometag", key)
>>> cache.invalidate(tags=["sometag"]) # Invalidate all sharing a specific tag
>>> cache.invalidate(lfu=300) # Drop less frequently accessed data when total exceeds given count
"""
with self._lock:
queue = self._filter(lfu, lru, *self._select(keys, tags))
if not queue:
return
# Efficiently clear everything if requested
if len(queue) == len(self._data_cache):
self._data_cache.clear()
self._links.clear()
self._links_reverse_lookup.clear()
return
while queue:
key = queue.pop()
if key is None:
continue
try:
del self._data_cache[key]
except KeyError:
pass
# Track dependencies
queue.update(self._links.pop(key, ()))
for tag in self._links_reverse_lookup.pop(key, ()):
links = self._links[tag]
links.discard(key)
if not links:
del self._links[tag]
@contextlib.contextmanager
def context(self, key):
""" Maintain context among caches, so cache dependencies can be taken into account """
old_key = getattr(self._local, "key", None)
self._local.key = key
try:
yield old_key
finally:
self._local.key = old_key
def cleanup(self):
""" Run over all the data and remove stragglers """
with self._lock:
current_time = time.time()
self._next_cleanup = current_time + self._cleanup_frequency
expired_keys = (
key
for key, item in self._data_cache.items()
if item.expired(current_time)
)
missing_keys = (
key
for key in self._links_reverse_lookup
if key not in self._data_cache
)
self.invalidate(keys=itertools.chain(expired_keys, missing_keys))
def __init__(self):
self._data_cache = {}
self._links = collections.defaultdict(set)
self._links_reverse_lookup = collections.defaultdict(set)
self._next_cleanup = time.time()
thread = threading.Thread(target=self._cleanup_loop)
thread.setDaemon(True)
thread.start()
def _hash(self, obj):
if isinstance(obj, (int, float, bool)):
hashed = obj
elif isinstance(obj, dict):
hashed = frozenset((self._hash(k), self._hash(v)) for k, v in obj.items()) if obj else False
elif isinstance(obj, list):
hashed = tuple(self._hash(v) for v in obj) if obj else False
elif isinstance(obj, set):
hashed = frozenset(self._hash(v) for v in obj) if obj else False
else:
hashed = hash(obj)
return hashed, type(obj)
def _select(self, keys, tags):
if keys is None and tags is None:
return self._data_cache, None
selected_keys = selected_tags = None
if tags:
selected_tags = tuple(map(self._links.__getitem__, tags))
if keys:
selected_keys = keys
return selected_keys, selected_tags
def _filter(self, lfu, lru, selected_keys, selected_tags):
if not selected_keys and not selected_tags:
return None
if not lfu and not lru:
return self._merge_selection(selected_keys, selected_tags)
limit = lfu or lru
# Quick check first
quick_total = 0
if selected_tags:
quick_total += max(map(len, selected_tags))
if selected_keys:
try:
quick_total += len(selected_keys)
except TypeError:
quick_total = -1
if 0 <= quick_total and quick_total < limit:
return None
selection = self._merge_selection(selected_keys, selected_tags)
if len(selection) < limit:
return None
chunk = int(limit * 0.25) # Remove chunk to make up for slow sort
heuristic = operator.attrgetter(
"access_count" if lfu else "access_time"
)
selection, view = itertools.tee(selection)
heuristics = (
heuristic(item) if item else -1
for item in map(self._data_cache.get, view)
)
return set(
map(
operator.itemgetter(1),
heapq.nsmallest(
chunk,
zip(heuristics, selection),
key=operator.itemgetter(0),
)
)
)
def _merge_selection(self, selected_keys, selected_tags):
selection = set.intersection(*selected_tags) if selected_tags else set()
if selected_keys:
selection.update(selected_keys)
return selection
def _cleanup_loop(self):
while True:
current_time = time.time()
if current_time < self._next_cleanup:
time.sleep(self._next_cleanup - current_time)
else:
self.cleanup()
if __name__ == "__main__":
@cache()
def concat(prefix, suffix):
time.sleep(0.5)
return prefix + suffix
@cache()
def double_concat(prefix, suffix):
time.sleep(0.5)
return concat(concat(prefix, suffix), concat(prefix, suffix))
print(">>", double_concat)
cache_ = get_cache()
assert not concat.is_cached("abc", "def"), "We are not cached"
concat("abc", "def")
assert concat.is_cached("abc", "def"), "We are now cached"
concat("abc", "def")
cache_.invalidate()
assert not concat.is_cached("abc", "def"), "We are again uncached"
cache_.invalidate()
double_concat("abc", "def")
assert concat.is_cached("abc", "def"), "Base concat is cached"
assert double_concat.is_cached("abc", "def"), "Outer double concat is cached"
cache_.invalidate(keys=[concat.get_key("abc", "def")])
assert not double_concat.is_cached("abc", "def"), "Dependant data is invalidated"
concat("abc", "def")
concat("efg", "hij")
concat("klm", "nop")
assert concat.is_cached("abc", "def"), "Data is cached"
assert concat.is_cached("efg", "hij"), "Data is cached"
assert concat.is_cached("klm", "nop"), "Data is cached"
cache_.link("tag", concat.get_key("abc", "def"))
cache_.link("tag", concat.get_key("efg", "hij"))
cache_.invalidate(tags=["tag"])
assert not concat.is_cached("abc", "def"), "Data is not cached"
assert not concat.is_cached("efg", "hij"), "Data is not cached"
assert concat.is_cached("klm", "nop"), "Data is still cached"
double_concat("abc", "def")
assert double_concat.is_cached("abc", "def"), "Data is cached"
assert concat.is_cached("abc", "def"), "Data is cached"
double_concat.invalidate()
assert not double_concat.is_cached("abc", "def"), "Data cleared at function level"
assert concat.is_cached("abc", "def"), "Data still cached"
@cache(ttl=1)
def ttl():
return "value"
ttl()
assert ttl.is_cached(), "Data is cached"
time.sleep(1.5)
assert not ttl.is_cached(), "Data no longer cached"
@cache(lfu=5)
def lfu(num):
return num
for i in range(10):
lfu(i)
assert lfu.is_cached(i), i
assert len([i for i in range(10) if lfu.is_cached(i)]) <= 5, "Maintain no more than given number"
@cache(lru=5)
def lru(num):
return num
for i in range(10):
lru(i)
assert lru.is_cached(i), i
assert len([i for i in range(10) if lru.is_cached(i)]) <= 5, "Maintain no more than given number"
@cache()
def nested_hashable(value=None):
return value
nested_hashable(value={1:2})
assert nested_hashable.is_cached(value={1:2}), "Value cached"
cache_.cleanup()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment