Skip to content

Instantly share code, notes, and snippets.

@mcm
Created April 12, 2015 23:03
Show Gist options
  • Save mcm/c2c4503d59daae521a8c to your computer and use it in GitHub Desktop.
Save mcm/c2c4503d59daae521a8c to your computer and use it in GitHub Desktop.
import cerberus
import cerberus.errors
from bson.objectid import ObjectId
def SchemaValidationError(BaseException):
def __init__(self, *args, errors=None, **kwargs):
super().__init__(*args, **kwargs)
if errors != None:
self.errors = errors
class MongoValidator(cerberus.Validator):
def _validate_type_oid(self, field, value):
from bson.objectid import ObjectId
if not isinstance(value, ObjectId):
self._error(field, cerberus.errors.ERROR_BAD_TYPE % 'ObjectId')
def get_number_of_pages(total_entries, per_page):
pages = total_entries / per_page
if total_entries % per_page > 0:
pages += 1
return pages
class BaseMongoEndpoint:
collection_name = None
schema = None
def __init__(self, database):
self.database = database
@property
def collection(self):
return getattr(self.database, self.collection_name)
@property
def validator(self):
if hasattr(self, "_validator") and callable(self._validator):
return self._validator
elif hasattr(self, "schema") and self.schema is not None:
self._validator = MongoValidator(self.schema)
return self._validator
raise AttributeError
def validate(self, doc):
try:
validated = self.validator.validate(doc)
except AttributeError:
validated = None
if validated is False:
raise SchemaValidationError(self.validator.errors)
def get_data(self, req):
if "id" in req.params:
# Get one
queryset = self.collection.find(ObjectId("id"))
else:
if "filter" in req.context:
# Get some
queryset = self.collection.find(req.context["filter"])
else:
# Get all
queryset = self.collection.find({})
if "pagination" in req.context:
pagination_settings = req.context["pagination"]
if "sort_by" in pagination_settings:
queryset.sort(pagination_settings["sort_by"], pagination_settings["order"])
queryset.skip((pagination_settings["page"] - 1) * pagination_settings["per_page"])
queryset.limit(pagination_settings["per_page"])
req.context["pagination"] = {
"page": str(pagination_settings["page"]),
"per_page": str(pagination_settings["per_page"]),
"total_entries": str(queryset.count()),
"total_pages": str(get_number_of_pages(queryset.count(), pagination_settings["per_page"]))
}
if "sort_by" in querystring:
req.context["pagination"].update({
"sort_by": pagination_settings["sort_by"],
"order": str(pagination_settings["order"]),
})
req["data"] = list(queryset)
def new(self, req):
pass
def update(self, req):
pass
def delete(self, req):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment