Skip to content

Instantly share code, notes, and snippets.

@sebmarkbage
Created February 22, 2010 17:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sebmarkbage/311257 to your computer and use it in GitHub Desktop.
Save sebmarkbage/311257 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Data;
using System.Reflection;
using System.Reflection.Emit;
using System.Linq.Expressions;
using System.Text;
public static class DbConnection
{
private const string paramPrefix = "@";
private static Dictionary<Type, Delegate> typeMappings = new Dictionary<Type,Delegate>();
private static Dictionary<MethodInfo, object> compiledProjections = new Dictionary<MethodInfo, object>();
public static string ParamPrefix
{
get
{
return paramPrefix;
}
}
public static void AddMapping<T>(Delegate projection)
{
typeMappings.Add(typeof(T), projection);
}
public static void Work(this IDbConnection connection, Action method)
{
if (connection.State == ConnectionState.Open)
{
method();
return;
}
connection.Open();
var transaction = connection.BeginTransaction();
try
{
method();
transaction.Commit();
}
catch
{
transaction.Rollback();
throw;
}
finally
{
transaction.Dispose();
connection.Close();
}
}
public static T Work<T>(this IDbConnection connection, Func<T> method)
{
if (connection.State == ConnectionState.Open)
return method();
connection.Open();
var transaction = connection.BeginTransaction();
try
{
var result = method();
transaction.Commit();
return result;
}
catch
{
transaction.Rollback();
throw;
}
finally
{
transaction.Dispose();
connection.Close();
}
}
private static Projection<T> GetCompiledProjection<T>(MethodInfo method)
{
object projection;
if (compiledProjections.TryGetValue(method, out projection)) return (Projection<T>)projection;
var recordParam = Expression.Parameter(typeof(IDataRecord), "record");
var targetParam = Expression.Parameter(typeof(object), "target");
var parameters = method.GetParameters();
var arguments = new Expression[parameters.Length];
int index = 0;
for (int i = 0; i < parameters.Length; i++)
arguments[i] = GetMappingExpression(parameters[i].ParameterType, recordParam, ref index, i == parameters.Length - 1);
var callProjector = Expression.Call(Expression.TypeAs(targetParam, method.DeclaringType), method, arguments);
projection = Expression.Lambda<Projection<T>>(callProjector, recordParam, targetParam).Compile();
compiledProjections.Add(method, projection);
return (Projection<T>)projection;
}
private static Expression GetMappingExpression(Type type, ParameterExpression recordParam, ref int index, bool isLast)
{
Delegate mapper;
if (typeMappings.TryGetValue(type, out mapper))
{
var parameters = mapper.Method.GetParameters();
var arguments = new Expression[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
arguments[i] = GetMappingExpression(parameters[i].ParameterType, recordParam, ref index, isLast && i == parameters.Length - 1);
return Expression.Call(Expression.Constant(mapper.Target), mapper.Method, arguments);
}
if (type.IsArray)
{
if (type.GetElementType() == typeof(byte))
return GetRecordReadingExpression(type, recordParam, Expression.Constant(index++));
if (type.GetElementType() == typeof(char))
return Expression.Call(GetRecordReadingExpression(typeof(string), recordParam, Expression.Constant(index++)), typeof(string).GetMethod("ToCharArray"));
if (!isLast) throw new Exception("An array cannot be used as in parameter in a projection unless it's the very last parameter.");
var elementType = type.GetElementType();
if (elementType.IsArray) throw new Exception("Multidimensional arrays are not supported.");
var indexOffset = 0;
var indexParam = Expression.Parameter(typeof(int), "index");
var elementMap = GetMappingArrayElementExpression(elementType, recordParam, indexParam, ref indexOffset);
var readArray = typeof(DbConnection).GetMethod("ReadArray", BindingFlags.NonPublic | BindingFlags.Static).MakeGenericMethod(elementType);
return Expression.Call(
readArray, recordParam, Expression.Constant(index), Expression.Constant(indexOffset),
Expression.Lambda(elementMap, recordParam, indexParam)
);
}
return GetRecordReadingExpression(type, recordParam, Expression.Constant(index++));
}
private static Expression GetMappingArrayElementExpression(Type type, ParameterExpression recordParam, ParameterExpression indexParam, ref int indexOffset)
{
Delegate mapper;
if (typeMappings.TryGetValue(type, out mapper))
{
var parameters = mapper.Method.GetParameters();
var arguments = new Expression[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
arguments[i] = GetMappingArrayElementExpression(parameters[i].ParameterType, recordParam, indexParam, ref indexOffset);
return Expression.Call(Expression.Constant(mapper.Target), mapper.Method, arguments);
}
return GetRecordReadingExpression(type, recordParam, Expression.Add(indexParam, Expression.Constant(indexOffset++)));
}
private static Expression GetRecordReadingExpression(Type type, ParameterExpression recordParam, Expression index)
{
var recordType = typeof(IDataRecord);
var isDBNull = recordType.GetMethod("IsDBNull");
string valueMethod;
if (type == typeof(String)) valueMethod = "GetString";
else if (type == typeof(byte)) valueMethod = "GetByte";
else if (type == typeof(char)) valueMethod = "GetChar";
else if (type == typeof(DateTime)) valueMethod = "GetDateTime";
else if (type == typeof(Decimal)) valueMethod = "GetDecimal";
else if (type == typeof(Double)) valueMethod = "GetDouble";
else if (type == typeof(Single)) valueMethod = "GetFloat";
else if (type == typeof(Guid)) valueMethod = "GetGuid";
else if (type == typeof(Int16)) valueMethod = "GetInt16";
else if (type == typeof(Int32)) valueMethod = "GetInt32";
else if (type == typeof(Int64)) valueMethod = "GetInt64";
else if (type == typeof(bool)) valueMethod = "GetBoolean";
else valueMethod = null;
return Expression.Condition(
Expression.Call(recordParam, isDBNull, index),
type.IsValueType ? (Expression)Expression.New(type) : Expression.Constant(null, type),
valueMethod == null ? (Expression) Expression.TypeAs(Expression.Call(recordParam, recordType.GetMethod("GetValue"), index), type) :
Expression.Call(recordParam, recordType.GetMethod(valueMethod), index)
);
}
private static T[] ReadArray<T>(IDataRecord record, int firstIndex, int columns, Func<IDataRecord, int, T> projection)
{
var result = new T[(record.FieldCount - firstIndex) / columns];
for (int i = 0; i < result.Length; i++)
result[i] = projection(record, firstIndex + (i * columns));
return result;
}
public static int Exec(this IDbConnection connection, string sql)
{
var cmd = GetCommand(connection, sql, null);
if (connection.State != ConnectionState.Open)
{
connection.Open();
try
{
return cmd.ExecuteNonQuery();
}
finally
{
connection.Close();
}
}
return cmd.ExecuteNonQuery();
}
public static int Exec(this IDbConnection connection, string sql, object parameters)
{
var cmd = GetCommand(connection, sql, parameters);
if (connection.State != ConnectionState.Open)
{
connection.Open();
try
{
return cmd.ExecuteNonQuery();
}
finally
{
connection.Close();
}
}
return cmd.ExecuteNonQuery();
}
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql, object parameters, Delegate projection)
{
if (projection == null) projection = (Func<T, T>)(i => i);
var proj = GetCompiledProjection<T>(projection.Method);
var target = projection.Target;
var cmd = GetCommand(connection, sql, parameters);
if (connection.State != ConnectionState.Open)
{
connection.Open();
try
{
var reader = cmd.ExecuteReader();
try
{
while (reader.Read()) yield return proj(reader, target);
}
finally
{
reader.Close();
}
}
finally
{
connection.Close();
}
}
else
{
var reader = cmd.ExecuteReader();
try
{
while (reader.Read()) yield return proj(reader, target);
}
finally
{
reader.Close();
}
}
}
public static void AddMapping<T1, T>(Func<T1, T> projection)
{
AddMapping<T>(projection);
}
public static void AddMapping<T1, T2, T>(Func<T1, T2, T> projection)
{
AddMapping<T>(projection);
}
public static void AddMapping<T1, T2, T3, T>(Func<T1, T2, T3, T> projection)
{
AddMapping<T>(projection);
}
public static void AddMapping<T1, T2, T3, T4, T>(Func<T1, T2, T3, T4, T> projection)
{
AddMapping<T>(projection);
}
public static void AddMapping<T1, T2, T3, T4, T5, T>(Func<T1, T2, T3, T4, T5, T> projection)
{
AddMapping<T>(projection);
}
public static void AddMapping<T1, T2, T3, T4, T5, T6, T>(Func<T1, T2, T3, T4, T5, T6, T> projection)
{
AddMapping<T>(projection);
}
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql, object parameters)
{
return Query<T>(connection, sql, parameters, null);
}
public static IEnumerable<T> Query<T1, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T1, T2, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T5, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T6, T>(this IDbConnection connection, string sql, object parameters, Func<T1, T2, T3, T4, T5, T6, T> projection)
{
return Query<T>(connection, sql, parameters, projection);
}
public static IEnumerable<T> Query<T>(this IDbConnection connection, string sql)
{
return Query<T>(connection, sql, null, null);
}
public static IEnumerable<T> Query<T1, T>(this IDbConnection connection, string sql, Func<T1, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static IEnumerable<T> Query<T1, T2, T>(this IDbConnection connection, string sql, Func<T1, T2, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T5, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static IEnumerable<T> Query<T1, T2, T3, T4, T5, T6, T>(this IDbConnection connection, string sql, Func<T1, T2, T3, T4, T5, T6, T> projection)
{
return Query<T>(connection, sql, null, projection);
}
public static void Insert<T>(this IDbConnection connection, string table, params T[] values)
{
Insert<T>(connection, table, (IEnumerable<T>)values);
}
public static void Insert<T>(this IDbConnection connection, string table, IEnumerable<T> values)
{
if (table == null) throw new ArgumentNullException("table");
if (values == null) throw new ArgumentNullException("values");
var open = connection.State != ConnectionState.Open;
if (open) connection.Open();
try
{
var cmd = connection.CreateCommand();
StringBuilder sb = new StringBuilder("INSERT INTO ");
sb.Append(table);
sb.Append(" (");
var properties = typeof(T).GetProperties(BindingFlags.Instance | BindingFlags.Public);
bool first = true;
foreach (var property in properties)
{
if (!property.CanRead) continue;
if (first) first = false; else sb.Append(", ");
sb.Append(property.Name);
}
if (first) return;
sb.Append(") VALUES (");
int i = 0;
foreach (var value in values)
{
if (i > 0) sb.Append("),(");
first = true;
foreach (var property in properties)
{
if (first) first = false; else sb.Append(", ");
sb.Append(paramPrefix);
sb.Append("i");
sb.Append(i);
sb.Append(property.Name);
var param = cmd.CreateParameter();
param.ParameterName = paramPrefix + "i" + i.ToString() + property.Name;
param.Value = property.GetValue(value, null);
cmd.Parameters.Add(param);
}
i++;
}
if (i == 0) return;
sb.Append(")");
cmd.CommandText = sb.ToString();
int result = cmd.ExecuteNonQuery();
if (result != i) throw new DataException("Not all rows was inserted or the database caused unexpected updates.");
}
finally
{
if (open) connection.Close();
}
}
public static void InsertSingle<T>(this IDbConnection connection, string table, T values)
{
Insert(connection, table, values);
}
public static int Update(this IDbConnection connection, string table, object set, object where)
{
if (table == null) throw new ArgumentNullException("table");
if (set == null) throw new ArgumentNullException("set");
if (where == null) throw new ArgumentNullException("where");
var open = connection.State != ConnectionState.Open;
if (open) connection.Open();
try
{
var cmd = connection.CreateCommand();
StringBuilder sb = new StringBuilder("UPDATE ");
sb.Append(table);
sb.Append(" SET ");
var properties = set.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public);
bool first = true;
foreach (var property in properties)
{
if (!property.CanRead) continue;
if (first) first = false; else sb.Append(", ");
sb.Append(property.Name);
sb.Append(" = ");
sb.Append(paramPrefix);
sb.Append("set");
sb.Append(property.Name);
var param = cmd.CreateParameter();
param.ParameterName = paramPrefix + "set" + property.Name;
param.Value = property.GetValue(set, null);
cmd.Parameters.Add(param);
}
if (first) return 0;
sb.Append(" WHERE ");
properties = where.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public);
first = true;
foreach (var property in properties)
{
if (!property.CanRead) continue;
if (first) first = false; else sb.Append(" AND ");
sb.Append(property.Name);
sb.Append(" = ");
sb.Append(paramPrefix);
sb.Append("where");
sb.Append(property.Name);
var param = cmd.CreateParameter();
param.ParameterName = paramPrefix + "where" + property.Name;
param.Value = property.GetValue(where, null);
cmd.Parameters.Add(param);
}
cmd.CommandText = sb.ToString();
return cmd.ExecuteNonQuery();
}
finally
{
if (open) connection.Close();
}
}
public static int Delete(this IDbConnection connection, string table, object where)
{
if (table == null) throw new ArgumentNullException("table");
if (where == null) throw new ArgumentNullException("where");
var open = connection.State != ConnectionState.Open;
if (open) connection.Open();
try
{
var cmd = connection.CreateCommand();
StringBuilder sb = new StringBuilder("DELETE FROM ");
sb.Append(table);
sb.Append(" WHERE ");
var properties = where.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public);
var first = true;
foreach (var property in properties)
{
if (!property.CanRead) continue;
if (first) first = false; else sb.Append(" AND ");
sb.Append(property.Name);
sb.Append(" = ");
sb.Append(paramPrefix);
sb.Append("where");
sb.Append(property.Name);
var param = cmd.CreateParameter();
param.ParameterName = paramPrefix + "where" + property.Name;
param.Value = property.GetValue(where, null);
cmd.Parameters.Add(param);
}
cmd.CommandText = sb.ToString();
return cmd.ExecuteNonQuery();
}
finally
{
if (open) connection.Close();
}
}
public static void UpdateSingle(this IDbConnection connection, string table, object set, object where)
{
int result = Update(connection, table, set, where);
if (result == 0)
throw new DataException("No row was updated. Perhaps the row has been removed.");
else if (result > 1)
throw new DataException("More than one row was updated. Only a single updated row was expected.");
}
public static void DeleteSingle(this IDbConnection connection, string table, object where)
{
int result = Delete(connection, table, where);
if (result > 1) throw new DataException("More than one row was deleted. Only a single row deletion was expected.");
}
public static void ExecSingle(this IDbConnection connection, string sql)
{
ExecSingle(connection, sql, null);
}
public static void ExecSingle(this IDbConnection connection, string sql, object parameters)
{
int result = Exec(connection, sql, parameters);
if (result == 0)
throw new DataException("No row was updated. Perhaps the row has been removed.");
else if (result > 1)
throw new DataException("More than one row was updated. Only a single updated row was expected.");
}
private static IDbCommand GetCommand(IDbConnection connection, string sql, object parameters)
{
var cmd = connection.CreateCommand();
cmd.CommandText = sql;
if (parameters == null) return cmd;
var type = parameters.GetType();
foreach (var property in type.GetProperties(BindingFlags.Instance | BindingFlags.Public))
{
if (!property.CanRead) continue;
if (property.PropertyType.IsArray && property.PropertyType != typeof(byte[]))
{
var addArrayProperties = typeof(DbConnection).GetMethod("AddArrayProperties").MakeGenericMethod(property.PropertyType.GetElementType());
addArrayProperties.Invoke(null, new [] { paramPrefix + property.Name, cmd, property.GetValue(parameters, null) });
}
else
{
var param = cmd.CreateParameter();
param.ParameterName = paramPrefix + property.Name;
param.Value = property.GetValue(parameters, null);
cmd.Parameters.Add(param);
}
}
return cmd;
}
private static void AddArrayProperties<T>(string paramName, IDbCommand cmd, T[] values)
{
StringBuilder paramList = new StringBuilder("(");
for (var i = 0; i < values.Length; i++)
{
var param = cmd.CreateParameter();
param.ParameterName = paramName + "-" + i.ToString();
param.Value = values[i];
if (i > 0) paramList.Append(", ");
paramList.Append(param.ParameterName);
}
paramList.Append(")");
cmd.CommandText = cmd.CommandText.Replace(paramName, paramList.ToString());
}
private delegate T Projection<T>(IDataRecord record, object target);
public delegate TResult Func<T1, T2, T3, T4, T5, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5);
public delegate TResult Func<T1, T2, T3, T4, T5, T6, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6);
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7);
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, T8, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7, T8 arg8);
public delegate TResult Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5, T6 arg6, T7 arg7, T8 arg8, T9 arg9);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment