Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Serialize SQLAlchemy Model to dictionary (for JSON output) and update Model from dictionary attributes.
import uuid
import wtforms_json
from sqlalchemy import not_
from sqlalchemy.dialects.postgresql import UUID
from wtforms import Form
from wtforms.fields import FormField, FieldList
from wtforms.validators import Length
from flask import current_app as app
from flask import request, json, jsonify, abort
from flask.ext.sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
wtforms_json.init()
class Model(db.Model):
"""Base SQLAlchemy Model for automatic serialization and
deserialization of columns and nested relationships.
Usage::
>>> class User(Model):
>>> id = db.Column(db.Integer(), primary_key=True)
>>> email = db.Column(db.String(), index=True)
>>> name = db.Column(db.String())
>>> password = db.Column(db.String())
>>> posts = db.relationship('Post', backref='user', lazy='dynamic')
>>> ...
>>> default_fields = ['email', 'name']
>>> hidden_fields = ['password']
>>> readonly_fields = ['email', 'password']
>>>
>>> class Post(Model):
>>> id = db.Column(db.Integer(), primary_key=True)
>>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False)
>>> title = db.Column(db.String())
>>> ...
>>> default_fields = ['title']
>>> readonly_fields = ['user_id']
>>>
>>> model = User(email='john@localhost')
>>> db.session.add(model)
>>> db.session.commit()
>>>
>>> # update name and create a new post
>>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]}
>>> model.set_columns(**validated_input)
>>> db.session.commit()
>>>
>>> print(model.to_dict(show=['password', 'posts']))
>>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1}
"""
__abstract__ = True
# Stores changes made to this model's attributes. Can be retrieved
# with model.changes
_changes = {}
def __init__(self, **kwargs):
kwargs['_force'] = True
self._set_columns(**kwargs)
def _set_columns(self, **kwargs):
force = kwargs.get('_force')
readonly = []
if hasattr(self, 'readonly_fields'):
readonly = self.readonly_fields
if hasattr(self, 'hidden_fields'):
readonly += self.hidden_fields
readonly += [
'id',
'created',
'updated',
'modified',
'created_at',
'updated_at',
'modified_at',
]
changes = {}
columns = self.__table__.columns.keys()
relationships = self.__mapper__.relationships.keys()
for key in columns:
allowed = True if force or key not in readonly else False
exists = True if key in kwargs else False
if allowed and exists:
val = getattr(self, key)
if val != kwargs[key]:
changes[key] = {'old': val, 'new': kwargs[key]}
setattr(self, key, kwargs[key])
for rel in relationships:
allowed = True if force or rel not in readonly else False
exists = True if rel in kwargs else False
if allowed and exists:
is_list = self.__mapper__.relationships[rel].uselist
if is_list:
valid_ids = []
query = getattr(self, rel)
cls = self.__mapper__.relationships[rel].argument()
for item in kwargs[rel]:
if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1:
obj = cls.query.filter_by(id=item['id']).first()
col_changes = obj.set_columns(**item)
if col_changes:
col_changes['id'] = str(item['id'])
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
valid_ids.append(str(item['id']))
else:
col = cls()
col_changes = col.set_columns(**item)
query.append(col)
db.session.flush()
if col_changes:
col_changes['id'] = str(col.id)
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
valid_ids.append(str(col.id))
# delete related rows that were not in kwargs[rel]
for item in query.filter(not_(cls.id.in_(valid_ids))).all():
col_changes = {
'id': str(item.id),
'deleted': True,
}
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
db.session.delete(item)
else:
val = getattr(self, rel)
if self.__mapper__.relationships[rel].query_class is not None:
if val is not None:
col_changes = val.set_columns(**kwargs[rel])
if col_changes:
changes.update({rel: col_changes})
else:
if val != kwargs[rel]:
setattr(self, rel, kwargs[rel])
changes[rel] = {'old': val, 'new': kwargs[rel]}
return changes
def set_columns(self, **kwargs):
self._changes = self._set_columns(**kwargs)
if 'modified' in self.__table__.columns:
self.modified = datetime.utcnow()
if 'updated' in self.__table__.columns:
self.updated = datetime.utcnow()
if 'modified_at' in self.__table__.columns:
self.modified_at = datetime.utcnow()
if 'updated_at' in self.__table__.columns:
self.updated_at = datetime.utcnow()
return self._changes
@property
def changes(self):
return self._changes
def reset_changes(self):
self._changes = {}
def to_dict(self, show=None, hide=None, path=None, show_all=None):
""" Return a dictionary representation of this model.
"""
if not show:
show = []
if not hide:
hide = []
hidden = []
if hasattr(self, 'hidden_fields'):
hidden = self.hidden_fields
default = []
if hasattr(self, 'default_fields'):
default = self.default_fields
ret_data = {}
if not path:
path = self.__tablename__.lower()
def prepend_path(item):
item = item.lower()
if item.split('.', 1)[0] == path:
return item
if len(item) == 0:
return item
if item[0] != '.':
item = '.%s' % item
item = '%s%s' % (path, item)
return item
show[:] = [prepend_path(x) for x in show]
hide[:] = [prepend_path(x) for x in hide]
columns = self.__table__.columns.keys()
relationships = self.__mapper__.relationships.keys()
properties = dir(self)
for key in columns:
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or key is 'id' or check in show or key in default:
ret_data[key] = getattr(self, key)
for key in relationships:
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or check in show or key in default:
hide.append(check)
is_list = self.__mapper__.relationships[key].uselist
if is_list:
ret_data[key] = []
for item in getattr(self, key):
ret_data[key].append(item.to_dict(
show=show,
hide=hide,
path=('%s.%s' % (path, key.lower())),
show_all=show_all,
))
else:
if self.__mapper__.relationships[key].query_class is not None:
ret_data[key] = getattr(self, key).to_dict(
show=show,
hide=hide,
path=('%s.%s' % (path, key.lower())),
show_all=show_all,
)
else:
ret_data[key] = getattr(self, key)
for key in list(set(properties) - set(columns) - set(relationships)):
if key.startswith('_'):
continue
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or check in show or key in default:
val = getattr(self, key)
try:
ret_data[key] = json.loads(json.dumps(val))
except:
pass
return ret_data
class User(Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
first_name = db.Column(db.String(120))
last_name = db.Column(db.String(120))
posts = db.relationship('Post', backref='user', lazy='dynamic')
class Post(Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('user.id'), nullable=False)
title = db.Column(db.String(200))
text = db.Column(db.String())
class PostForm(Form):
title = StringField(validators=[Length(max=200)])
text = StringField()
class UserForm(Form):
first_name = StringField(validators=[Length(max=120)])
last_name = StringField(validators=[Length(max=120)])
posts = FieldList(FormField(PostForm))
def requested_columns(request):
show = request.args.get('show', None)
if not show:
return []
return show.split(',')
@app.route('/users/<string:user_id>', methods=['GET'])
def read_user(user_id):
# get user from database
user = User.query.filter_by(id=user_id).first()
if user is None:
abort(404)
# return user as json
show = requested_columns(request)
return jsonify(data=user.to_dict(show=show))
@app.route('/users/<string:user_id>', methods=['PUT'])
def update_user(user_id):
# get user from database
user = User.query.filter_by(id=user_id).first()
if user is None:
abort(404)
input_data = request.get_json(force=True)
if not isinstance(input_data, dict):
return jsonify(error='Request data must be a JSON Object'), 400
# validate json user input using WTForms-JSON
form = UserForm.from_json(input_data)
if not form.validate():
return jsonify(errors=form.errors), 400
# update user in database
user.set_columns(**form.patch_data)
db.session.commit()
# return user as json
show = requested_columns(request)
return jsonify(data=user.to_dict(show=show))
@paramire

This comment has been minimized.

Copy link

@paramire paramire commented Apr 5, 2017

Hi, the to_dict method should include a check that key is equal to 'query', because db.Model includes it, and when is called by a getattr is making a call to the Database to obtain all elements, like Model.query.all().
If a User has 10 Posts, for every to_dict to a Post child will query all elements

for key in list(set(properties) - set(columns) - set(relationships)):
   if key.startswith('_'):
       continue
   check = '%s.%s' % (path, key)
   if check in hide or key in hidden or key == 'query':
       continue
   if show_all or check in show or key in default:
       val = getattr(self, key)
       try:
           ret_data[key] = json.loads(json.dumps(val))
       except:
           pass

Anyway, this to_dict has been very helpfull
Cheers

@piotr-dobrogost

This comment has been minimized.

Copy link

@piotr-dobrogost piotr-dobrogost commented Jan 3, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment