Created
February 22, 2010 17:04
-
-
Save sebmarkbage/311257 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Data; | |
using System.Reflection; | |
using System.Reflection.Emit; | |
using System.Linq.Expressions; | |
using System.Text; | |
public static class DbConnection | |
{ | |
private const string paramPrefix = "@"; | |
private static Dictionary<Type, Delegate> typeMappings = new Dictionary<Type,Delegate>(); | |
private static Dictionary<MethodInfo, object> compiledProjections = new Dictionary<MethodInfo, object>(); | |
public static string ParamPrefix | |
{ | |
get | |
{ | |
return paramPrefix; | |
} | |
} | |
public static void AddMapping<T>(Delegate projection) | |
{ | |
typeMappings.Add(typeof(T), projection); | |
} | |
public static void Work(this IDbConnection connection, Action method) | |
{ | |
if (connection.State == ConnectionState.Open) | |
{ | |
method(); | |
return; | |
} | |
connection.Open(); | |
var transaction = connection.BeginTransaction(); | |
try | |
{ | |
method(); | |
transaction.Commit(); | |
} | |
catch | |
{ | |
transaction.Rollback(); | |
throw; | |
} | |
finally | |
{ | |
transaction.Dispose(); | |
connection.Close(); | |
} | |
} | |
public static T Work<T>(this IDbConnection connection, Func<T> method) | |
{ | |
if (connection.State == ConnectionState.Open) | |
return method(); | |
connection.Open(); | |
var transaction = connection.BeginTransaction(); | |
try | |
{ | |
var result = method(); | |
transaction.Commit(); | |
return result; | |
} | |
catch | |
{ | |
transaction.Rollback(); | |
throw; | |
} | |
finally | |
{ | |
transaction.Dispose(); | |
connection.Close(); | |
} | |
} | |
private static Projection<T> GetCompiledProjection<T>(MethodInfo method) | |
{ | |
object projection; | |
if (compiledProjections.TryGetValue(method, out projection)) return (Projection<T>)projection; | |
var recordParam = Expression.Parameter(typeof(IDataRecord), "record"); | |
var targetParam = Expression.Parameter(typeof(object), "target"); | |
var parameters = method.GetParameters(); | |
var arguments = new Expression[parameters.Length]; | |
int index = 0; | |
for (int i = 0; i < parameters.Length; i++) | |
arguments[i] = GetMappingExpression(parameters[i].ParameterType, recordParam, ref index, i == parameters.Length - 1); | |
var callProjector = Expression.Call(Expression.TypeAs(targetParam, method.DeclaringType), method, arguments); | |
projection = Expression.Lambda<Projection<T>>(callProjector, recordParam, targetParam).Compile(); | |
compiledProjections.Add(method, projection); | |
return (Projection<T>)projection; | |
} | |
private static Expression GetMappingExpression(Type type, ParameterExpression recordParam, ref int index, bool isLast) | |
{ | |
Delegate mapper; | |
if (typeMappings.TryGetValue(type, out mapper)) | |
{ | |
var parameters = mapper.Method.GetParameters(); | |
var arguments = new Expression[parameters.Length]; | |
for (int i = 0; i < parameters.Length; i++) | |
arguments[i] = GetMappingExpression(parameters[i].ParameterType, recordParam, ref index, isLast && i == parameters.Length - 1); | |
return Expression.Call(Expression.Constant(mapper.Target), mapper.Method, arguments); | |
} | |
if (type.IsArray) | |
{ | |
if (type.GetElementType() == typeof(byte)) | |
return GetRecordReadingExpression(type, recordParam, Expression.Constant(index++)); | |
if (type.GetElementType() == typeof(char)) | |
return Expression.Call(GetRecordReadingExpression(typeof(string), recordParam, Expression.Constant(index++)), typeof(string).GetMethod("ToCharArray")); | |
if (!isLast) throw new Exception("An array cannot be used as in parameter in a projection unless it's the very last parameter."); | |
var elementType = type.GetElementType(); | |
if (elementType.IsArray) throw new Exception("Multidimensional arrays are not supported."); | |
var indexOffset = 0; | |
var indexParam = Expression.Parameter(typeof(int), "index"); | |
var elementMap = GetMappingArrayElementExpression(elementType, recordParam, indexParam, ref indexOffset); | |
var readArray = typeof(DbConnection).GetMethod("ReadArray", BindingFlags.NonPublic | BindingFlags.Static).MakeGenericMethod(elementType); | |
return Expression.Call( | |
readArray, recordParam, Expression.Constant(index), Expression.Constant(indexOffset), | |
Expression.Lambda(elementMap, recordParam, indexParam) | |
); | |
} | |
return GetRecordReadingExpression(type, recordParam, Expression.Constant(index++)); | |
} | |
private static Expression GetMappingArrayElementExpression(Type type, ParameterExpression recordParam, ParameterExpression indexParam, ref int indexOffset) | |
{ | |
Delegate mapper; | |
if (typeMappings.TryGetValue(type, out mapper)) | |
{ | |
var parameters = mapper.Method.GetParameters(); | |
var arguments = new Expression[parameters.Length]; | |
for (int i = 0; i < parameters.Length; i++) | |
arguments[i] = GetMappingArrayElementExpression(parameters[i].ParameterType, recordParam, indexParam, ref indexOffset); | |
return Expression.Call(Expression.Constant(mapper.Target), mapper.Method, arguments); | |
} | |
return GetRecordReadingExpression(type, recordParam, Expression.Add(indexParam, Expression.Constant(indexOffset++))); | |
} | |
private static Expression GetRecordReadingExpression(Type type, ParameterExpression recordParam, Expression index) | |
{ | |
var recordType = typeof(IDataRecord); | |
var isDBNull = recordType.GetMethod("IsDBNull"); | |
string valueMethod; | |
if (type == typeof(String)) valueMethod = "GetString"; | |
else if (type == typeof(byte)) valueMethod = "GetByte"; | |
else if (type == typeof(char)) valueMethod = "GetChar"; | |
else if (type == typeof(DateTime)) valueMethod = "GetDateTime"; | |
else if (type == typeof(Decimal)) valueMethod = "GetDecimal"; | |
else if (type == typeof(Double)) valueMethod = "GetDouble"; | |
else if (type == typeof(Single)) valueMethod = "GetFloat"; | |
else if (type == typeof(Guid)) valueMethod = "GetGuid"; | |
else if (type == typeof(Int16)) valueMethod = "GetInt16"; | |
else if (type == typeof(Int32)) valueMethod = "GetInt32"; | |
else if (type == typeof(Int64)) valueMethod = "GetInt64"; | |
else if (type == typeof(bool)) valueMethod = "GetBoolean"; | |
else valueMethod = null; | |
return Expression.Condition( | |
Expression.Call(recordParam, isDBNull, index), | |
type.IsValueType ? (Expression)Expression.New(type) : Expression.Constant(null, type), | |
valueMethod == null ? (Expression) Expression.TypeAs(Expression.Call(recordParam, recordType.GetMethod("GetValue"), index), type) : | |
Expression.Call(recordParam, recordType.GetMethod(valueMethod), index) | |
); | |
} | |
private static T[] ReadArray<T>(IDataRecord record, int firstIndex, int columns, Func<IDataRecord, int, T> projection) | |
{ | |
var result = new T[(record.FieldCount - firstIndex) / columns]; | |
for (int i = 0; i < result.Length; i++) | |
result[i] = projection(record, firstIndex + (i * columns)); | |
return result; | |
} | |
public static int Exec(this IDbConnection connection, string sql) | |
{ | |
var cmd = GetCommand(connection, sql, null); | |
if (connection.State != ConnectionState.Open) | |
{ | |
connection.Open(); | |
try | |
{ | |
return cmd.ExecuteNonQuery(); | |
} | |
finally | |
{ | |
connection.Close(); | |
} | |
} | |
return cmd.ExecuteNonQuery(); | |
} | |
public static int Exec(this IDbConnection connection, string sql, object parameters) | |
{ | |
var cmd = GetCommand(connection, sql, parameters); | |
if (connection.State != ConnectionState.Open) | |
{ | |
connection.Open(); | |
try | |
{ | |
return cmd.ExecuteNonQuery(); | |
} | |
finally | |
{ | |
connection.Close(); | |
} | |
} | |
return cmd.ExecuteNonQuery(); | |
} | |
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql, object parameters, Delegate projection) | |
{ | |
if (projection == null) projection = (Func<T, T>)(i => i); | |
var proj = GetCompiledProjection<T>(projection.Method); | |
var target = projection.Target; | |
var cmd = GetCommand(connection, sql, parameters); | |
if (connection.State != ConnectionState.Open) | |
{ | |
connection.Open(); | |
try | |
{ | |
var reader = cmd.ExecuteReader(); | |
try | |
{ | |
while (reader.Read()) yield return proj(reader, target); | |
} | |
finally | |
{ | |
reader.Close(); | |
} | |
} | |
finally | |
{ | |
connection.Close(); | |
} | |
} | |
else | |
{ | |
var reader = cmd.ExecuteReader(); | |
try | |
{ | |
while (reader.Read()) yield return proj(reader, target); | |
} | |
finally | |
{ | |
reader.Close(); | |
} | |
} | |
} | |
public static void AddMapping<T1, T>(Func<T1, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static void AddMapping<T1, T2, T>(Func<T1, T2, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static void AddMapping<T1, T2, T3, T>(Func<T1, T2, T3, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static void AddMapping<T1, T2, T3, T4, T>(Func<T1, T2, T3, T4, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static void AddMapping<T1, T2, T3, T4, T5, T>(Func<T1, T2, T3, T4, T5, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static void AddMapping<T1, T2, T3, T4, T5, T6, T>(Func<T1, T2, T3, T4, T5, T6, T> projection) | |
{ | |
AddMapping<T>(projection); | |
} | |
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql, object parameters) | |
{ | |
return Query<T>(connection, sql, parameters, null); | |
} | |
public static IEnumerable<T> Query<T1, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T5, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T6, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T5, T6, T> projection) | |
{ | |
return Query<T>(connection, sql, parameters, projection); | |
} | |
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql) | |
{ | |
return Query<T>(connection, sql, null, null); | |
} | |
public static IEnumerable<T> Query<T1, T>(this IDbConnection connection, string sql, Func<T1, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T>(this IDbConnection connection, string sql, Func<T1, T2, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T5, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T6, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T5, T6, T> projection) | |
{ | |
return Query<T>(connection, sql, null, projection); | |
} | |
public static void Insert<T>(this IDbConnection connection, string table, params T[] values) | |
{ | |
Insert<T>(connection, table, (IEnumerable<T>)values); | |
} | |
public static void Insert<T>(this IDbConnection connection, string table, IEnumerable<T> values) | |
{ | |
if (table == null) throw new ArgumentNullException("table"); | |
if (values == null) throw new ArgumentNullException("values"); | |
var open = connection.State != ConnectionState.Open; | |
if (open) connection.Open(); | |
try | |
{ | |
var cmd = connection.CreateCommand(); | |
StringBuilder sb = new StringBuilder("INSERT INTO "); | |
sb.Append(table); | |
sb.Append(" ("); | |
var properties = typeof(T).GetProperties(BindingFlags.Instance | BindingFlags.Public); | |
bool first = true; | |
foreach (var property in properties) | |
{ | |
if (!property.CanRead) continue; | |
if (first) first = false; else sb.Append(", "); | |
sb.Append(property.Name); | |
} | |
if (first) return; | |
sb.Append(") VALUES ("); | |
int i = 0; | |
foreach (var value in values) | |
{ | |
if (i > 0) sb.Append("),("); | |
first = true; | |
foreach (var property in properties) | |
{ | |
if (first) first = false; else sb.Append(", "); | |
sb.Append(paramPrefix); | |
sb.Append("i"); | |
sb.Append(i); | |
sb.Append(property.Name); | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramPrefix + "i" + i.ToString() + property.Name; | |
param.Value = property.GetValue(value, null); | |
cmd.Parameters.Add(param); | |
} | |
i++; | |
} | |
if (i == 0) return; | |
sb.Append(")"); | |
cmd.CommandText = sb.ToString(); | |
int result = cmd.ExecuteNonQuery(); | |
if (result != i) throw new DataException("Not all rows was inserted or the database caused unexpected updates."); | |
} | |
finally | |
{ | |
if (open) connection.Close(); | |
} | |
} | |
public static void InsertSingle<T>(this IDbConnection connection, string table, T values) | |
{ | |
Insert(connection, table, values); | |
} | |
public static int Update(this IDbConnection connection, string table, object set, object where) | |
{ | |
if (table == null) throw new ArgumentNullException("table"); | |
if (set == null) throw new ArgumentNullException("set"); | |
if (where == null) throw new ArgumentNullException("where"); | |
var open = connection.State != ConnectionState.Open; | |
if (open) connection.Open(); | |
try | |
{ | |
var cmd = connection.CreateCommand(); | |
StringBuilder sb = new StringBuilder("UPDATE "); | |
sb.Append(table); | |
sb.Append(" SET "); | |
var properties = set.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public); | |
bool first = true; | |
foreach (var property in properties) | |
{ | |
if (!property.CanRead) continue; | |
if (first) first = false; else sb.Append(", "); | |
sb.Append(property.Name); | |
sb.Append(" = "); | |
sb.Append(paramPrefix); | |
sb.Append("set"); | |
sb.Append(property.Name); | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramPrefix + "set" + property.Name; | |
param.Value = property.GetValue(set, null); | |
cmd.Parameters.Add(param); | |
} | |
if (first) return 0; | |
sb.Append(" WHERE "); | |
properties = where.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public); | |
first = true; | |
foreach (var property in properties) | |
{ | |
if (!property.CanRead) continue; | |
if (first) first = false; else sb.Append(" AND "); | |
sb.Append(property.Name); | |
sb.Append(" = "); | |
sb.Append(paramPrefix); | |
sb.Append("where"); | |
sb.Append(property.Name); | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramPrefix + "where" + property.Name; | |
param.Value = property.GetValue(where, null); | |
cmd.Parameters.Add(param); | |
} | |
cmd.CommandText = sb.ToString(); | |
return cmd.ExecuteNonQuery(); | |
} | |
finally | |
{ | |
if (open) connection.Close(); | |
} | |
} | |
public static int Delete(this IDbConnection connection, string table, object where) | |
{ | |
if (table == null) throw new ArgumentNullException("table"); | |
if (where == null) throw new ArgumentNullException("where"); | |
var open = connection.State != ConnectionState.Open; | |
if (open) connection.Open(); | |
try | |
{ | |
var cmd = connection.CreateCommand(); | |
StringBuilder sb = new StringBuilder("DELETE FROM "); | |
sb.Append(table); | |
sb.Append(" WHERE "); | |
var properties = where.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public); | |
var first = true; | |
foreach (var property in properties) | |
{ | |
if (!property.CanRead) continue; | |
if (first) first = false; else sb.Append(" AND "); | |
sb.Append(property.Name); | |
sb.Append(" = "); | |
sb.Append(paramPrefix); | |
sb.Append("where"); | |
sb.Append(property.Name); | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramPrefix + "where" + property.Name; | |
param.Value = property.GetValue(where, null); | |
cmd.Parameters.Add(param); | |
} | |
cmd.CommandText = sb.ToString(); | |
return cmd.ExecuteNonQuery(); | |
} | |
finally | |
{ | |
if (open) connection.Close(); | |
} | |
} | |
public static void UpdateSingle(this IDbConnection connection, string table, object set, object where) | |
{ | |
int result = Update(connection, table, set, where); | |
if (result == 0) | |
throw new DataException("No row was updated. Perhaps the row has been removed."); | |
else if (result > 1) | |
throw new DataException("More than one row was updated. Only a single updated row was expected."); | |
} | |
public static void DeleteSingle(this IDbConnection connection, string table, object where) | |
{ | |
int result = Delete(connection, table, where); | |
if (result > 1) throw new DataException("More than one row was deleted. Only a single row deletion was expected."); | |
} | |
public static void ExecSingle(this IDbConnection connection, string sql) | |
{ | |
ExecSingle(connection, sql, null); | |
} | |
public static void ExecSingle(this IDbConnection connection, string sql, object parameters) | |
{ | |
int result = Exec(connection, sql, parameters); | |
if (result == 0) | |
throw new DataException("No row was updated. Perhaps the row has been removed."); | |
else if (result > 1) | |
throw new DataException("More than one row was updated. Only a single updated row was expected."); | |
} | |
private static IDbCommand GetCommand(IDbConnection connection, string sql, object parameters) | |
{ | |
var cmd = connection.CreateCommand(); | |
cmd.CommandText = sql; | |
if (parameters == null) return cmd; | |
var type = parameters.GetType(); | |
foreach (var property in type.GetProperties(BindingFlags.Instance | BindingFlags.Public)) | |
{ | |
if (!property.CanRead) continue; | |
if (property.PropertyType.IsArray && property.PropertyType != typeof(byte[])) | |
{ | |
var addArrayProperties = typeof(DbConnection).GetMethod("AddArrayProperties").MakeGenericMethod(property.PropertyType.GetElementType()); | |
addArrayProperties.Invoke(null, new [] { paramPrefix + property.Name, cmd, property.GetValue(parameters, null) }); | |
} | |
else | |
{ | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramPrefix + property.Name; | |
param.Value = property.GetValue(parameters, null); | |
cmd.Parameters.Add(param); | |
} | |
} | |
return cmd; | |
} | |
private static void AddArrayProperties<T>(string paramName, IDbCommand cmd, T[] values) | |
{ | |
StringBuilder paramList = new StringBuilder("("); | |
for (var i = 0; i < values.Length; i++) | |
{ | |
var param = cmd.CreateParameter(); | |
param.ParameterName = paramName + "-" + i.ToString(); | |
param.Value = values[i]; | |
if (i > 0) paramList.Append(", "); | |
paramList.Append(param.ParameterName); | |
} | |
paramList.Append(")"); | |
cmd.CommandText = cmd.CommandText.Replace(paramName, paramList.ToString()); | |
} | |
private delegate T Projection<T>(IDataRecord record, object target); | |
public delegate TResult Func<T1, T2, T3, T4, T5, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5); | |
public delegate TResult Func<T1, T2, T3, T4, T5, T6, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6); | |
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7); | |
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, T8, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7, T8 arg8); | |
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7, T8 arg8, T9 arg9); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment