Skip to content

Instantly share code, notes, and snippets.

@groner
Created September 3, 2010 01:37
Show Gist options
  • Save groner/563270 to your computer and use it in GitHub Desktop.
Save groner/563270 to your computer and use it in GitHub Desktop.
YAML/SQLAlchemy fixture loader inspired by Chris Perkins' bootalchemy
%TAG !foo! tag:gronr.com,2010:
---
- Group:
- name: blues
- name: reds
- name: squares
%TAG !foo! tag:gronr.com,2010:
---
__aliases__: # Container for loading references in
- &BLUES_GROUP !foo!ref:Group { name: blues }
- &REDS_GROUP !foo!ref:Group { name: reds }
- &SQUARES_GROUP !foo!ref:Group { name: squares }
- User: # Define some User objects
- name: tom
groups: # Existing groups
- *BLUES_GROUP
emails: # The emails relation will be examined to determine what type of
# object to construct
- type: Work
address: tom@example.com
- type: Home
address: hometom@example.com
- name: matt
groups:
- *REDS_GROUP
- *SQUARES_GROUP
- name: otto
groups:
- *SQUARES_GROUP
import logging
import pylons.test
from itertools import groupby
import yaml
from sqlalchemy.orm import class_mapper, Mapper, RelationProperty
from sqlalchemy.orm.exc import UnmappedClassError, MultipleResultsFound
from <project>.config.environment import load_environment
from <project>.model.meta import Session, Base
import <project>.model as m
log_ = logging.getLogger(__name__)
def get_model_class(cls):
if isinstance(cls, (str, unicode)):
try:
cls = getattr(m, cls)
except AttributeError, e:
raise ValueError('no model named %r exists' % (cls,))
try:
mapper = class_mapper(cls)
except UnmappedClassError, e:
raise TypeError('class %r is not a mapped class' % (cls,))
return cls, mapper
def make_object(cls, d, seen=None):
if isinstance(d, Base):
return d
if seen is None:
seen = {}
if id(d) in seen:
return seen[id(d)]
cls, mapper = get_model_class(cls)
seen[id(d)] = obj = cls()
for k,v in d.items():
prop = mapper.get_property(k)
#TODO: composites
if isinstance(prop, RelationProperty):
# TODO: collections
rcls = get_relation_property_class(prop)
if prop.uselist:
setattr(obj, k, [ make_object(rcls, v, seen) for v in v ])
else:
setattr(obj, k, make_object(rcls, v, seen))
else:
setattr(obj, prop.key, v)
return obj
def get_relation_property_class(prop):
if isinstance(prop.argument, Mapper):
return prop.argument.class_
if isinstance(prop.argument, type):
return prop.argument
if callable(prop.argument):
return prop.argument()
class ORMLoader(yaml.Loader):
def construct_reference(loader, suffix, node):
cls, mapper = get_model_class(suffix)
kw = loader.construct_mapping(node)
try:
return cls.query.filter_by(**kw).one()
except MultipleResultsFound, e:
log_.error('!foo!ref:%s %r found multiple rows', suffix, kw)
raise
ORMLoader.add_multi_constructor('tag:gronr.com,2010:ref:',
ORMLoader.construct_reference)
def setup_app(command, conf, vars):
# Don't reload the app if it was loaded under the testing environment
if not pylons.test.pylonsapp:
load_environment(conf.global_conf, conf.local_conf)
# Create the tables if they don't already exist
Base.metadata.create_all(bind=Session.bind)
for fn in conf['init_data'].splitlines():
fn = fn.strip()
if not fn:
continue
if fn.endswith('.yaml'):
load_objects(fn)
elif fn.endswith('.sql'):
log_.info("Executing SQL statements in %s", fn)
db = Session.bind.raw_connection()
sql = file(fn).read()
db.executescript(sql)
Session.commit()
def load_objects(fn):
log_.info("Reading objects in %s", fn)
data = yaml.load(file(fn), Loader=ORMLoader)
objects = {}
assert isinstance(data, list), \
'expected a list of mappings, not a %r' % type(data)
for d in data:
assert isinstance(d, dict), \
'expected a single-entry mapping, not a %r' % type(d)
assert len(d) == 1, \
'expected a single-entry mapping'
cls, d = d.items()[0]
# __aliases__ is a place where aliases can be set up
if cls == '__aliases__':
continue
if isinstance(d, list):
new_objects = [ make_object(cls, d, objects) for d in d ]
else:
new_objects = make_object(cls, d, objects),
session_new_snapshot = set(Session.new)
for obj in new_objects:
Session.add(obj)
additional_new_objects = set(Session.new)-session_new_snapshot-set(new_objects)
log_.info(" Loaded %s %s objects", len(new_objects), cls)
for cls,group in groupby(additional_new_objects, type):
log_.info(" and %s %s objects", len(list(group)), cls.__name__)
Session.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment