Skip to content

Instantly share code, notes, and snippets.

@valtron
Created July 19, 2012 22:49
Show Gist options
  • Save valtron/3147414 to your computer and use it in GitHub Desktop.
Save valtron/3147414 to your computer and use it in GitHub Desktop.
Caches simple queries on one or more unique fields
from django.db import models
class CachedManager(models.Manager):
"""
Caches accesses in the form
Model.objects.get(field = value)
for each field in cached_fields.
Cache is invalidated on any delete or update operation,
regardless of whether anything was actually deleted or updated.
"""
use_for_related_fields = True
def __init__(self, cached_fields = [], *args, **kwargs):
super(CachedManager, self).__init__(*args, **kwargs)
cached_fields = ['id', 'pk'] + cached_fields
self._cache = Cache(cached_fields)
def get_query_set(self):
return CachedQuerySet(self._cache, self.model, using = self._db)
class CachedQuerySet(models.query.QuerySet):
def __init__(self, cache = None, *args, **kwargs):
super(CachedQuerySet, self).__init__(*args, **kwargs)
self._cache = cache
def get(self, *args, **kwargs):
if not self._cache or args or len(kwargs) != 1:
return super(CachedQuerySet, self).get(*args, **kwargs)
(field, value) = list(kwargs.items())[0]
uncached_get = lambda: super(CachedQuerySet, self).get(**kwargs)
return self._cache.get(field, value, uncached_get)
def delete(self):
try:
super(CachedQuerySet, self).delete()
finally:
self._invalidate_cache()
def update(self, **kwargs):
try:
return super(CachedQuerySet, self).update(**kwargs)
finally:
self._invalidate_cache()
def _clone(self, *args, **kwargs):
clone = super(CachedQuerySet, self)._clone(*args, **kwargs)
clone._cache = self._cache
return clone
def _invalidate_cache(self):
if self._cache:
self._cache.clear()
class Cache(object):
def __init__(self, cached_fields):
self._cached_fields = cached_fields
self._cache = self._create_cache()
def get(self, field, value, uncached_get):
if field.endswith('__exact'):
field = field[:-7]
if field not in self._cache:
return uncached_get()
field_cache = self._cache[field]
if value not in field_cache:
obj = uncached_get()
self._add_to_cache(obj)
# hack to get around '1' vs. 1 problem
field_cache[value] = obj
return field_cache[value]
def clear(self):
self._cache = self._create_cache()
def _add_to_cache(self, obj):
for field in self._cached_fields:
self._cache[field][getattr(obj, field)] = obj
def _create_cache(self):
return { field: {} for field in self._cached_fields }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment