Skip to content

Instantly share code, notes, and snippets.

@martijnluinstra
Last active July 24, 2021 01:21
Show Gist options
  • Save martijnluinstra/f10ff7f2125d8b618a7df858f834cd66 to your computer and use it in GitHub Desktop.
Save martijnluinstra/f10ff7f2125d8b618a7df858f834cd66 to your computer and use it in GitHub Desktop.
Python 3 classes for filtering SQLAlchemy queries
"""
This set of Python 3 classes is created to make automatic filtering of
SQLAlchemy queries easier. It is designed to have an api similar to
Django-filter (https://github.com/carltongibson/django-filter).
You may use an modify this code however you like for non-commercial purposes.
I will appreciate it if you mention my name when you do so.
Copyright (c) 2021 Martijn Luinstra
"""
import datetime
class Filter:
"""
Base filter class
"""
def __init__(self, operation='action', field=None, action=None, nullable=False):
"""
operation: the name of the class-method to use
field: the field to operate on (will use name of the property if not assigned)
action: custom action function that accepts a query and a compare value and returns a filtered query
nullable: allows comparevalue = None
"""
self.operation = operation
self.field = field
self._action = action
self.nullable = nullable
def prepare_compare_value(self, compare):
""" Format the compare value """
return compare
def apply(self, query, field, compare):
""" Apply the filter """
compare = self.prepare_compare_value(compare)
if not self.nullable and compare is None:
return query
if self._action:
return self._action(query, compare)
if self.field:
field = self.field
return getattr(self, self.operation)(query, field, compare)
def action(self, query, field, compare):
""" The filter action """
return query.filter(field==compare)
class CompareFilter(Filter):
"""
Base filter class for (basic) comparisons
"""
def equal_to(self, query, field, compare):
return query.filter(field == compare)
def not_equal_to(self, query, field, compare):
return query.filter(field != compare)
def greater_equal_to(self, query, field, compare):
return query.filter(field >= compare)
def greater_than(self, query, field, compare):
return query.filter(field > compare)
def less_equal_to(self, query, field, compare):
return query.filter(field <= compare)
def less_than(self, query, field, compare):
return query.filter(field < compare)
class ListFilter(Filter):
"""
Base filter that accepts comma separated lists as compare value
"""
def prepare_compare_value(self, compare):
if not compare:
return None
compare = compare.split(',')
if (len(compare) == 1 and not compare[0]):
return None
return [c.strip() for c in compare]
class RangeFilter(CompareFilter, ListFilter):
"""
Base filter that filters for ranges
"""
def prepare_compare_value(self, compare):
compare = super().prepare_compare_value(compare)
if compare and len(compare) > 2:
return None
return compare
def action(self, query, field, compare):
if len(compare) == 1:
return self.equal_to(query, field, compare[0])
if compare[0]: # Left boundary
query = self.greater_than(query, field, compare[0])
if compare[1]: # Right boundary
query = self.less_than(query, field, compare[1])
return query
class InFilter(ListFilter):
"""
Base filter that filters whether the value is in a list of comparevalues
"""
def action(self, query, field, compare):
return query.filter(field.in_(compare))
class BooleanFilter(Filter):
"""
Filter that compares boolean values
"""
def prepare_compare_value(self, compare):
if isinstance(compare, bool) or compare is None:
return compare
elif compare.lower() in ('true', 'y', 'yes', 't', 'on'):
return True
elif compare.lower() in ('false', 'n', 'no', 'f', 'off'):
return False
return None
class StringFilter(Filter):
"""
Filter that compares string values
"""
def __init__(self, template='{}', **kwargs):
self.template = template
if not 'operation' in kwargs:
kwargs['operation'] = 'equals'
super().__init__(**kwargs)
def prepare_compare_value(self, compare):
if not compare:
return None
return str(compare)
def equal_to(self, query, field, compare):
return query.filter(field==compare)
def like(self, query, field, compare):
return query.filter(field.like(self.template.format(compare)))
class IntegerFilter(CompareFilter):
"""
Filter that compares integer values
"""
def prepare_compare_value(self, compare):
try:
return int(compare)
except Exception:
# Skip filter
return None
class ForeignKeyFilter(IntegerFilter):
"""
Filter that compares foreign keys (only accepts non-zero compare values)
"""
def prepare_compare_value(self, compare):
try:
compare = int(compare)
return compare if compare else None
except Exception:
# Skip filter
return None
class DateTimeFilter(CompareFilter):
"""
Filter that compares datetime values
"""
def __init__(self, format='%Y-%m-%d', **kwargs):
self.format = format
super().__init__(**kwargs)
def prepare_compare_value(self, compare):
try:
return datetime.datetime.strptime(compare, self.format)
except Exception:
# Skip filter
return None
class DateTimeRangeFilter(RangeFilter):
"""
Filter that compares datetime ranges
"""
def __init__(self, format='%Y-%m-%d', **kwargs):
self.format = format
super().__init__(**kwargs)
def prepare_compare_value(self, compare):
compare = super().prepare_compare_value(compare)
if not compare:
return None
for idx, c in enumerate(compare):
try:
compare[idx] = datetime.datetime.strptime(c, self.format)
except Exception:
# Skip failed item in range
compare[idx] = None
return compare
class FilteredQuery:
"""
Base class for filtered queries
"""
def __init__(self, query, data, default=None):
self.data = default if default else {}
self.data.update(data)
self.query = self.apply(query)
if isinstance(self.data.get('order_by'), list):
self._order = self.data['order_by']
elif isinstance(self.data.get('order_by'), str):
self._order = self.data['order_by'].split(',')
self._order = [o.strip() for o in self._order]
else:
self._order = None
def apply(self, query):
"""
Applies filters to query
"""
for key in dir(self):
if key not in self.data:
continue
value = getattr(self, key)
compare = self.data[key]
if isinstance(value, Filter):
query = value.apply(query, self._get_field(key), compare)
return query
def order(self, query):
"""
Orders query
"""
if not self._order:
return query
args = []
for name in self._order:
if not name:
continue
elif name[0] == '-':
field = self._get_field(name[1:])
if field:
args.append( field.desc() )
else:
args.append( self._get_field(name) )
if args:
return query.order_by(*args)
return query
@property
def ordered_query(self):
return self.order(self.query)
def _get_field(self, fieldname):
"""
Gets field object based on __model__ property and the filter property name
"""
if not hasattr(self, '__model__'):
return None
model = getattr(self,'__model__')
if hasattr(model, fieldname):
return getattr(model, fieldname)
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment