Skip to content

Instantly share code, notes, and snippets.

@kcuzner
Created March 26, 2013 14:59
Show Gist options
  • Save kcuzner/5246020 to your computer and use it in GitHub Desktop.
Save kcuzner/5246020 to your computer and use it in GitHub Desktop.
Relatively simple database abstraction layer. http://kevincuzner.com/?p=261
"""
Database abstraction layer
by Kevin Cuzner
Designed for PEP 249-style cursor objects
Objects:
- Column: Base column descriptor which tracks changes and such for a DbObject.
It also does validation and other things.
- DbObject: Represents a table in the database. Instances of this object are
rows in the table. dbo_tablename and primary_key are required fields on the
class level
- DbQuery: The subclasses of this represent queries to the database
"""
class ColumnSet(object):
"""
Object which is updated by ColumnInstances to inform changes
"""
def __init__(self):
self.__columns = {} # columns are sorted by name
i_dict = type(self).__dict__
for attr in i_dict:
obj = i_dict[attr]
if isinstance(obj, Column):
# we get an instance of this column
self.__columns[obj.name] = ColumnInstance(obj, self)
@property
def mutated(self):
"""
Returns the mutated columns for this tracker.
"""
output = []
for name in self.__columns:
column = self.get_column(name)
if column.mutated:
output.append(column)
return output
def get_column(self, name):
return self.__columns[name]
class DbObject(ColumnSet):
"""
A DbObject is a set of columns linked to a table in the database. This is
synonomous to a row. The following class attributes must be set:
dbo_tablename : string table name
primary_key : Column for the primary key
"""
def __init__(self, **cols):
ColumnSet.__init__(self)
for name in cols:
c = self.get_column(name)
c.update(cols[name])
@classmethod
def get_query_columns(self, prefix):
return DbQueryColumnSet(self, prefix)
@classmethod
def select(self, prefix):
"""
Returns a DbSelectQuery set up for this DbObject
"""
columns = self.get_query_columns(prefix)
def execute(query, cur):
output = []
block = cur.fetchmany()
while len(block) > 0:
for row in block:
values = {}
i = 0
for name in columns:
values[name] = row[i]
i += 1
output.append(self(**values))
block = cur.fetchmany()
return output
query = DbSelectQuery(execute)
query.select(*[columns[name] for name in columns])
query.from_table(self, prefix)
return query, columns
def get_primary_key_name(self):
return type(self).__dict__['primary_key'].name
def save(self, cur):
"""
Saves any changes to this object to the database
"""
if self.primary_key is None:
# we need to be saved
columns = self.get_query_columns('x')
def execute(query, cur):
self.get_column(self.get_primary_key_name()\
).update(cur.lastrowid)
selection = []
for name in columns:
if name == self.get_primary_key_name():
continue #we have no need to update the primary key
column_instance = self.get_column(name)
if not column_instance.column.mutable:
selection.append(columns[name])
if len(selection) != 0:
# we get to select to get additional computed values
def execute2(query, cur):
row = cur.fetchone()
index = 0
for s in selection:
self.get_column(s.name).update(row[index])
index += 1
return True
query = DbSelectQuery(execute2)
query.select(*selection)
query.from_table(type(self), 'x')
query.where(columns[self.get_primary_key_name()] == \
self.get_column(self.get_primary_key_name()\
).value)
return query.execute(cur)
return True
query = DbInsertQuery(type(self), 'x', execute)
for name in columns:
column_instance = self.get_column(name)
if not column_instance.column.mutable:
continue
query.value(columns[name], column_instance.value)
print query.sql
return query.execute(cur)
else:
# we have been modified
modified = self.mutated
if len(modified) == 0:
return True
columns = self.get_query_columns('x')
def execute(query, cur):
for mod in modified:
mod.update(mod.value)
return True
query = DbUpdateQuery(type(self), 'x', execute)
for mod in modified:
query.update(columns[mod.column.name], mod.value)
query.where(columns[self.get_primary_key_name()] == self.primary_key)
return query.execute(cur)
class ColumnInstance(object):
"""
Per-instance column data. This is used in ColumnSet objects to hold data
specific to that particular instance
"""
def __init__(self, column, owner):
"""
column: Column object this is created for
initial: Initial value
"""
self.__column = column
self.__owner = owner
self.update(column.default)
def update(self, value):
"""
Updates the value for this instance, resetting the mutated flag
"""
if value is None and not self.__column.allow_none:
raise ValueError("'None' is invalid for column '" + \
self.__column.name + "'")
if self.__column.validate(value):
self.__value = value
self.__origvalue = value
else:
raise ValueError("'" + str(value) + "' is not valid for column '" + \
self.__column.name + "'")
@property
def column(self):
return self.__column
@property
def owner(self):
return self.__owner
@property
def mutated(self):
return self.__value != self.__origvalue
@property
def value(self):
return self.__value
@value.setter
def value(self, value):
if value is None and not self.__column.allow_none:
raise ValueError("'None' is invalid for column '" + \
self.__column.name + "'")
if not self.__column.mutable:
raise AttributeError("Column '" + self.__column.name + "' is not" +
" mutable")
if self.__column.validate(value):
self.__value = value
else:
raise ValueError("'" + value + "' is not valid for column '" + \
self.__column.name + "'")
class DbQueryError(Exception):
"""
Raised when there is an error constructing a query
"""
def __init__(self, msg):
self.message = msg
def __str__(self):
return self.message
class DbQuery(object):
"""
Represents a base SQL Query to a database based upon some DbObjects
All of the methods implemented here are valid on select, update, and
delete statements.
"""
def __init__(self, execute_filter=None):
"""
callback: Function to call when the DbQuery is executed
"""
self.__where = []
self.__limit = None
self.__orderby = []
self.__execute_filter = execute_filter
def where(self, expression):
"""Specify an expression to append to the WHERE clause"""
self.__where.append(expression)
def limit(self, value=None):
"""Specify the limit to the query"""
self.__limit = value
@property
def sql(self):
query = ""
args = []
if len(self.__where) > 0:
where = self.__where[0]
for clause in self.__where[1:]:
where = where & clause
args = where.arguments
query += " WHERE " + str(where)
if self.__limit is not None:
query += " LIMIT " + self.__limit
return query,args
def execute(self, cur):
"""
Executes this query on the passed cursor and returns either the result
of the filter function or the cursor if there is no filter function.
"""
query = self.sql
cur.execute(query[0], query[1])
if self.__execute_filter:
return self.__execute_filter(self, cur)
else:
return cur
class DbSelectQuery(DbQuery):
"""
Creates a select query to a database based upon DbObjects
"""
def __init__(self, execute_filter=None):
DbQuery.__init__(self, execute_filter)
self.__select = []
self.__froms = []
self.__joins = []
self.__orderby = []
def select(self, *columns):
"""Specify one or more columns to select"""
self.__select += columns
def from_table(self, dbo_type, prefix):
"""Specify a table to select from"""
self.__froms.append((dbo_type, prefix))
def join(self, dbo_type, prefix, on):
"""Specify a table to join to"""
self.__joins.append((dbo_type, prefix, on))
def orderby(self, *columns):
"""Specify one or more columns to order by"""
self.__orderby += columns
@property
def sql(self):
query = "SELECT "
args = []
if len(self.__select) == 0:
raise DbQueryError("No selection in DbSelectQuery")
query += ','.join([col.prefix + "." + \
col.name for col in self.__select])
if len(self.__froms) == 0:
raise DbQueryError("No FROM clause in DbSelectQuery")
for table in self.__froms:
query += " FROM " + table[0].dbo_tablename + " " + table[1]
if len(self.__joins) > 0:
for join in self.__joins:
query += " JOIN " + join[0].dbo_tablename + " " + join[1] + \
" ON " + str(join[2])
query_parent = super(DbSelectQuery, self).sql
query += query_parent[0]
args += query_parent[1]
if len(self.__orderby) > 0:
query += " ORDER BY " + \
','.join([col.prefix + "." + \
col.name for col in self.__orderby])
return query,args
class DbInsertQuery(DbQuery):
"""
Creates an insert query to a database based upon DbObjects. This does not
include any where or limit expressions
"""
def __init__(self, dbo_type, prefix, execute_filter=None):
DbQuery.__init__(self, execute_filter)
self.table = (dbo_type, prefix)
self.__values = []
def value(self, column, value):
self.__values.append((column, value))
@property
def sql(self):
if len(self.__values) == 0:
raise DbQueryError("No values in insert")
tablename = self.table[0].dbo_tablename
query = "INSERT INTO {table} (".format(table=tablename)
args = [val[1] for val in self.__values \
if val[0].prefix == self.table[1]]
query += ",".join([val[0].name for val in self.__values \
if val[0].prefix == self.table[1]])
query += ") VALUES ("
query += ",".join(["%s" for x in args])
query += ")"
return query,args
class DbUpdateQuery(DbQuery):
"""
Creates an update query to a database based upon DbObjects
"""
def __init__(self, dbo_type, prefix, execute_filter=None):
"""
Initialize the update query
dbo_type: table type to be updating
prefix: Prefix the columns are known under
"""
DbQuery.__init__(self, execute_filter)
self.table = (dbo_type, prefix)
self.__updates = []
def update(self, left, right):
self.__updates.append((left, right))
@property
def sql(self):
if len(self.__updates) == 0:
raise DbQueryError("No update in DbUpdateQuery")
query = "UPDATE " + self.table[0].dbo_tablename + " " + self.table[1]
args = []
query += " SET "
for update in self.__updates:
if isinstance(update[0], DbQueryColumn):
query += update[0].prefix + "." + update[0].name
else:
query += "%s"
args.append(update[0])
query += "="
if isinstance(update[1], DbQueryColumn):
query += update[1].prefix + "." + update[1].name
else:
query += "%s"
args.append(update[1])
query_parent = super(DbUpdateQuery, self).sql
query += query_parent[0]
args += query_parent[1]
return query, args
class DbDeleteQuery(DbQuery):
"""
Creates a delete query for a database based on a DbObject
"""
def __init__(self, dbo_type, prefix, execute_filter=None):
DbQuery.__init__(self, execute_filter)
self.table = (dbo_type, prefix)
@property
def sql(self):
query = "DELETE FROM " + self.table[0].dbo_tablename + " " + \
self.table[1]
args = []
query_parent = super(DbDeleteQuery, self).sql
query += query_parent[0]
args += query_parent[1]
return query, args
class DbQueryExpression(object):
"""
Query expression created from columns, literals, and operators
"""
def __and__(self, other):
return DbQueryConjunction(self, other)
def __or__(self, other):
return DbQueryDisjunction(self, other)
def __str__(self):
raise NotImplementedError
@property
def arguments(self):
raise NotImplementedError
class DbQueryConjunction(DbQueryExpression):
"""
Query expression joining together a left and right expression with an
AND statement
"""
def __init__(self, l, r):
DbQueryExpression.__ini__(self)
self.l = l
self.r = r
def __str__(self):
return str(self.l) + " AND " + str(self.r)
@property
def arguments(self):
return self.l.arguments + self.r.arguments
class DbQueryDisjunction(DbQueryExpression):
"""
Query expression joining together a left and right expression with an
OR statement
"""
def __init__(self, l, r):
DbQueryExpression.__init__(self)
self.l = l
self.r = r
def __str__(self):
return str(self.r) + " OR " + str(self.r)
@property
def arguments(self):
return self.l.arguments + self.r.arguments
class DbQueryColumnComparison(DbQueryExpression):
"""
Query expression comparing a combination of a column and/or a value
"""
def __init__(self, l, op, r):
DbQueryExpression.__init__(self)
self.l = l
self.op = op
self.r = r
def __str__(self):
output = ""
if isinstance(self.l, DbQueryColumn):
prefix = self.l.prefix
if prefix is not None:
output += prefix + "."
output += self.l.name
elif self.l is None:
output += "NULL"
else:
output += "%s"
output += self.op
if isinstance(self.r, DbQueryColumn):
prefix = self.r.prefix
if prefix is not None:
output += prefix + "."
output += self.r.name
elif self.r is None:
output += "NULL"
else:
output += "%s"
return output
@property
def arguments(self):
output = []
if not isinstance(self.l, DbQueryColumn) and self.l is not None:
output.append(self.l)
if not isinstance(self.r, DbQueryColumn) and self.r is not None:
output.append(self.r)
return output
class DbQueryColumnSet(object):
"""
Represents a set of columns attached to a specific DbOject type. This
object dynamically builds itself based on a passed type. The columns
attached to this set may be used in DbQueries
"""
def __init__(self, dbo_type, prefix):
d = dbo_type.__dict__
self.__columns = {}
for attr in d:
obj = d[attr]
if isinstance(obj, Column):
column = DbQueryColumn(dbo_type, prefix, obj.name)
setattr(self, attr, column)
self.__columns[obj.name] = column
def __len__(self):
return len(self.__columns)
def __getitem__(self, key):
return self.__columns[key]
def __iter__(self):
return iter(self.__columns)
class DbQueryColumn(object):
"""
Represents a Column object used in a DbQuery
"""
def __init__(self, dbo_type, prefix, column_name):
self.dbo_type = dbo_type
self.name = column_name
self.prefix = prefix
def __lt__(self, other):
return DbQueryColumnComparison(self, "<", other)
def __le__(self, other):
return DbQueryColumnComparison(self, "<=", other)
def __eq__(self, other):
op = "="
if other is None:
op = " IS "
return DbQueryColumnComparison(self, op, other)
def __ne__(self, other):
op = "!="
if other is None:
op = " IS NOT "
return DbQueryColumnComparison(self, op, other)
def __gt__(self, other):
return DbQueryColumnComparison(self, ">", other)
def __ge__(self, other):
return DbQueryColumnComparison(self, ">=", other)
class Column(object):
"""
Column descriptor for a column
"""
def __init__(self, name, default=None, allow_none=False, mutable=True):
"""
Initializes a column
name: Name of the column this maps to
default: Default value
allow_none: Whether none (db null) values are allowed
mutable: Whether this can be mutated by a setter
"""
self.__name = name
self.__allow_none = allow_none
self.__mutable = mutable
self.__default = default
def validate(self, value):
"""
In a child class, this will validate values being set
"""
raise NotImplementedError
@property
def name(self):
return self.__name
@property
def allow_none(self):
return self.__allow_none
@property
def mutable(self):
return self.__mutable
@property
def default(self):
return self.__default
def __get__(self, owner, ownertype=None):
"""
Gets the value for this column for the passed owner
"""
if owner is None:
return self
if not isinstance(owner, ColumnSet):
raise TypeError("Columns are only allowed on ColumnSets")
return owner.get_column(self.name).value
def __set__(self, owner, value):
"""
Sets the value for this column for the passed owner
"""
if not isinstance(owner, ColumnSet):
raise TypeError("Columns are only allowed on ColumnSets")
owner.get_column(self.name).value = value
class StringColumn(Column):
def validate(self, value):
if value is None and self.allow_none:
print "nonevalue"
return True
if isinstance(value, basestring):
print "isstr"
return True
print "not string", value, type(value)
return False
class IntColumn(Column):
def validate(self, value):
if value is None and self.allow_none:
return True
if isinstance(value, int) or isinstance(value, long):
return True
return False
class PasswordColumn(Column):
def __init__(self, name, salt_function, default=None, allow_none=False, \
mutable=True):
"""
Create a new password column which uses the specified salt function
salt_function: a function(self, value) which returns the salted string
"""
Column.__init__(self, name, default, allow_none, mutable)
self.__salt_function = salt_function
def validate(self, value):
return True
def __set__(self, owner, value):
salted = self.__salt_function(owner, value)
super(PasswordColumn, self).__set__(owner, salted)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment