Skip to content

Instantly share code, notes, and snippets.

@alanhamlett
Last active January 24, 2023 21:03
Show Gist options
  • Save alanhamlett/6604662 to your computer and use it in GitHub Desktop.
Save alanhamlett/6604662 to your computer and use it in GitHub Desktop.
Serialize SQLAlchemy Model to dictionary (for JSON output) and update Model from dictionary attributes.
import uuid
import wtforms_json
from sqlalchemy import not_
from sqlalchemy.dialects.postgresql import UUID
from wtforms import Form
from wtforms.fields import FormField, FieldList
from wtforms.validators import Length
from flask import current_app as app
from flask import request, json, jsonify, abort
from flask.ext.sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
wtforms_json.init()
class Model(db.Model):
"""Base SQLAlchemy Model for automatic serialization and
deserialization of columns and nested relationships.
Usage::
>>> class User(Model):
>>> id = db.Column(db.Integer(), primary_key=True)
>>> email = db.Column(db.String(), index=True)
>>> name = db.Column(db.String())
>>> password = db.Column(db.String())
>>> posts = db.relationship('Post', backref='user', lazy='dynamic')
>>> ...
>>> default_fields = ['email', 'name']
>>> hidden_fields = ['password']
>>> readonly_fields = ['email', 'password']
>>>
>>> class Post(Model):
>>> id = db.Column(db.Integer(), primary_key=True)
>>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False)
>>> title = db.Column(db.String())
>>> ...
>>> default_fields = ['title']
>>> readonly_fields = ['user_id']
>>>
>>> model = User(email='john@localhost')
>>> db.session.add(model)
>>> db.session.commit()
>>>
>>> # update name and create a new post
>>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]}
>>> model.set_columns(**validated_input)
>>> db.session.commit()
>>>
>>> print(model.to_dict(show=['password', 'posts']))
>>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1}
"""
__abstract__ = True
# Stores changes made to this model's attributes. Can be retrieved
# with model.changes
_changes = {}
def __init__(self, **kwargs):
kwargs['_force'] = True
self._set_columns(**kwargs)
def _set_columns(self, **kwargs):
force = kwargs.get('_force')
readonly = []
if hasattr(self, 'readonly_fields'):
readonly = self.readonly_fields
if hasattr(self, 'hidden_fields'):
readonly += self.hidden_fields
readonly += [
'id',
'created',
'updated',
'modified',
'created_at',
'updated_at',
'modified_at',
]
changes = {}
columns = self.__table__.columns.keys()
relationships = self.__mapper__.relationships.keys()
for key in columns:
allowed = True if force or key not in readonly else False
exists = True if key in kwargs else False
if allowed and exists:
val = getattr(self, key)
if val != kwargs[key]:
changes[key] = {'old': val, 'new': kwargs[key]}
setattr(self, key, kwargs[key])
for rel in relationships:
allowed = True if force or rel not in readonly else False
exists = True if rel in kwargs else False
if allowed and exists:
is_list = self.__mapper__.relationships[rel].uselist
if is_list:
valid_ids = []
query = getattr(self, rel)
cls = self.__mapper__.relationships[rel].argument()
for item in kwargs[rel]:
if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1:
obj = cls.query.filter_by(id=item['id']).first()
col_changes = obj.set_columns(**item)
if col_changes:
col_changes['id'] = str(item['id'])
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
valid_ids.append(str(item['id']))
else:
col = cls()
col_changes = col.set_columns(**item)
query.append(col)
db.session.flush()
if col_changes:
col_changes['id'] = str(col.id)
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
valid_ids.append(str(col.id))
# delete related rows that were not in kwargs[rel]
for item in query.filter(not_(cls.id.in_(valid_ids))).all():
col_changes = {
'id': str(item.id),
'deleted': True,
}
if rel in changes:
changes[rel].append(col_changes)
else:
changes.update({rel: [col_changes]})
db.session.delete(item)
else:
val = getattr(self, rel)
if self.__mapper__.relationships[rel].query_class is not None:
if val is not None:
col_changes = val.set_columns(**kwargs[rel])
if col_changes:
changes.update({rel: col_changes})
else:
if val != kwargs[rel]:
setattr(self, rel, kwargs[rel])
changes[rel] = {'old': val, 'new': kwargs[rel]}
return changes
def set_columns(self, **kwargs):
self._changes = self._set_columns(**kwargs)
if 'modified' in self.__table__.columns:
self.modified = datetime.utcnow()
if 'updated' in self.__table__.columns:
self.updated = datetime.utcnow()
if 'modified_at' in self.__table__.columns:
self.modified_at = datetime.utcnow()
if 'updated_at' in self.__table__.columns:
self.updated_at = datetime.utcnow()
return self._changes
@property
def changes(self):
return self._changes
def reset_changes(self):
self._changes = {}
def to_dict(self, show=None, hide=None, path=None, show_all=None):
""" Return a dictionary representation of this model.
"""
if not show:
show = []
if not hide:
hide = []
hidden = []
if hasattr(self, 'hidden_fields'):
hidden = self.hidden_fields
default = []
if hasattr(self, 'default_fields'):
default = self.default_fields
ret_data = {}
if not path:
path = self.__tablename__.lower()
def prepend_path(item):
item = item.lower()
if item.split('.', 1)[0] == path:
return item
if len(item) == 0:
return item
if item[0] != '.':
item = '.%s' % item
item = '%s%s' % (path, item)
return item
show[:] = [prepend_path(x) for x in show]
hide[:] = [prepend_path(x) for x in hide]
columns = self.__table__.columns.keys()
relationships = self.__mapper__.relationships.keys()
properties = dir(self)
for key in columns:
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or key is 'id' or check in show or key in default:
ret_data[key] = getattr(self, key)
for key in relationships:
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or check in show or key in default:
hide.append(check)
is_list = self.__mapper__.relationships[key].uselist
if is_list:
ret_data[key] = []
for item in getattr(self, key):
ret_data[key].append(item.to_dict(
show=show,
hide=hide,
path=('%s.%s' % (path, key.lower())),
show_all=show_all,
))
else:
if self.__mapper__.relationships[key].query_class is not None:
ret_data[key] = getattr(self, key).to_dict(
show=show,
hide=hide,
path=('%s.%s' % (path, key.lower())),
show_all=show_all,
)
else:
ret_data[key] = getattr(self, key)
for key in list(set(properties) - set(columns) - set(relationships)):
if key.startswith('_'):
continue
check = '%s.%s' % (path, key)
if check in hide or key in hidden:
continue
if show_all or check in show or key in default:
val = getattr(self, key)
try:
ret_data[key] = json.loads(json.dumps(val))
except:
pass
return ret_data
class User(Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
first_name = db.Column(db.String(120))
last_name = db.Column(db.String(120))
posts = db.relationship('Post', backref='user', lazy='dynamic')
class Post(Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('user.id'), nullable=False)
title = db.Column(db.String(200))
text = db.Column(db.String())
class PostForm(Form):
title = StringField(validators=[Length(max=200)])
text = StringField()
class UserForm(Form):
first_name = StringField(validators=[Length(max=120)])
last_name = StringField(validators=[Length(max=120)])
posts = FieldList(FormField(PostForm))
def requested_columns(request):
show = request.args.get('show', None)
if not show:
return []
return show.split(',')
@app.route('/users/<string:user_id>', methods=['GET'])
def read_user(user_id):
# get user from database
user = User.query.filter_by(id=user_id).first()
if user is None:
abort(404)
# return user as json
show = requested_columns(request)
return jsonify(data=user.to_dict(show=show))
@app.route('/users/<string:user_id>', methods=['PUT'])
def update_user(user_id):
# get user from database
user = User.query.filter_by(id=user_id).first()
if user is None:
abort(404)
input_data = request.get_json(force=True)
if not isinstance(input_data, dict):
return jsonify(error='Request data must be a JSON Object'), 400
# validate json user input using WTForms-JSON
form = UserForm.from_json(input_data)
if not form.validate():
return jsonify(errors=form.errors), 400
# update user in database
user.set_columns(**form.patch_data)
db.session.commit()
# return user as json
show = requested_columns(request)
return jsonify(data=user.to_dict(show=show))
@paramire
Copy link

paramire commented Apr 5, 2017

Hi, the to_dict method should include a check that key is equal to 'query', because db.Model includes it, and when is called by a getattr is making a call to the Database to obtain all elements, like Model.query.all().
If a User has 10 Posts, for every to_dict to a Post child will query all elements

for key in list(set(properties) - set(columns) - set(relationships)):
   if key.startswith('_'):
       continue
   check = '%s.%s' % (path, key)
   if check in hide or key in hidden or key == 'query':
       continue
   if show_all or check in show or key in default:
       val = getattr(self, key)
       try:
           ret_data[key] = json.loads(json.dumps(val))
       except:
           pass

Anyway, this to_dict has been very helpfull
Cheers

@piotr-dobrogost
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment