Skip to content

Instantly share code, notes, and snippets.

@jamesrcounts
Created August 11, 2014 21:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesrcounts/c40badcb2b85098affeb to your computer and use it in GitHub Desktop.
Save jamesrcounts/c40badcb2b85098affeb to your computer and use it in GitHub Desktop.
LINQ to SQL adapter for ApprovalTests
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Common;
using System.Data.Linq;
using System.Data.Linq.SqlClient;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using ApprovalUtilities.Persistence.Database;
/// <summary>
/// LINQ to SQL adapter.
/// </summary>
public class DataQueryAdaptor : IDatabaseToExecuteableQueryAdaptor
{
private readonly DataContext ctx;
private readonly IQueryable queryable;
public DataQueryAdaptor(IQueryable queryable, DataContext ctx)
{
this.queryable = queryable;
this.ctx = ctx;
}
public DbConnection GetConnection()
{
return this.ctx.Connection;
}
public string GetQuery()
{
var sqlProvider = this.InitializeProviderMode();
var queryInfo = this.BuildQuery(sqlProvider);
var parameters = GetParameters(queryInfo.FirstOrDefault());
return parameters.Aggregate(this.queryable.ToString(), (c, p) => c.Replace(p.Key, p.Value));
}
private static Func<object[], TReturn> BindMethod<TReturn>(object instance, string name, params Type[] types)
{
return objects =>
{
var method = instance.GetType()
.GetMethod(name, BindingFlags.NonPublic | BindingFlags.Instance, null, types, null);
return (TReturn)method.Invoke(instance, objects);
};
}
private static KeyValuePair<string, string> CreateParameter(object parameter)
{
var sql = GetFieldValue(parameter, "parameter");
var name = GetFieldValue(sql, "name");
var value = GetFieldValue(parameter, "value");
return new KeyValuePair<string, string>(name + "", string.Format("'{0}'", value));
}
private static object CreateSqlNodeAnnotations()
{
var assembly = Assembly.GetAssembly(typeof(SqlProvider));
var annotationType = assembly.GetType("System.Data.Linq.SqlClient.SqlNodeAnnotations");
return annotationType.GetConstructor(new Type[0]).Invoke(null);
}
private static object GetFieldValue(object instance, string name)
{
return GetFieldValue(instance, name, instance.GetType());
}
private static object GetFieldValue(object instance, string name, IReflect type)
{
var fieldInfo = type.GetField(name, BindingFlags.NonPublic | BindingFlags.Instance);
if (fieldInfo == null)
{
throw new InvalidOperationException(type + " has no field named: " + name);
}
return fieldInfo.GetValue(instance);
}
private static IEnumerable<KeyValuePair<string, string>> GetParameters(object queryInfo)
{
var parameters = (IEnumerable)GetFieldValue(queryInfo, "parameters");
return parameters.Cast<object>().Select(CreateParameter);
}
private IEnumerable<object> BuildQuery(object sqlProvider)
{
var sqlNodeAnnotations = CreateSqlNodeAnnotations();
return BindMethod<object[]>(sqlProvider, "BuildQuery", new[] { typeof(Expression), sqlNodeAnnotations.GetType() })
.Invoke(new[] { this.queryable.Expression, sqlNodeAnnotations });
}
private object InitializeProviderMode()
{
var sqlProvider = GetFieldValue(this.ctx, "provider", typeof(DataContext));
BindMethod<object>(sqlProvider, "InitializeProviderMode").Invoke(null);
return sqlProvider;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment