Created
July 19, 2012 22:49
-
-
Save valtron/3147414 to your computer and use it in GitHub Desktop.
Caches simple queries on one or more unique fields
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
class CachedManager(models.Manager): | |
""" | |
Caches accesses in the form | |
Model.objects.get(field = value) | |
for each field in cached_fields. | |
Cache is invalidated on any delete or update operation, | |
regardless of whether anything was actually deleted or updated. | |
""" | |
use_for_related_fields = True | |
def __init__(self, cached_fields = [], *args, **kwargs): | |
super(CachedManager, self).__init__(*args, **kwargs) | |
cached_fields = ['id', 'pk'] + cached_fields | |
self._cache = Cache(cached_fields) | |
def get_query_set(self): | |
return CachedQuerySet(self._cache, self.model, using = self._db) | |
class CachedQuerySet(models.query.QuerySet): | |
def __init__(self, cache = None, *args, **kwargs): | |
super(CachedQuerySet, self).__init__(*args, **kwargs) | |
self._cache = cache | |
def get(self, *args, **kwargs): | |
if not self._cache or args or len(kwargs) != 1: | |
return super(CachedQuerySet, self).get(*args, **kwargs) | |
(field, value) = list(kwargs.items())[0] | |
uncached_get = lambda: super(CachedQuerySet, self).get(**kwargs) | |
return self._cache.get(field, value, uncached_get) | |
def delete(self): | |
try: | |
super(CachedQuerySet, self).delete() | |
finally: | |
self._invalidate_cache() | |
def update(self, **kwargs): | |
try: | |
return super(CachedQuerySet, self).update(**kwargs) | |
finally: | |
self._invalidate_cache() | |
def _clone(self, *args, **kwargs): | |
clone = super(CachedQuerySet, self)._clone(*args, **kwargs) | |
clone._cache = self._cache | |
return clone | |
def _invalidate_cache(self): | |
if self._cache: | |
self._cache.clear() | |
class Cache(object): | |
def __init__(self, cached_fields): | |
self._cached_fields = cached_fields | |
self._cache = self._create_cache() | |
def get(self, field, value, uncached_get): | |
if field.endswith('__exact'): | |
field = field[:-7] | |
if field not in self._cache: | |
return uncached_get() | |
field_cache = self._cache[field] | |
if value not in field_cache: | |
obj = uncached_get() | |
self._add_to_cache(obj) | |
# hack to get around '1' vs. 1 problem | |
field_cache[value] = obj | |
return field_cache[value] | |
def clear(self): | |
self._cache = self._create_cache() | |
def _add_to_cache(self, obj): | |
for field in self._cached_fields: | |
self._cache[field][getattr(obj, field)] = obj | |
def _create_cache(self): | |
return { field: {} for field in self._cached_fields } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment