Last active
December 23, 2015 06:59
-
-
Save diegows/6597765 to your computer and use it in GitHub Desktop.
CSDocument class, a class to be used with Mongoengine's Document class to create documents. It allows DB switch on the fly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import warnings | |
from collections import defaultdict | |
from mongoengine import * | |
from mongoengine.connection import get_db | |
def dbconnect(company): | |
global company_name | |
connect('cs_' + company, company, w=0, safe=False) | |
CSDocument.switch_active_db(company) | |
CSDocument.company = company | |
def current_db(): | |
return CSDocument.active_db | |
def current_company(): | |
return CSDocument.company | |
class CSMongoError(Exception): | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return self.value | |
class CSDocument(object): | |
_collections = defaultdict(lambda: None) | |
active_db = None | |
#Set this to the name of the share DB in classes that live there | |
shared_db = None | |
def usave(self, **kwargs): | |
"""Backward compatibility, use save(). | |
""" | |
warnings.warn("Don't use usave() anymore, use save().", | |
category=DeprecationWarning) | |
self.save(**kwargs) | |
def save(self, **kwargs): | |
self.csvalidate() | |
super(CSDocument, self).save(**kwargs) | |
def csvalidate(self): | |
"""Rewrite this function in a derived class if you need some | |
custom validation. | |
""" | |
pass | |
def __repr__(self): | |
return str(self) | |
@classmethod | |
def switch_active_db(cls, dbname): | |
cls.active_db = dbname | |
@classmethod | |
def _get_db(cls): | |
if cls.shared_db is None: | |
return get_db(cls.active_db) | |
return get_db(cls.shared_db) | |
@classmethod | |
def _build_key(cls, dbname, collection_name): | |
#XXX: check if we really need this per thread thing!! | |
thread_id = threading.currentThread().ident | |
key = '%s-%s-%s' % (thread_id, dbname, collection_name) | |
return key | |
@classmethod | |
def _get_collection(cls): | |
"""Custom MongoEngine function to support global DB switch thread-safe. | |
""" | |
collection_name = cls._get_collection_name() | |
if cls.shared_db is None: | |
key = cls._build_key(cls.active_db, collection_name) | |
else: | |
key = cls._build_key(cls.shared_db, collection_name) | |
if cls._collections[key] is None: | |
db = cls._get_db() | |
cls._collection = db[collection_name] | |
cls._collections[key] = db[collection_name] | |
if cls._meta.get('auto_create_index', True): | |
cls.ensure_indexes() | |
return cls._collections[key] | |
@classmethod | |
def create(cls, **kwargs): | |
"""Create and save a new obj in the DB with information provided in | |
kwargs. | |
""" | |
obj = cls(**kwargs) | |
obj.save() | |
return obj | |
@classmethod | |
def get(cls, **kwargs): | |
"""Get the document from the DB or create it using **kwargs if it | |
doesn't exists | |
if kwargs has create=True, the document is created with the information | |
provided in kwargs. | |
There is a get_or_create() function in MongoEngine but it's deprecated. | |
""" | |
create = False | |
if kwargs.has_key('create'): | |
create = kwargs['create'] | |
del kwargs['create'] | |
obj = cls.objects(**kwargs) | |
count = obj.count() | |
if count == 0 and create: | |
obj = cls.create(**kwargs) | |
elif count > 1: | |
raise CSMongoError("More than one document found for " + str(cls)) | |
else: | |
obj = obj.first() | |
return obj | |
@classmethod | |
def get_index_by_name(cls, name): | |
"""Get index spec to be used in .hint() cursor method. | |
Example: | |
hint = mdb.Rcpt.get_index_by_name('sort0') | |
cursor = mdb.Rcpt.objects.hint(hint) | |
""" | |
for index_spec in cls._meta['index_specs']: | |
index_name = index_spec.get('name') | |
if index_name and index_name == name: | |
return index_spec['fields'] | |
return None | |
def dump(self, *attrs): | |
ret = [] | |
for attr in attrs: | |
value = getattr(self, attr) | |
if type(value) != str: | |
value = str(value) | |
ret.append('%s=%s' % (attr, value)) | |
return " ".join(ret) | |
@property | |
def tid(self): | |
return str(self.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from base import * | |
from mongoengine import * | |
class Data(CSDocument, Document): | |
info = StringField() | |
moreinfo = StringField() | |
dbconnect("somedb") | |
data = Data() | |
do_something(data) | |
data.save() | |
dbconnect("anotherdb") | |
data = Data() | |
do_something(data) | |
data.save() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment