Created
March 26, 2013 14:59
-
-
Save kcuzner/5246020 to your computer and use it in GitHub Desktop.
Relatively simple database abstraction layer. http://kevincuzner.com/?p=261
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Database abstraction layer | |
by Kevin Cuzner | |
Designed for PEP 249-style cursor objects | |
Objects: | |
- Column: Base column descriptor which tracks changes and such for a DbObject. | |
It also does validation and other things. | |
- DbObject: Represents a table in the database. Instances of this object are | |
rows in the table. dbo_tablename and primary_key are required fields on the | |
class level | |
- DbQuery: The subclasses of this represent queries to the database | |
""" | |
class ColumnSet(object): | |
""" | |
Object which is updated by ColumnInstances to inform changes | |
""" | |
def __init__(self): | |
self.__columns = {} # columns are sorted by name | |
i_dict = type(self).__dict__ | |
for attr in i_dict: | |
obj = i_dict[attr] | |
if isinstance(obj, Column): | |
# we get an instance of this column | |
self.__columns[obj.name] = ColumnInstance(obj, self) | |
@property | |
def mutated(self): | |
""" | |
Returns the mutated columns for this tracker. | |
""" | |
output = [] | |
for name in self.__columns: | |
column = self.get_column(name) | |
if column.mutated: | |
output.append(column) | |
return output | |
def get_column(self, name): | |
return self.__columns[name] | |
class DbObject(ColumnSet): | |
""" | |
A DbObject is a set of columns linked to a table in the database. This is | |
synonomous to a row. The following class attributes must be set: | |
dbo_tablename : string table name | |
primary_key : Column for the primary key | |
""" | |
def __init__(self, **cols): | |
ColumnSet.__init__(self) | |
for name in cols: | |
c = self.get_column(name) | |
c.update(cols[name]) | |
@classmethod | |
def get_query_columns(self, prefix): | |
return DbQueryColumnSet(self, prefix) | |
@classmethod | |
def select(self, prefix): | |
""" | |
Returns a DbSelectQuery set up for this DbObject | |
""" | |
columns = self.get_query_columns(prefix) | |
def execute(query, cur): | |
output = [] | |
block = cur.fetchmany() | |
while len(block) > 0: | |
for row in block: | |
values = {} | |
i = 0 | |
for name in columns: | |
values[name] = row[i] | |
i += 1 | |
output.append(self(**values)) | |
block = cur.fetchmany() | |
return output | |
query = DbSelectQuery(execute) | |
query.select(*[columns[name] for name in columns]) | |
query.from_table(self, prefix) | |
return query, columns | |
def get_primary_key_name(self): | |
return type(self).__dict__['primary_key'].name | |
def save(self, cur): | |
""" | |
Saves any changes to this object to the database | |
""" | |
if self.primary_key is None: | |
# we need to be saved | |
columns = self.get_query_columns('x') | |
def execute(query, cur): | |
self.get_column(self.get_primary_key_name()\ | |
).update(cur.lastrowid) | |
selection = [] | |
for name in columns: | |
if name == self.get_primary_key_name(): | |
continue #we have no need to update the primary key | |
column_instance = self.get_column(name) | |
if not column_instance.column.mutable: | |
selection.append(columns[name]) | |
if len(selection) != 0: | |
# we get to select to get additional computed values | |
def execute2(query, cur): | |
row = cur.fetchone() | |
index = 0 | |
for s in selection: | |
self.get_column(s.name).update(row[index]) | |
index += 1 | |
return True | |
query = DbSelectQuery(execute2) | |
query.select(*selection) | |
query.from_table(type(self), 'x') | |
query.where(columns[self.get_primary_key_name()] == \ | |
self.get_column(self.get_primary_key_name()\ | |
).value) | |
return query.execute(cur) | |
return True | |
query = DbInsertQuery(type(self), 'x', execute) | |
for name in columns: | |
column_instance = self.get_column(name) | |
if not column_instance.column.mutable: | |
continue | |
query.value(columns[name], column_instance.value) | |
print query.sql | |
return query.execute(cur) | |
else: | |
# we have been modified | |
modified = self.mutated | |
if len(modified) == 0: | |
return True | |
columns = self.get_query_columns('x') | |
def execute(query, cur): | |
for mod in modified: | |
mod.update(mod.value) | |
return True | |
query = DbUpdateQuery(type(self), 'x', execute) | |
for mod in modified: | |
query.update(columns[mod.column.name], mod.value) | |
query.where(columns[self.get_primary_key_name()] == self.primary_key) | |
return query.execute(cur) | |
class ColumnInstance(object): | |
""" | |
Per-instance column data. This is used in ColumnSet objects to hold data | |
specific to that particular instance | |
""" | |
def __init__(self, column, owner): | |
""" | |
column: Column object this is created for | |
initial: Initial value | |
""" | |
self.__column = column | |
self.__owner = owner | |
self.update(column.default) | |
def update(self, value): | |
""" | |
Updates the value for this instance, resetting the mutated flag | |
""" | |
if value is None and not self.__column.allow_none: | |
raise ValueError("'None' is invalid for column '" + \ | |
self.__column.name + "'") | |
if self.__column.validate(value): | |
self.__value = value | |
self.__origvalue = value | |
else: | |
raise ValueError("'" + str(value) + "' is not valid for column '" + \ | |
self.__column.name + "'") | |
@property | |
def column(self): | |
return self.__column | |
@property | |
def owner(self): | |
return self.__owner | |
@property | |
def mutated(self): | |
return self.__value != self.__origvalue | |
@property | |
def value(self): | |
return self.__value | |
@value.setter | |
def value(self, value): | |
if value is None and not self.__column.allow_none: | |
raise ValueError("'None' is invalid for column '" + \ | |
self.__column.name + "'") | |
if not self.__column.mutable: | |
raise AttributeError("Column '" + self.__column.name + "' is not" + | |
" mutable") | |
if self.__column.validate(value): | |
self.__value = value | |
else: | |
raise ValueError("'" + value + "' is not valid for column '" + \ | |
self.__column.name + "'") | |
class DbQueryError(Exception): | |
""" | |
Raised when there is an error constructing a query | |
""" | |
def __init__(self, msg): | |
self.message = msg | |
def __str__(self): | |
return self.message | |
class DbQuery(object): | |
""" | |
Represents a base SQL Query to a database based upon some DbObjects | |
All of the methods implemented here are valid on select, update, and | |
delete statements. | |
""" | |
def __init__(self, execute_filter=None): | |
""" | |
callback: Function to call when the DbQuery is executed | |
""" | |
self.__where = [] | |
self.__limit = None | |
self.__orderby = [] | |
self.__execute_filter = execute_filter | |
def where(self, expression): | |
"""Specify an expression to append to the WHERE clause""" | |
self.__where.append(expression) | |
def limit(self, value=None): | |
"""Specify the limit to the query""" | |
self.__limit = value | |
@property | |
def sql(self): | |
query = "" | |
args = [] | |
if len(self.__where) > 0: | |
where = self.__where[0] | |
for clause in self.__where[1:]: | |
where = where & clause | |
args = where.arguments | |
query += " WHERE " + str(where) | |
if self.__limit is not None: | |
query += " LIMIT " + self.__limit | |
return query,args | |
def execute(self, cur): | |
""" | |
Executes this query on the passed cursor and returns either the result | |
of the filter function or the cursor if there is no filter function. | |
""" | |
query = self.sql | |
cur.execute(query[0], query[1]) | |
if self.__execute_filter: | |
return self.__execute_filter(self, cur) | |
else: | |
return cur | |
class DbSelectQuery(DbQuery): | |
""" | |
Creates a select query to a database based upon DbObjects | |
""" | |
def __init__(self, execute_filter=None): | |
DbQuery.__init__(self, execute_filter) | |
self.__select = [] | |
self.__froms = [] | |
self.__joins = [] | |
self.__orderby = [] | |
def select(self, *columns): | |
"""Specify one or more columns to select""" | |
self.__select += columns | |
def from_table(self, dbo_type, prefix): | |
"""Specify a table to select from""" | |
self.__froms.append((dbo_type, prefix)) | |
def join(self, dbo_type, prefix, on): | |
"""Specify a table to join to""" | |
self.__joins.append((dbo_type, prefix, on)) | |
def orderby(self, *columns): | |
"""Specify one or more columns to order by""" | |
self.__orderby += columns | |
@property | |
def sql(self): | |
query = "SELECT " | |
args = [] | |
if len(self.__select) == 0: | |
raise DbQueryError("No selection in DbSelectQuery") | |
query += ','.join([col.prefix + "." + \ | |
col.name for col in self.__select]) | |
if len(self.__froms) == 0: | |
raise DbQueryError("No FROM clause in DbSelectQuery") | |
for table in self.__froms: | |
query += " FROM " + table[0].dbo_tablename + " " + table[1] | |
if len(self.__joins) > 0: | |
for join in self.__joins: | |
query += " JOIN " + join[0].dbo_tablename + " " + join[1] + \ | |
" ON " + str(join[2]) | |
query_parent = super(DbSelectQuery, self).sql | |
query += query_parent[0] | |
args += query_parent[1] | |
if len(self.__orderby) > 0: | |
query += " ORDER BY " + \ | |
','.join([col.prefix + "." + \ | |
col.name for col in self.__orderby]) | |
return query,args | |
class DbInsertQuery(DbQuery): | |
""" | |
Creates an insert query to a database based upon DbObjects. This does not | |
include any where or limit expressions | |
""" | |
def __init__(self, dbo_type, prefix, execute_filter=None): | |
DbQuery.__init__(self, execute_filter) | |
self.table = (dbo_type, prefix) | |
self.__values = [] | |
def value(self, column, value): | |
self.__values.append((column, value)) | |
@property | |
def sql(self): | |
if len(self.__values) == 0: | |
raise DbQueryError("No values in insert") | |
tablename = self.table[0].dbo_tablename | |
query = "INSERT INTO {table} (".format(table=tablename) | |
args = [val[1] for val in self.__values \ | |
if val[0].prefix == self.table[1]] | |
query += ",".join([val[0].name for val in self.__values \ | |
if val[0].prefix == self.table[1]]) | |
query += ") VALUES (" | |
query += ",".join(["%s" for x in args]) | |
query += ")" | |
return query,args | |
class DbUpdateQuery(DbQuery): | |
""" | |
Creates an update query to a database based upon DbObjects | |
""" | |
def __init__(self, dbo_type, prefix, execute_filter=None): | |
""" | |
Initialize the update query | |
dbo_type: table type to be updating | |
prefix: Prefix the columns are known under | |
""" | |
DbQuery.__init__(self, execute_filter) | |
self.table = (dbo_type, prefix) | |
self.__updates = [] | |
def update(self, left, right): | |
self.__updates.append((left, right)) | |
@property | |
def sql(self): | |
if len(self.__updates) == 0: | |
raise DbQueryError("No update in DbUpdateQuery") | |
query = "UPDATE " + self.table[0].dbo_tablename + " " + self.table[1] | |
args = [] | |
query += " SET " | |
for update in self.__updates: | |
if isinstance(update[0], DbQueryColumn): | |
query += update[0].prefix + "." + update[0].name | |
else: | |
query += "%s" | |
args.append(update[0]) | |
query += "=" | |
if isinstance(update[1], DbQueryColumn): | |
query += update[1].prefix + "." + update[1].name | |
else: | |
query += "%s" | |
args.append(update[1]) | |
query_parent = super(DbUpdateQuery, self).sql | |
query += query_parent[0] | |
args += query_parent[1] | |
return query, args | |
class DbDeleteQuery(DbQuery): | |
""" | |
Creates a delete query for a database based on a DbObject | |
""" | |
def __init__(self, dbo_type, prefix, execute_filter=None): | |
DbQuery.__init__(self, execute_filter) | |
self.table = (dbo_type, prefix) | |
@property | |
def sql(self): | |
query = "DELETE FROM " + self.table[0].dbo_tablename + " " + \ | |
self.table[1] | |
args = [] | |
query_parent = super(DbDeleteQuery, self).sql | |
query += query_parent[0] | |
args += query_parent[1] | |
return query, args | |
class DbQueryExpression(object): | |
""" | |
Query expression created from columns, literals, and operators | |
""" | |
def __and__(self, other): | |
return DbQueryConjunction(self, other) | |
def __or__(self, other): | |
return DbQueryDisjunction(self, other) | |
def __str__(self): | |
raise NotImplementedError | |
@property | |
def arguments(self): | |
raise NotImplementedError | |
class DbQueryConjunction(DbQueryExpression): | |
""" | |
Query expression joining together a left and right expression with an | |
AND statement | |
""" | |
def __init__(self, l, r): | |
DbQueryExpression.__ini__(self) | |
self.l = l | |
self.r = r | |
def __str__(self): | |
return str(self.l) + " AND " + str(self.r) | |
@property | |
def arguments(self): | |
return self.l.arguments + self.r.arguments | |
class DbQueryDisjunction(DbQueryExpression): | |
""" | |
Query expression joining together a left and right expression with an | |
OR statement | |
""" | |
def __init__(self, l, r): | |
DbQueryExpression.__init__(self) | |
self.l = l | |
self.r = r | |
def __str__(self): | |
return str(self.r) + " OR " + str(self.r) | |
@property | |
def arguments(self): | |
return self.l.arguments + self.r.arguments | |
class DbQueryColumnComparison(DbQueryExpression): | |
""" | |
Query expression comparing a combination of a column and/or a value | |
""" | |
def __init__(self, l, op, r): | |
DbQueryExpression.__init__(self) | |
self.l = l | |
self.op = op | |
self.r = r | |
def __str__(self): | |
output = "" | |
if isinstance(self.l, DbQueryColumn): | |
prefix = self.l.prefix | |
if prefix is not None: | |
output += prefix + "." | |
output += self.l.name | |
elif self.l is None: | |
output += "NULL" | |
else: | |
output += "%s" | |
output += self.op | |
if isinstance(self.r, DbQueryColumn): | |
prefix = self.r.prefix | |
if prefix is not None: | |
output += prefix + "." | |
output += self.r.name | |
elif self.r is None: | |
output += "NULL" | |
else: | |
output += "%s" | |
return output | |
@property | |
def arguments(self): | |
output = [] | |
if not isinstance(self.l, DbQueryColumn) and self.l is not None: | |
output.append(self.l) | |
if not isinstance(self.r, DbQueryColumn) and self.r is not None: | |
output.append(self.r) | |
return output | |
class DbQueryColumnSet(object): | |
""" | |
Represents a set of columns attached to a specific DbOject type. This | |
object dynamically builds itself based on a passed type. The columns | |
attached to this set may be used in DbQueries | |
""" | |
def __init__(self, dbo_type, prefix): | |
d = dbo_type.__dict__ | |
self.__columns = {} | |
for attr in d: | |
obj = d[attr] | |
if isinstance(obj, Column): | |
column = DbQueryColumn(dbo_type, prefix, obj.name) | |
setattr(self, attr, column) | |
self.__columns[obj.name] = column | |
def __len__(self): | |
return len(self.__columns) | |
def __getitem__(self, key): | |
return self.__columns[key] | |
def __iter__(self): | |
return iter(self.__columns) | |
class DbQueryColumn(object): | |
""" | |
Represents a Column object used in a DbQuery | |
""" | |
def __init__(self, dbo_type, prefix, column_name): | |
self.dbo_type = dbo_type | |
self.name = column_name | |
self.prefix = prefix | |
def __lt__(self, other): | |
return DbQueryColumnComparison(self, "<", other) | |
def __le__(self, other): | |
return DbQueryColumnComparison(self, "<=", other) | |
def __eq__(self, other): | |
op = "=" | |
if other is None: | |
op = " IS " | |
return DbQueryColumnComparison(self, op, other) | |
def __ne__(self, other): | |
op = "!=" | |
if other is None: | |
op = " IS NOT " | |
return DbQueryColumnComparison(self, op, other) | |
def __gt__(self, other): | |
return DbQueryColumnComparison(self, ">", other) | |
def __ge__(self, other): | |
return DbQueryColumnComparison(self, ">=", other) | |
class Column(object): | |
""" | |
Column descriptor for a column | |
""" | |
def __init__(self, name, default=None, allow_none=False, mutable=True): | |
""" | |
Initializes a column | |
name: Name of the column this maps to | |
default: Default value | |
allow_none: Whether none (db null) values are allowed | |
mutable: Whether this can be mutated by a setter | |
""" | |
self.__name = name | |
self.__allow_none = allow_none | |
self.__mutable = mutable | |
self.__default = default | |
def validate(self, value): | |
""" | |
In a child class, this will validate values being set | |
""" | |
raise NotImplementedError | |
@property | |
def name(self): | |
return self.__name | |
@property | |
def allow_none(self): | |
return self.__allow_none | |
@property | |
def mutable(self): | |
return self.__mutable | |
@property | |
def default(self): | |
return self.__default | |
def __get__(self, owner, ownertype=None): | |
""" | |
Gets the value for this column for the passed owner | |
""" | |
if owner is None: | |
return self | |
if not isinstance(owner, ColumnSet): | |
raise TypeError("Columns are only allowed on ColumnSets") | |
return owner.get_column(self.name).value | |
def __set__(self, owner, value): | |
""" | |
Sets the value for this column for the passed owner | |
""" | |
if not isinstance(owner, ColumnSet): | |
raise TypeError("Columns are only allowed on ColumnSets") | |
owner.get_column(self.name).value = value | |
class StringColumn(Column): | |
def validate(self, value): | |
if value is None and self.allow_none: | |
print "nonevalue" | |
return True | |
if isinstance(value, basestring): | |
print "isstr" | |
return True | |
print "not string", value, type(value) | |
return False | |
class IntColumn(Column): | |
def validate(self, value): | |
if value is None and self.allow_none: | |
return True | |
if isinstance(value, int) or isinstance(value, long): | |
return True | |
return False | |
class PasswordColumn(Column): | |
def __init__(self, name, salt_function, default=None, allow_none=False, \ | |
mutable=True): | |
""" | |
Create a new password column which uses the specified salt function | |
salt_function: a function(self, value) which returns the salted string | |
""" | |
Column.__init__(self, name, default, allow_none, mutable) | |
self.__salt_function = salt_function | |
def validate(self, value): | |
return True | |
def __set__(self, owner, value): | |
salted = self.__salt_function(owner, value) | |
super(PasswordColumn, self).__set__(owner, salted) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment