Last active
October 21, 2024 14:30
-
-
Save alanhamlett/6604662 to your computer and use it in GitHub Desktop.
Serialize SQLAlchemy Model to dictionary (for JSON output) and update Model from dictionary attributes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
import wtforms_json | |
from sqlalchemy import not_ | |
from sqlalchemy.dialects.postgresql import UUID | |
from wtforms import Form | |
from wtforms.fields import FormField, FieldList | |
from wtforms.validators import Length | |
from flask import current_app as app | |
from flask import request, json, jsonify, abort | |
from flask.ext.sqlalchemy import SQLAlchemy | |
db = SQLAlchemy(app) | |
wtforms_json.init() | |
class Model(db.Model): | |
"""Base SQLAlchemy Model for automatic serialization and | |
deserialization of columns and nested relationships. | |
Usage:: | |
>>> class User(Model): | |
>>> id = db.Column(db.Integer(), primary_key=True) | |
>>> email = db.Column(db.String(), index=True) | |
>>> name = db.Column(db.String()) | |
>>> password = db.Column(db.String()) | |
>>> posts = db.relationship('Post', backref='user', lazy='dynamic') | |
>>> ... | |
>>> default_fields = ['email', 'name'] | |
>>> hidden_fields = ['password'] | |
>>> readonly_fields = ['email', 'password'] | |
>>> | |
>>> class Post(Model): | |
>>> id = db.Column(db.Integer(), primary_key=True) | |
>>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False) | |
>>> title = db.Column(db.String()) | |
>>> ... | |
>>> default_fields = ['title'] | |
>>> readonly_fields = ['user_id'] | |
>>> | |
>>> model = User(email='john@localhost') | |
>>> db.session.add(model) | |
>>> db.session.commit() | |
>>> | |
>>> # update name and create a new post | |
>>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]} | |
>>> model.set_columns(**validated_input) | |
>>> db.session.commit() | |
>>> | |
>>> print(model.to_dict(show=['password', 'posts'])) | |
>>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1} | |
""" | |
__abstract__ = True | |
# Stores changes made to this model's attributes. Can be retrieved | |
# with model.changes | |
_changes = {} | |
def __init__(self, **kwargs): | |
kwargs['_force'] = True | |
self._set_columns(**kwargs) | |
def _set_columns(self, **kwargs): | |
force = kwargs.get('_force') | |
readonly = [] | |
if hasattr(self, 'readonly_fields'): | |
readonly = self.readonly_fields | |
if hasattr(self, 'hidden_fields'): | |
readonly += self.hidden_fields | |
readonly += [ | |
'id', | |
'created', | |
'updated', | |
'modified', | |
'created_at', | |
'updated_at', | |
'modified_at', | |
] | |
changes = {} | |
columns = self.__table__.columns.keys() | |
relationships = self.__mapper__.relationships.keys() | |
for key in columns: | |
allowed = True if force or key not in readonly else False | |
exists = True if key in kwargs else False | |
if allowed and exists: | |
val = getattr(self, key) | |
if val != kwargs[key]: | |
changes[key] = {'old': val, 'new': kwargs[key]} | |
setattr(self, key, kwargs[key]) | |
for rel in relationships: | |
allowed = True if force or rel not in readonly else False | |
exists = True if rel in kwargs else False | |
if allowed and exists: | |
is_list = self.__mapper__.relationships[rel].uselist | |
if is_list: | |
valid_ids = [] | |
query = getattr(self, rel) | |
cls = self.__mapper__.relationships[rel].argument() | |
for item in kwargs[rel]: | |
if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1: | |
obj = cls.query.filter_by(id=item['id']).first() | |
col_changes = obj.set_columns(**item) | |
if col_changes: | |
col_changes['id'] = str(item['id']) | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
valid_ids.append(str(item['id'])) | |
else: | |
col = cls() | |
col_changes = col.set_columns(**item) | |
query.append(col) | |
db.session.flush() | |
if col_changes: | |
col_changes['id'] = str(col.id) | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
valid_ids.append(str(col.id)) | |
# delete related rows that were not in kwargs[rel] | |
for item in query.filter(not_(cls.id.in_(valid_ids))).all(): | |
col_changes = { | |
'id': str(item.id), | |
'deleted': True, | |
} | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
db.session.delete(item) | |
else: | |
val = getattr(self, rel) | |
if self.__mapper__.relationships[rel].query_class is not None: | |
if val is not None: | |
col_changes = val.set_columns(**kwargs[rel]) | |
if col_changes: | |
changes.update({rel: col_changes}) | |
else: | |
if val != kwargs[rel]: | |
setattr(self, rel, kwargs[rel]) | |
changes[rel] = {'old': val, 'new': kwargs[rel]} | |
return changes | |
def set_columns(self, **kwargs): | |
self._changes = self._set_columns(**kwargs) | |
if 'modified' in self.__table__.columns: | |
self.modified = datetime.utcnow() | |
if 'updated' in self.__table__.columns: | |
self.updated = datetime.utcnow() | |
if 'modified_at' in self.__table__.columns: | |
self.modified_at = datetime.utcnow() | |
if 'updated_at' in self.__table__.columns: | |
self.updated_at = datetime.utcnow() | |
return self._changes | |
@property | |
def changes(self): | |
return self._changes | |
def reset_changes(self): | |
self._changes = {} | |
def to_dict(self, show=None, hide=None, path=None, show_all=None): | |
""" Return a dictionary representation of this model. | |
""" | |
if not show: | |
show = [] | |
if not hide: | |
hide = [] | |
hidden = [] | |
if hasattr(self, 'hidden_fields'): | |
hidden = self.hidden_fields | |
default = [] | |
if hasattr(self, 'default_fields'): | |
default = self.default_fields | |
ret_data = {} | |
if not path: | |
path = self.__tablename__.lower() | |
def prepend_path(item): | |
item = item.lower() | |
if item.split('.', 1)[0] == path: | |
return item | |
if len(item) == 0: | |
return item | |
if item[0] != '.': | |
item = '.%s' % item | |
item = '%s%s' % (path, item) | |
return item | |
show[:] = [prepend_path(x) for x in show] | |
hide[:] = [prepend_path(x) for x in hide] | |
columns = self.__table__.columns.keys() | |
relationships = self.__mapper__.relationships.keys() | |
properties = dir(self) | |
for key in columns: | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or key is 'id' or check in show or key in default: | |
ret_data[key] = getattr(self, key) | |
for key in relationships: | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or check in show or key in default: | |
hide.append(check) | |
is_list = self.__mapper__.relationships[key].uselist | |
if is_list: | |
ret_data[key] = [] | |
for item in getattr(self, key): | |
ret_data[key].append(item.to_dict( | |
show=show, | |
hide=hide, | |
path=('%s.%s' % (path, key.lower())), | |
show_all=show_all, | |
)) | |
else: | |
if self.__mapper__.relationships[key].query_class is not None: | |
ret_data[key] = getattr(self, key).to_dict( | |
show=show, | |
hide=hide, | |
path=('%s.%s' % (path, key.lower())), | |
show_all=show_all, | |
) | |
else: | |
ret_data[key] = getattr(self, key) | |
for key in list(set(properties) - set(columns) - set(relationships)): | |
if key.startswith('_'): | |
continue | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or check in show or key in default: | |
val = getattr(self, key) | |
try: | |
ret_data[key] = json.loads(json.dumps(val)) | |
except: | |
pass | |
return ret_data | |
class User(Model): | |
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) | |
first_name = db.Column(db.String(120)) | |
last_name = db.Column(db.String(120)) | |
posts = db.relationship('Post', backref='user', lazy='dynamic') | |
class Post(Model): | |
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) | |
user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('user.id'), nullable=False) | |
title = db.Column(db.String(200)) | |
text = db.Column(db.String()) | |
class PostForm(Form): | |
title = StringField(validators=[Length(max=200)]) | |
text = StringField() | |
class UserForm(Form): | |
first_name = StringField(validators=[Length(max=120)]) | |
last_name = StringField(validators=[Length(max=120)]) | |
posts = FieldList(FormField(PostForm)) | |
def requested_columns(request): | |
show = request.args.get('show', None) | |
if not show: | |
return [] | |
return show.split(',') | |
@app.route('/users/<string:user_id>', methods=['GET']) | |
def read_user(user_id): | |
# get user from database | |
user = User.query.filter_by(id=user_id).first() | |
if user is None: | |
abort(404) | |
# return user as json | |
show = requested_columns(request) | |
return jsonify(data=user.to_dict(show=show)) | |
@app.route('/users/<string:user_id>', methods=['PUT']) | |
def update_user(user_id): | |
# get user from database | |
user = User.query.filter_by(id=user_id).first() | |
if user is None: | |
abort(404) | |
input_data = request.get_json(force=True) | |
if not isinstance(input_data, dict): | |
return jsonify(error='Request data must be a JSON Object'), 400 | |
# validate json user input using WTForms-JSON | |
form = UserForm.from_json(input_data) | |
if not form.validate(): | |
return jsonify(errors=form.errors), 400 | |
# update user in database | |
user.set_columns(**form.patch_data) | |
db.session.commit() | |
# return user as json | |
show = requested_columns(request) | |
return jsonify(data=user.to_dict(show=show)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, the to_dict method should include a check that key is equal to 'query', because db.Model includes it, and when is called by a getattr is making a call to the Database to obtain all elements, like
Model.query.all()
.If a User has 10 Posts, for every to_dict to a Post child will query all elements
Anyway, this to_dict has been very helpfull
Cheers