Created
August 11, 2014 21:54
-
-
Save jamesrcounts/c40badcb2b85098affeb to your computer and use it in GitHub Desktop.
LINQ to SQL adapter for ApprovalTests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Data.Common; | |
using System.Data.Linq; | |
using System.Data.Linq.SqlClient; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
using ApprovalUtilities.Persistence.Database; | |
/// <summary> | |
/// LINQ to SQL adapter. | |
/// </summary> | |
public class DataQueryAdaptor : IDatabaseToExecuteableQueryAdaptor | |
{ | |
private readonly DataContext ctx; | |
private readonly IQueryable queryable; | |
public DataQueryAdaptor(IQueryable queryable, DataContext ctx) | |
{ | |
this.queryable = queryable; | |
this.ctx = ctx; | |
} | |
public DbConnection GetConnection() | |
{ | |
return this.ctx.Connection; | |
} | |
public string GetQuery() | |
{ | |
var sqlProvider = this.InitializeProviderMode(); | |
var queryInfo = this.BuildQuery(sqlProvider); | |
var parameters = GetParameters(queryInfo.FirstOrDefault()); | |
return parameters.Aggregate(this.queryable.ToString(), (c, p) => c.Replace(p.Key, p.Value)); | |
} | |
private static Func<object[], TReturn> BindMethod<TReturn>(object instance, string name, params Type[] types) | |
{ | |
return objects => | |
{ | |
var method = instance.GetType() | |
.GetMethod(name, BindingFlags.NonPublic | BindingFlags.Instance, null, types, null); | |
return (TReturn)method.Invoke(instance, objects); | |
}; | |
} | |
private static KeyValuePair<string, string> CreateParameter(object parameter) | |
{ | |
var sql = GetFieldValue(parameter, "parameter"); | |
var name = GetFieldValue(sql, "name"); | |
var value = GetFieldValue(parameter, "value"); | |
return new KeyValuePair<string, string>(name + "", string.Format("'{0}'", value)); | |
} | |
private static object CreateSqlNodeAnnotations() | |
{ | |
var assembly = Assembly.GetAssembly(typeof(SqlProvider)); | |
var annotationType = assembly.GetType("System.Data.Linq.SqlClient.SqlNodeAnnotations"); | |
return annotationType.GetConstructor(new Type[0]).Invoke(null); | |
} | |
private static object GetFieldValue(object instance, string name) | |
{ | |
return GetFieldValue(instance, name, instance.GetType()); | |
} | |
private static object GetFieldValue(object instance, string name, IReflect type) | |
{ | |
var fieldInfo = type.GetField(name, BindingFlags.NonPublic | BindingFlags.Instance); | |
if (fieldInfo == null) | |
{ | |
throw new InvalidOperationException(type + " has no field named: " + name); | |
} | |
return fieldInfo.GetValue(instance); | |
} | |
private static IEnumerable<KeyValuePair<string, string>> GetParameters(object queryInfo) | |
{ | |
var parameters = (IEnumerable)GetFieldValue(queryInfo, "parameters"); | |
return parameters.Cast<object>().Select(CreateParameter); | |
} | |
private IEnumerable<object> BuildQuery(object sqlProvider) | |
{ | |
var sqlNodeAnnotations = CreateSqlNodeAnnotations(); | |
return BindMethod<object[]>(sqlProvider, "BuildQuery", new[] { typeof(Expression), sqlNodeAnnotations.GetType() }) | |
.Invoke(new[] { this.queryable.Expression, sqlNodeAnnotations }); | |
} | |
private object InitializeProviderMode() | |
{ | |
var sqlProvider = GetFieldValue(this.ctx, "provider", typeof(DataContext)); | |
BindMethod<object>(sqlProvider, "InitializeProviderMode").Invoke(null); | |
return sqlProvider; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment