Skip to content

Instantly share code, notes, and snippets.

@shazow
Created December 17, 2009 00:59
Show Gist options
  • Save shazow/258394 to your computer and use it in GitHub Desktop.
Save shazow/258394 to your computer and use it in GitHub Desktop.
BaseModel object for SQLAlchemy declarative object abstractions (with query statistics tracking and JSON decoder).
"""SQLAlchemy Metadata and Session object"""
from sqlalchemy import MetaData
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.session import Session as SessionBase
from sqlalchemy.interfaces import ConnectionProxy
from datetime import datetime
import time
__all__ = ['Session', 'metadata', 'BaseModel']
# Query statistics tracking in Session and Sessions setup
class QueryStats(object):
def __init__(self, query_count=0, time_elapsed=0):
self.query_count = query_count
self.time_elapsed = time_elapsed
self.queries = []
def add_query(self, statement, elapsed):
self.queries += [(statement, elapsed)]
self.time_elapsed += elapsed
self.query_count += 1
def __repr__(self):
return "%s(query_count=%d, time_elapsed=%0.4f)" % (self.__class__.__name__, self.query_count, self.time_elapsed)
class QueryStatsProxy(ConnectionProxy):
"""
When creating the engine...
engine = create_engine("...", proxy=QueryStatsProxy())
"""
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
now = time.time()
try:
return execute(cursor, statement, parameters, context)
finally:
elapsed = time.time() - now
Session().stats.add_query(statement, elapsed)
class SessionStatsBase(SessionBase):
"""
Add a stats property to the scoped Session object.
"""
def __init__(self, *args, **kw):
SessionBase.__init__(self, *args, **kw)
self.stats = QueryStats()
Session = scoped_session(sessionmaker(class_=SessionStatsBase))
metadata = MetaData()
# Declarative base
from sqlalchemy.ext.declarative import declarative_base
class _Base(object):
_repr_hide = ['time_created', 'time_updated']
_json_hide = ['time_updated']
@classmethod
def get(cls, id):
return Session.query(cls).get(id)
@classmethod
def get_by(cls, **kw):
return Session.query(cls).filter_by(**kw).first()
@classmethod
def get_or_create(cls, **kw):
r = cls.get_by(**kw)
if not r:
r = cls(**kw)
Session.add(r)
return r
@classmethod
def create(cls, **kw):
r = cls(**kw)
Session.add(r)
return r
def delete(self):
Session.delete(self)
def _is_loaded(self, key):
return key in self.__dict__
def _is_loaded_all(self, path):
"""
Check if the given path of properties are eager-loaded.
`path` is similar to sqlalchemy.orm.eagerload_all, checking happens
by inspecting obj.__data__.
"""
current = self
for k in path.split('.'):
if not current._is_loaded(k):
return False
current = getattr(current, k)
if not current:
return False
if isinstance(current, list):
current = current[0]
return True
def __repr__(self):
values = ', '.join("%s=%r" % (n, getattr(self, n)) for n in self.__table__.c.keys() if n not in self._repr_hide)
return "%s(%s)" % (self.__class__.__name__, values)
def __json__(self):
## Only include local table attributes:
#return dict((n, getattr(self, n)) for n in self.__table__.c.keys() if n not in self._json_hide)
## Include loaded relations recursively:
return dict((n,v) for n,v in self.__dict__.iteritems() if not n.startswith('_') and n not in self._json_hide)
BaseModel = declarative_base(metadata=metadata, cls=_Base)
# JSON encoding of SQLAlchemy-returned objects:
import json
class SchemaEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, BaseModel):
return obj.__json__()
if isinstance(obj, datetime):
return obj.isoformat()
return json.JSONEncoder.default(self, obj)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment