Skip to content

Instantly share code, notes, and snippets.

@commandodev
Last active September 11, 2018 20:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save commandodev/5108455 to your computer and use it in GitHub Desktop.
Save commandodev/5108455 to your computer and use it in GitHub Desktop.
Rest traversal in pyramid. With a small example of usage.
from pyramid.view import view_config
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker, object_mapper, ColumnProperty, SynonymProperty
Session = scoped_session(sessionmaker())
class _PrettyPrintBase(object):
"""Base mixin for all of our declarative tables
.. note:: Don't use this directly it's a mixin to be used with
:func:`~sqlalchemy.ext.declarative.declarative_base`
"""
query = Session.query_property()
def __str__(self):
return self._pk
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)
@property
def _pk(self):
if hasattr(self, '__table__'):
om = object_mapper(self)
pk_cols = om.primary_key
pk_column_keys = [p.key for p in om.iterate_properties
if isinstance(p, ColumnProperty)
and p.columns[0] in pk_cols]
pk_to_val = [(k, getattr(self, k)) for k in pk_column_keys]
return ', '.join('%s=%s' % (k, v if v else 'NONE') for k, v in pk_to_val)
return "No Table"
Base = declarative_base(cls=_PrettyPrintBase)
class BaseTraverser(object):
base = None
def getitem(self, item):
raise KeyError
def __getitem__(self, item):
resource = self.getitem(item)
resource.__name__ = item
resource.__parent__ = self
return resource
class DBTraverser(BaseTraverser):
def getitem(self, item):
model = get_model(self.base, item)
return ModelResource(item, model)
def model_to_dict(model_inst):
"""Generic function to convert a database model instance to a dict
This is to enable serialization to json at a later stage
"""
mapper = object_mapper(model_inst)
props = list(mapper.iterate_properties)
synonyms = [p for p in props if isinstance(p, SynonymProperty)]
synonymed_column_names = [s.name for s in synonyms]
keys = [p.key for p in props
if isinstance(p, ColumnProperty)
and p.key not in synonymed_column_names] +\
[p.key for p in synonyms]
association_proxies = [k for k, v in model_inst.__class__.__dict__.items()
if isinstance(v, AssociationProxy)]
def get_ap_attr(model, attr):
try:
model_attr = getattr(model, attr)
except AttributeError:
return None
else:
try:
return attr, model_attr.copy()
except AttributeError:
return attr, model_attr
ap_mapping = [x for x in [get_ap_attr(model_inst, key) for key in association_proxies] if x]
return dict([(key, getattr(model_inst, key)) for key in keys] + ap_mapping)
def get_model(base, name):
"""Look up a table class based on it's name
:param base: A :ref:`sqlalchemy:declarative_toplevel` base class
:type: Subclass of :class:`_PrettyPrintBase``
:param name: The name of the class
"""
return base._decl_class_registry[name]
class ModelResource(BaseTraverser):
def __init__(self, name, Model):
self.name = name
self.ses = Session()
self.Model = Model
@property
def q(self):
return self.ses.query(self.Model)
def getitem(self, primary_key):
return ItemResource(self.name, self.q.get(int(primary_key)))
class ItemResource(BaseTraverser):
def __init__(self, model_name, model_instance):
self.name = model_name
self.model = model_instance
def getitem(self, relation_name):
try:
return RelationResource(getattr(self.model, relation_name))
except AttributeError:
raise KeyError("No relation %s" % relation_name)
class RelationResource(object):
def __init__(self, child_list):
self.children = child_list
def model_is(model_class):
"""Pyramid custom prediacate that matches a specific Model"""
def _model_is(context, request):
return hasattr(context, "Model") and context.Model is model_class
_model_is.__name__ = "model_is_%s" % model_class.__name__
return _model_is
@view_config(context=ModelResource, renderer="safe_json")
def list_model(context, request):
q = context.q
filters = request.GET.get("q")
if filters:
q = q.filter_by(**json.loads(filters))
return [model_to_dict(mkt) for mkt in q.all()]
@view_config(context=ItemResource, renderer="safe_json")
def model_detail(context, request):
return model_to_dict(context.model)
@view_config(context=RelationResource, renderer="safe_json")
def list_related(context, request):
return [model_to_dict(mkt) for mkt in context.children]
@view_config(context=DBTraverser, name="models", renderer="safe_json")
def models(context, request):
return context.base._decl_class_registry.keys()
from pyramid.view import view_config
from rest_traversal import Base, DBTraverser, BaseTraverser
class MyTraverser(DBTraverser):
base = Base
class Root(BaseTraverser):
__parent__ = None
__name__ = None
ROUTES = {
"db": DBTraverser()
}
def getitem(self, item):
return self.ROUTES[item]
root = Root()
def app_root_factory(request):
return root
class MyModel(Base):
__tablename__ = "a_table"
__table_args__ = (
dict(schema='a_schema', extend_existing=True)
)
# columns for a_table go here
@view_config(context=ModelResource, request_method="POST", renderer="safe_json",
custom_predicates=[model_is(MyModel)])
def list_model(context, request):
# special logic for MyModel
return dict()
def main(global_conf, **settings):
""" This function returns a WSGI application."""
settings.update(global_conf)
engine = engine_from_config(settings)
initialize_sql(engine, Base)
config = Configurator(settings=settings, root_factory=app_root_factory)
config.include("pyramid_jinja2")
config.add_view('pyramid.view.append_slash_notfound_view',
context='pyramid.httpexceptions.HTTPNotFound')
config.scan()
return config.make_wsgi_app()
def initialize_sql(engine, base):
base.metadata.bind = engine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment