Skip to content

Instantly share code, notes, and snippets.

@michaelbartnett
Created January 17, 2013 22:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaelbartnett/4560351 to your computer and use it in GitHub Desktop.
Save michaelbartnett/4560351 to your computer and use it in GitHub Desktop.
mgeutils makes MongoEngine more fun to use ;P
"""mgeutils module
Decorators and convenience functions for using mongoengine
"""
from __future__ import unicode_literals
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import logging
import bson
import mongoengine as mge
import datetime
import dateutil.parser as dateparser
import utils
_MGE_SERIALIZE_TO_UNICODE = (bson.ObjectId,)
_MGE_SERIALIZABLE_DOCUMENT = (mge.Document, mge.EmbeddedDocument)
def _mge_convert_to_json_builtin(obj):
"""Attempts to converted MongoEngine objects to
something that can be serialized to JSON
"""
if isinstance(obj, _MGE_SERIALIZABLE_DOCUMENT):
if hasattr(obj, 'to_dict'):
return obj.to_dict()
logging.error("Got a MongoEngine Document, but it wasn't a superdoc")
elif isinstance(obj, _MGE_SERIALIZE_TO_UNICODE):
return unicode(obj)
elif isinstance(obj, mge.ValidationError):
result = {}
for attr in ('errors', 'field_name', 'message'):
val = getattr(obj, attr)
if isinstance(val, basestring) and val.startswith('ValidationError('):
continue
if val:
result[attr] = val
return result
return utils._convert_to_json_builtin(obj)
def json_encode(value, *args, **kwargs):
"""JSON encoding that supports encoding mongoengine Documents."""
default = kwargs.pop('default', None) or _mge_convert_to_json_builtin
return utils.json_encode(value, default=default)
class LegitDateTimeField(mge.DateTimeField):
def validate(self, value):
if not isinstance(value, (datetime.datetime, datetime.date)):
try:
dateparser.parse(value)
except:
self.error('Could not parse date {0}'.format(value))
def prepare_query_value(self, op, value):
if value is None:
return value
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
return datetime.datetime(value.year, value.month, value.day)
dateparser.parse(value)
return super(LegitDateTimeField, self).prepare_query_value(op, value)
def __ensure_class_not_hasattr(cls, attr_name):
assert not hasattr(cls, attr_name), (
'Class {1} already has {0} attribute. Superdoc needs that name.'
''.format(cls, attr_name))
def superdoc(cls): # Decorator
"""Decorator that mixes in some useful functions for
manipulating MongoEngine documents.
"""
# Make some guarantees. Why decorate if you've
# already defined your helper methods?
assert isinstance(cls, mge.base.DocumentMetaclass), (
'Class {0} decorated by superdoc must be a mongoengine.Document'
''.format(cls.__name))
__ensure_class_not_hasattr(cls, 'field_names')
__ensure_class_not_hasattr(cls, 'reference_fields')
__ensure_class_not_hasattr(cls, 'to_dict')
__ensure_class_not_hasattr(cls, 'update_fields')
@classmethod
@utils.restrict_kwargs('include', 'exclude')
def field_names(cls, **kwargs):
for name in cls._fields.viewkeys():
yield name
if hasattr(cls, '__getattr__'):
old_getattr = cls.__getattr__
def __getattr__(self, attr_name):
try:
if attr_name.endswith('__id'):
realattr = attr_name[:-4]
ref_field = self._data.get(realattr, None)
# Only guarantee success for presence in reference_fields
# and exact type match
if ref_field is None and realattr in self.reference_fields:
# Sometimes a ReferenceField may be present but not set
return None
elif type(ref_field) is bson.ObjectId:
return ref_field
elif type(ref_field) is bson.DBRef:
return ref_field.id
# If old_getattr is not defined, then
# control will pass down to the "raise AttributeError"
return old_getattr(self, attr_name)
elif attr_name.endswith('__dbref'):
realattr = attr_name[:-7]
ref_field = self._data.get(realattr, None)
# Only makes sense to return DBRef objects
if ref_field is None and realattr in self.reference_fields:
return None
if type(ref_field is bson.DBRef):
return ref_field
# If old_getattr is not defined, then
# control will pass down to the "raise AttributeError"
return old_getattr(self, attr_name)
except:
# Just continue to the AttributeError
pass
raise AttributeError(
"'{0}' object has no attribute '{1}'"
"".format(type(self).__name__, attr_name))
def update_fields(self, **kwargs):
field_set = self.get_field_set(exclude='id')
for k in kwargs:
if k in field_set:
setattr(self, k, kwargs[k])
@classmethod
@utils.restrict_kwargs('include', 'exclude')
def get_field_set(cls, **kwargs):
defined_fields = frozenset(cls.field_names())
include = kwargs.get('include', defined_fields)
exclude = kwargs.get('exclude', None)
if utils.is_iter_not_str(include):
include_set = frozenset(include)
else:
include_set = frozenset((include,))
if utils.is_iter_not_str(exclude):
exclude_set = frozenset(exclude)
else:
exclude_set = frozenset((exclude,))
field_set = include_set - exclude_set
if not field_set <= defined_fields:
raise ValueError(
'The fields {0} are not defined in {1}'
''.format(field_set - defined_fields, cls))
return field_set
@utils.restrict_kwargs('include', 'exclude', 'include_nulls', 'include_empties')
def to_dict(self, **kwargs):
include_nulls = kwargs.pop('include_nulls', True)
include_empties = kwargs.pop('include_empties', True)
field_set = self.get_field_set(**kwargs)
result = {}
for fieldname in field_set:
if fieldname in self.reference_fields:
result[fieldname] = getattr(self, '{0}__id'.format(fieldname))
continue
if fieldname == 'id':
value = self.id
else:
field = self._fields[fieldname]
value = getattr(self, fieldname)
if value or (include_empties and value is not None):
result[fieldname] = field.to_python(value)
elif include_nulls and value is None:
result[fieldname] = None
else:
raise ValueError("Unexpected type found in to_dict call")
return result
the_ref_fields = {
field for field in cls._fields.viewkeys()
if type(cls._fields[field]) is mge.ReferenceField}
cls.reference_fields = property(lambda cls: the_ref_fields)
cls.__getattr__ = __getattr__
cls.field_names = field_names
cls.to_dict = to_dict
cls.update_fields = update_fields
cls.get_field_set = get_field_set
return cls
""""utils module
Module containing helper functions used throughout
the project, and not tied to a specify library other
than what's in the stdlib (2.7).
"""
from __future__ import unicode_literals
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import inspect
import functools
import json
import collections
import datetime
from tornado import escape
_RESTRICT_KWARGS_MSG = """\
{fn_name}() does not support a kwarg named {kwarg_name}
Supported kwargs are: {supported_kwargs}
File "{filename}", line {lineno}, in {caller_fn_name}
Context:
{context}
"""
_REQUIRE_KWARGS_MSG = """\
{fn_name}() missing a kwarg named {kwarg_name}.
Required kwargs are: {required_kwargs}
File "{filename}", line {lineno}, in {caller_fn_name}
Context:
{context}
"""
def is_iter_not_str(arg):
"""Helper to concisely check if an object is iterable, but not a string."""
return (isinstance(arg, collections.Iterable) and
not isinstance(arg, basestring))
def flatten(iter):
"""Quickly flatten an iterable. Returns a generator."""
for el in iter:
if is_iter_not_str(el):
for sub in flatten(el):
yield sub
else:
yield el
def flatten_and_split(strings, separator=','):
"""For a list of [separator]-separated strings,
split the strings by [separator] and flatten the resultant
lists into one big list.
"""
for splitstr in (s.split(',') for s in flatten(strings)):
for string in splitstr:
if string:
yield string
def _convert_to_json_builtin(value):
if isinstance(value, datetime.datetime):
return value.isoformat()
elif isinstance(value, collections.Set):
return list(value)
raise TypeError('Could not serialize type {0}'.format(type(value)))
def json_encode(value, *args, **kwargs):
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javscript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
default = kwargs.pop('default', None) or _convert_to_json_builtin
value = escape.recursive_unicode(value)
json_str = json.dumps(value, default=default, *args, **kwargs)
return json_str.replace("</", "<\\/")
def json_decode(s):
"""JSON-decodes a json-encoded string into a Python object (dict for now).
"""
return json.loads(s)
def restrict_kwargs(*supported_kwargs):
"""Raises TypeError if the kwargs passed into the function do not match
the list of arguments that you specify. Gets you part of the effect of
keyword-only arguments in Python 2.7.x.
"""
def decorator(fn):
@functools.wraps(fn)
def fn_with_kwargs_restriction(*args, **kwargs):
for kwarg_name in kwargs.viewkeys():
if kwarg_name not in supported_kwargs:
stackinfo_keys = ('_fname', 'filename', 'lineno',
'caller_fn_name', 'context', '_deth')
stackinfo = inspect.stack()[1]
msg_dict = dict(zip(stackinfo_keys, stackinfo))
msg_dict.update({
'context': ''.join(msg_dict['context']),
'fn_name': fn.__name__,
'kwarg_name': kwarg_name,
'supported_kwargs': supported_kwargs,
})
raise TypeError(_RESTRICT_KWARGS_MSG.format(**msg_dict))
return fn(*args, **kwargs)
return fn_with_kwargs_restriction
return decorator
def require_kwargs(*required_kwargs):
"""Raises TypeError if the the kwargs passed into the function do not
contain all of the arguments that you specify. Gets you part of the
effect of keyword-only arguments in Python 2.7.x.
"""
def decorator(fn):
@functools.wraps(fn)
def fn_with_kwargs_requirement(*args, **kwargs):
kwarg_keys = list(kwargs.viewkeys())
for kwarg_name in required_kwargs:
if kwarg_name not in kwarg_keys:
stackinfo_keys = ('_fname', 'filename', 'lineno',
'caller_fn_name', 'context', '_deth')
stackinfo = inspect.stack()[1]
msg_dict = dict(zip(stackinfo_keys, stackinfo))
msg_dict.update({
'context': ''.join(msg_dict['context']),
'fn_name': fn.__name__,
'kwarg_name': kwarg_name,
'required_kwargs': required_kwargs,
})
raise TypeError(_REQUIRE_KWARGS_MSG.format(**msg_dict))
return fn(*args, **kwargs)
return fn_with_kwargs_requirement
return decorator
def shallow_memoize(fn):
lookup = {}
@functools.wraps(fn)
def memoized_func(*args, **kwargs):
arg_tuple = (args, tuple(kwargs.viewkeys()), tuple(kwargs.viewvalues()))
if arg_tuple in lookup:
return lookup[arg_tuple]
result = fn(*args, **kwargs)
lookup[arg_tuple] = result
return result
return memoized_func
@shallow_memoize
def sparse_bitcount(val, abs_when_negative=True):
count = 0
if abs_when_negative:
val = abs(val)
elif val < 0:
raise ValueError(
"Either specify abs_when_negatve=True, or "
"don't pass negative values to count_bits.")
while val:
val &= val - 1
count += 1
return count
@michaelbartnett
Copy link
Author

Future readers, please note that this thing is a great sadness that I would not wish upon anyone. It has destroyed so many lives, and none yet know the full extent of the damage it has wrought.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment