Skip to content

Instantly share code, notes, and snippets.

@ItsDrike
Last active January 7, 2022 09:48
Show Gist options
  • Save ItsDrike/4507190723a8887d2b0668c06a79cc94 to your computer and use it in GitHub Desktop.
Save ItsDrike/4507190723a8887d2b0668c06a79cc94 to your computer and use it in GitHub Desktop.
Automatic caching for select methods of a python class
from __future__ import annotations
from typing import cast, Callable
from functools import wraps
_MISSING = object() # Sentinel value to avoid conflicts with vars set to None
class AutoCacheMeta(type):
def __new__(cls: type[AutoCacheMeta], name: str, bases: tuple[type[object]], clsdict: dict[str, object], **kwargs):
allow_missing_cache = kwargs.pop("allow_missing_cache", False)
clsobj = super().__new__(cls, name, bases, clsdict, **kwargs)
cache = cls.get_cache(clsobj, allow_missing_cache)
cls.cache_methods(clsobj, cache)
return clsobj
@staticmethod
def get_cache(clsobj: object, allow_missing: bool) -> tuple[str]:
"""Ensure `cache` is present in kwargs and has the correct type."""
cache = getattr(clsobj, "_cached", _MISSING)
# If cache isn't defined and it's allowed, set it to empty tuple
if allow_missing and cache is _MISSING:
cache = tuple()
if cache is _MISSING:
raise ValueError("AutoCacheMeta requires _cached class variable.")
if not isinstance(cache, tuple):
raise TypeError("_cached class variable must be a tuple of string method names.")
if not all(isinstance(el, str) for el in cache):
raise TypeError("_cached class variable can only contain strings (method names)")
return cast(tuple[str], cache)
@staticmethod
def is_descriptor(obj: object) -> bool:
return any((
hasattr(obj, "__get__"),
hasattr(obj, "__set__"),
hasattr(obj, "__del__"),
))
@classmethod
def _cache_descriptor(cls: type[AutoCacheMeta], clsobj: object, name: str) -> None:
"""Decorate descriptor's internal functions to allow for caching."""
descriptor = getattr(clsobj, name)
if isinstance(descriptor, property):
getter_f = descriptor.fget
setter_f = descriptor.fset
deleter_f = descriptor.fdel
else:
getter_f = getattr(descriptor, "__get__", None)
setter_f = getattr(descriptor, "__set__", None)
deleter_f = getattr(descriptor, "__del__", None)
if getter_f is None:
raise ValueError(f"Triedt to cache getter function of '{name}' descriptor, but it wasn't defined.")
if setter_f is not None:
# TODO: We need to make a function that overrides setter which
# when called, resets the cache for the descriptor's getter
raise NotImplementedError("Setter functionality isn't yet supported with caching")
if deleter_f is not None:
# TODO: We need to make a function that overrides deleter which
# when called, removes the cache for the descriptor's geter
raise NotImplementedError("Deleter functionality isn't yet supported with caching")
# If we're dealing with a property, we need to make a new one, since it's
# descriptor functions are read-only.
if isinstance(descriptor, property):
new = property(fset=descriptor.fset, fdel=descriptor.fdel, fget=cls.cache(getter_f))
return setattr(clsobj, name, new)
# We're caching a general descriptor, not a property
return setattr(descriptor, "__get__", cls.cache(getter_f))
@classmethod
def cache_methods(cls: type[AutoCacheMeta], clsobj: object, cache: tuple[str]) -> None:
"""Decorate specified methods to cache with memoization decorator."""
for name in cache:
attribute = getattr(clsobj, name, _MISSING)
if attribute is _MISSING:
raise AttributeError(f"Tried to cache non-existent attribute: '{name}'.")
if callable(attribute):
# Caching methods without self isn't possible without class-bound cache,
# we're only using instance-bound cache here though.
if isinstance(attribute, (staticmethod, classmethod)):
raise NotImplementedError("Can't cache static/class methods, they can't access the instance-bound cache.")
print(f"Found callable: {name}")
setattr(clsobj, name, cls.cache(attribute))
continue
if cls.is_descriptor(attribute):
print(f"Found descriptor: {name}")
cls._cache_descriptor(clsobj, name)
continue
raise TypeError(f"Tried to cache non-callable attribute (can only cache methods/descriptors): '{name}'.")
@staticmethod
def cache(func: Callable) -> Callable:
"""Decorator for methods which should be cached."""
kwd_mark = object() # Sentinel for separating args from kwargs
@wraps(func)
def wrapper(self, *args, **kwargs):
h = hash(self)
if not hasattr(self, f"_{__class__.__name__}__hash"):
print("Making hash")
self.__hash = h
self.__cache = {}
if self.__hash != h:
print("Hash changed! Resetting cache")
self.__hash = h
self.__cache = {}
if func not in self.__cache:
print("Populating func dict")
self.__cache[func] = {}
key = args + (kwd_mark,) + tuple(sorted(kwargs.items()))
if key not in self.__cache[func]:
print(f"Called func, not found in cache ({key})")
val = func(self, *args, **kwargs)
self.__cache[func][key] = val
print("From cache")
return self.__cache[func][key]
return wrapper
class AutoCacheMixin(metaclass=AutoCacheMeta, allow_missing_cache=True):
_cached: tuple[str]
@ItsDrike
Copy link
Author

ItsDrike commented Jan 7, 2022

Note: This code snippet is licensed under the MIT license. (i.e. you can use this basically anywhere, you can sublicense it, etc. so long as you mention the original source). https://spdx.org/licenses/MIT.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment