Created
August 30, 2012 12:41
-
-
Save erik-kallen/3527712 to your computer and use it in GitHub Desktop.
SqlProxy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<#@ include file="SqlProxy.ttinc" #> | |
<#+ | |
static string connectionString = "Server=.; Integrated Security=true; Database=AdventureWorks"; | |
static string namespaceName = "DAL"; | |
static string databaseClassName = "DatabaseProxy"; | |
static bool includeProcedures = true; | |
static bool includeViews = true; | |
static bool includeTables = false; | |
static int commandTimeout = 600; | |
static bool generateSqlclr = false; | |
static SchemaAndName[] procsToInclude = null; | |
static SchemaAndName[] procsToExclude = new SchemaAndName[0]; | |
static string accessibility = "public"; | |
static string[] schemaNames = new[] { "dbo", "HumanResources" }; | |
static string defaultSchemaName = "dbo"; | |
static SchemaAndName[] viewsToInclude = null; | |
static SchemaAndName[] viewsToExclude = null; | |
static Dictionary<SchemaAndName, object[]> customExtractions | |
= new Dictionary<SchemaAndName, object[]>() {}; | |
static Dictionary<string, string> namedQueries = static Dictionary<string, string> namedQueries = new Dictionary<string, string> { { "EmployeesByBirthDate", "SELECT e.EmployeeID, e.HireDate, c.FirstName, c.LastName, c.Phone FROM HumanResources.Employee e JOIN Person.Contact c ON c.ContactID = e.ContactID WHERE e.BirthDate BETWEEN @{minDate: DateTime} AND @{maxDate: DateTime}" } }; | |
; | |
#> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Text; | |
using System.Collections.ObjectModel; | |
using System.Data.SqlClient; | |
using System.Collections; | |
using System.Data.SqlTypes; | |
using System.Xml; | |
using System.IO; | |
#if !SQLPROXY_SQLCLR | |
using System.Linq; | |
using System.Linq.Expressions; | |
#endif | |
namespace SqlProxy { | |
#if SQLPROXY_SQLCLR | |
public delegate TResult Func<TResult>(); | |
public delegate TResult Func<T, TResult>(T arg1); | |
public delegate TResult Func<T1, T2, TResult>(T1 arg1, T2 arg2); | |
public delegate TResult Func<T1, T2, T3, TResult>(T1 arg1, T2 arg2, T3 arg3); | |
public delegate TResult Func<T1, T2, T3, T4, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4); | |
#endif | |
public delegate TResult Func<T1, T2, T3, T4, T5, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5); | |
#if !SQLPROXY_SQLCLR | |
[AttributeUsage(AttributeTargets.Property)] | |
internal sealed class ColumnNameAttribute : Attribute { | |
public string Name { get; set; } | |
public ColumnNameAttribute() { | |
} | |
public ColumnNameAttribute(string name) { | |
this.Name = name; | |
} | |
} | |
[AttributeUsage(AttributeTargets.Method)] | |
internal sealed class SqlFunctionAttribute : Attribute { | |
public enum FunctionType { | |
Function, | |
Case, | |
IsNotNull, | |
IsNull, | |
NoOp, | |
Like, | |
NotLike, | |
In, | |
NotIn | |
} | |
public FunctionType Type { get; private set; } | |
public string Name { get; private set; } | |
public string MagicFirstArg { get; set; } | |
public bool ParamList { get; set; } | |
public SqlFunctionAttribute(string name) { | |
this.Type = FunctionType.Function; | |
this.Name = name; | |
this.ParamList = true; | |
} | |
public SqlFunctionAttribute(FunctionType type) { | |
this.Type = type; | |
} | |
} | |
#endif // !SQLPROXY_SQLCLR | |
public static class SqlFunctions { | |
#if !SQLPROXY_SQLCLR | |
[SqlFunction("SYSDATETIME")] public static DateTime SysDateTime() { throw new Exception("Don't call"); } | |
[SqlFunction("SYSDATETIMEOFFSET")] public static DateTime SysDateTimeOffset() { throw new Exception("Don't call"); } | |
[SqlFunction("SYSUTCDATETIME")] public static DateTime SysUtcDateTime() { throw new Exception("Don't call"); } | |
[SqlFunction("CURRENT_TIMESTAMP", ParamList = false)] public static DateTime CurrentTimestamp() { throw new Exception("Don't call"); } | |
[SqlFunction("GETDATE")] public static DateTime GetDate() { throw new Exception("Don't call"); } | |
[SqlFunction("GETUTCDATE")] public static DateTime GetUtcDate() { throw new Exception("Don't call"); } | |
[SqlFunction("DATENAME")] public static string DateName(string datePart, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEPART")] public static int DatePart(string datePart, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DAY")] public static int Day(DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("MONTH")] public static int Month(DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("YEAR")] public static int Year(DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF")] public static int DateDiff(string datePart, DateTime startDate, DateTime endDate) { throw new Exception("Don't call"); } | |
[SqlFunction("SWITCHOFFSET")] public static TimeSpan SwitchOffset(TimeSpan dateTimeOffset, string timeZone) { throw new Exception("Don't call"); } | |
[SqlFunction("TODATETIMEOFFSET")] public static TimeSpan ToDateTimeOffset(DateTime dateTime, string timeZone) { throw new Exception("Don't call"); } | |
[SqlFunction("@@DATEFIRST", ParamList = false)] public static byte DateFirst() { throw new Exception("Don't call"); } | |
[SqlFunction("ISDATE")] public static int IsDate(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("ABS")] public static T Abs<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("ACOS")] public static T Acos<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("ASIN")] public static T Asin<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("ATAN")] public static T Atan<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("ATN2")] public static T Atn2<T, U>(T expr1, U expr2) { throw new Exception("Don't call"); } | |
[SqlFunction("CEILING")] public static T Ceiling<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("DEGREES")] public static T Degrees<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("EXP")] public static T Exp<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("FLOOR")] public static T Floor<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("LOG")] public static T Log<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("LOG10")] public static T Log10<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("PI")] public static double Pi() { throw new Exception("Don't call"); } | |
[SqlFunction("POWER")] public static T Power<T, U>(T b, U e) { throw new Exception("Don't call"); } | |
[SqlFunction("RADIANS")] public static T Radians<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("RAND")] public static double Rand() { throw new Exception("Don't call"); } | |
[SqlFunction("RAND")] public static double Rand(int seed) { throw new Exception("Don't call"); } | |
[SqlFunction("ROUND")] public static T Round<T>(T expression, int length) { throw new Exception("Don't call"); } | |
[SqlFunction("ROUND")] public static T Round<T>(T expression, int length, int function) { throw new Exception("Don't call"); } | |
[SqlFunction("SIGN")] public static T Sign<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("SIN")] public static T Sin<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("SQRT")] public static T Sqrt<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("SQUARE")] public static T Square<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("TAN")] public static T Tan<T>(T expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CURRENT_USER", ParamList = false)] public static string CurrentUser() { throw new Exception("Don't call"); } | |
[SqlFunction("USER_NAME")] public static string UserName(int id) { throw new Exception("Don't call"); } | |
[SqlFunction("SESSION_USER", ParamList = false)] public static string SessionUser() { throw new Exception("Don't call"); } | |
[SqlFunction("IS_MEMBER")] public static int? IsMember(string group) { throw new Exception("Don't call"); } | |
[SqlFunction("ASCII")] public static int Ascii(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CHAR")] public static string Char(int expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CHARINDEX")] public static int CharIndex(string expression1, string expression2) { throw new Exception("Don't call"); } | |
[SqlFunction("CHARINDEX")] public static int CharIndex(string expression1, string expression2, long startLocation) { throw new Exception("Don't call"); } | |
[SqlFunction("DIFFERENCE")] public static int Difference(string expression1, string expression2) { throw new Exception("Don't call"); } | |
[SqlFunction("LEFT")] public static string Left(string str, long count) { throw new Exception("Don't call"); } | |
[SqlFunction("LEN")] public static int Len(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("LOWER")] public static string Lower(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("LTRIM")] public static string LTrim(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("NCHAR")] public static string NChar(int expression) { throw new Exception("Don't call"); } | |
[SqlFunction("PATINDEX")] public static long PatIndex(string pattern, string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("QUOTENAME")] public static string QuoteName(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("QUOTENAME")] public static string QuoteName(string expression, string quoteChar) { throw new Exception("Don't call"); } | |
[SqlFunction("REPLACE")] public static string Replace(string expression, string find, string replace) { throw new Exception("Don't call"); } | |
[SqlFunction("REPLICATE")] public static string Replicate(string expression, int count) { throw new Exception("Don't call"); } | |
[SqlFunction("REVERSE")] public static string Reverse(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("RIGHT")] public static string Right(string str, long count) { throw new Exception("Don't call"); } | |
[SqlFunction("RTRIM")] public static string RTrim(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("SOUNDEX")] public static string SoundEx(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("SPACE")] public static string Space(int count) { throw new Exception("Don't call"); } | |
[SqlFunction("STR")] public static string Str(double expression) { throw new Exception("Don't call"); } | |
[SqlFunction("STR")] public static string Str(double expression, int length) { throw new Exception("Don't call"); } | |
[SqlFunction("STR")] public static string Str(double expression, int length, int dec) { throw new Exception("Don't call"); } | |
[SqlFunction("STUFF")] public static string Stuff(double expr1, int start, int length, string expr2) { throw new Exception("Don't call"); } | |
[SqlFunction("SUBSTRING")] public static string Substring(string expression, int start, int length) { throw new Exception("Don't call"); } | |
[SqlFunction("UNICODE")] public static int Unicode(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("UPPER")] public static string Upper(string expression) { throw new Exception("Don't call"); } | |
[SqlFunction("APP_NAME")] public static string AppName() { throw new Exception("Don't call"); } | |
[SqlFunction("COALESCE")] public static T Coalesce<T>(params T[] values) { throw new Exception("Don't call"); } | |
[SqlFunction("ISNUMERIC")] public static int IsNumeric(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("NEWID")] public static Guid NewId(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("NULLIF")] public static T? NullIfV<T>(T expr1, T expr2) where T : struct { throw new Exception("Don't call"); } | |
[SqlFunction("NULLIF")] public static T NullIfR<T>(T expr1, T expr2) where T : class { throw new Exception("Don't call"); } | |
[SqlFunction("@@LANGUAGE", ParamList = false)] public static string Language() { throw new Exception("Don't call"); } | |
[SqlFunction("ISNULL")] public static bool IsNull<T>(T expression, T nullvalue) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.Case)] public static object Case(bool test1, object case1, params object[] otherFiltersAndCases) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.IsNull)] public static bool IsNull(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.IsNotNull)] public static bool IsNotNull(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "yy")] public static DateTime DateAddYy(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "q")] public static DateTime DateAddQ(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "m")] public static DateTime DateAddM(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "dy")] public static DateTime DateAddDy(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "d")] public static DateTime DateAddD(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "ww")] public static DateTime DateAddWw(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "dw")] public static DateTime DateAddDw(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "hh")] public static DateTime DateAddHh(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "mi")] public static DateTime DateAddMi(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "s")] public static DateTime DateAddS(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "ms")] public static DateTime DateAddMs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "mcs")] public static DateTime DateAddMcs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEADD", MagicFirstArg = "ns")] public static DateTime DateAddNs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "yy")] public static DateTime DateDiffYy(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "q")] public static DateTime DateDiffQ(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "m")] public static DateTime DateDiffM(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "dy")] public static DateTime DateDiffDy(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "d")] public static DateTime DateDiffD(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "ww")] public static DateTime DateDiffWw(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "dw")] public static DateTime DateDiffDw(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "hh")] public static DateTime DateDiffHh(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "mi")] public static DateTime DateDiffMi(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "s")] public static DateTime DateDiffS(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "ms")] public static DateTime DateDiffMs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "mcs")] public static DateTime DateDiffMcs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("DATEDIFF", MagicFirstArg = "ns")] public static DateTime DateDiffNs(int number, DateTime date) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "varchar")] public static string ConvertToString(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "varchar")] public static string ConvertToString(object expression, int style) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "int")] public static string ConvertToInt(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "int")] public static string ConvertToInt(object expression, int style) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "datetime")] public static string ConvertToDateTime(object expression) { throw new Exception("Don't call"); } | |
[SqlFunction("CONVERT", MagicFirstArg = "datetime")] public static string ConvertToDateTime(object expression, int style) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.NoOp)] public static T ImplicitConvert<T>(object o) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.Like)] public static bool Like(string expression, string pattern) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.NotLike)] public static bool NotLike(string expression, string pattern) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.In)] public static bool In<T>(T expr, params T[] values) { throw new Exception("Don't call"); } | |
[SqlFunction(SqlFunctionAttribute.FunctionType.NotIn)] public static bool NotIn<T>(T expr, params T[] values) { throw new Exception("Don't call"); } | |
#endif // !SQLPROXY_SQLCLR | |
public static T DbNullCast<T>(object o) { | |
if (typeof(T) == typeof(SqlXml)) | |
return ((SqlXml)o).IsNull ? default(T) : (T)o; | |
else | |
return o is DBNull ? default(T) : (T)o; | |
} | |
public static SqlXml LoadXml(string xml) { | |
return xml != null ? new SqlXml(XmlReader.Create(new StringReader(xml), new XmlReaderSettings { CloseInput = true })) : null; | |
} | |
public static XmlDocument LoadSqlXml(SqlXml xml) { | |
var d = new XmlDocument(); | |
using (var rdr = xml.CreateReader()) { | |
d.Load(rdr); | |
} | |
return d; | |
} | |
} | |
#if !SQLPROXY_SQLCLR | |
internal abstract class ExpressionVisitor { | |
protected ExpressionVisitor() { | |
} | |
protected virtual Expression Visit(Expression exp) { | |
if (exp == null) | |
return exp; | |
switch (exp.NodeType) { | |
case ExpressionType.Negate: | |
case ExpressionType.NegateChecked: | |
case ExpressionType.Not: | |
case ExpressionType.Convert: | |
case ExpressionType.ConvertChecked: | |
case ExpressionType.ArrayLength: | |
case ExpressionType.Quote: | |
case ExpressionType.TypeAs: | |
return this.VisitUnary((UnaryExpression)exp); | |
case ExpressionType.Add: | |
case ExpressionType.AddChecked: | |
case ExpressionType.Subtract: | |
case ExpressionType.SubtractChecked: | |
case ExpressionType.Multiply: | |
case ExpressionType.MultiplyChecked: | |
case ExpressionType.Divide: | |
case ExpressionType.Modulo: | |
case ExpressionType.And: | |
case ExpressionType.AndAlso: | |
case ExpressionType.Or: | |
case ExpressionType.OrElse: | |
case ExpressionType.LessThan: | |
case ExpressionType.LessThanOrEqual: | |
case ExpressionType.GreaterThan: | |
case ExpressionType.GreaterThanOrEqual: | |
case ExpressionType.Equal: | |
case ExpressionType.NotEqual: | |
case ExpressionType.Coalesce: | |
case ExpressionType.ArrayIndex: | |
case ExpressionType.RightShift: | |
case ExpressionType.LeftShift: | |
case ExpressionType.ExclusiveOr: | |
return this.VisitBinary((BinaryExpression)exp); | |
case ExpressionType.TypeIs: | |
return this.VisitTypeIs((TypeBinaryExpression)exp); | |
case ExpressionType.Conditional: | |
return this.VisitConditional((ConditionalExpression)exp); | |
case ExpressionType.Constant: | |
return this.VisitConstant((ConstantExpression)exp); | |
case ExpressionType.Parameter: | |
return this.VisitParameter((ParameterExpression)exp); | |
case ExpressionType.MemberAccess: | |
return this.VisitMemberAccess((MemberExpression)exp); | |
case ExpressionType.Call: | |
return this.VisitMethodCall((MethodCallExpression)exp); | |
case ExpressionType.Lambda: | |
return this.VisitLambda((LambdaExpression)exp); | |
case ExpressionType.New: | |
return this.VisitNew((NewExpression)exp); | |
case ExpressionType.NewArrayInit: | |
case ExpressionType.NewArrayBounds: | |
return this.VisitNewArray((NewArrayExpression)exp); | |
case ExpressionType.Invoke: | |
return this.VisitInvocation((InvocationExpression)exp); | |
case ExpressionType.MemberInit: | |
return this.VisitMemberInit((MemberInitExpression)exp); | |
case ExpressionType.ListInit: | |
return this.VisitListInit((ListInitExpression)exp); | |
default: | |
throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType)); | |
} | |
} | |
protected virtual MemberBinding VisitBinding(MemberBinding binding) { | |
switch (binding.BindingType) { | |
case MemberBindingType.Assignment: | |
return this.VisitMemberAssignment((MemberAssignment)binding); | |
case MemberBindingType.MemberBinding: | |
return this.VisitMemberMemberBinding((MemberMemberBinding)binding); | |
case MemberBindingType.ListBinding: | |
return this.VisitMemberListBinding((MemberListBinding)binding); | |
default: | |
throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType)); | |
} | |
} | |
protected virtual ElementInit VisitElementInitializer(ElementInit initializer) { | |
ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments); | |
if (arguments != initializer.Arguments) { | |
return Expression.ElementInit(initializer.AddMethod, arguments); | |
} | |
return initializer; | |
} | |
protected virtual Expression VisitUnary(UnaryExpression u) { | |
Expression operand = this.Visit(u.Operand); | |
if (operand != u.Operand) { | |
return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method); | |
} | |
return u; | |
} | |
protected virtual Expression VisitBinary(BinaryExpression b) { | |
Expression left = this.Visit(b.Left); | |
Expression right = this.Visit(b.Right); | |
Expression conversion = this.Visit(b.Conversion); | |
if (left != b.Left || right != b.Right || conversion != b.Conversion) { | |
if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null) | |
return Expression.Coalesce(left, right, conversion as LambdaExpression); | |
else | |
return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method); | |
} | |
return b; | |
} | |
protected virtual Expression VisitTypeIs(TypeBinaryExpression b) { | |
Expression expr = this.Visit(b.Expression); | |
if (expr != b.Expression) { | |
return Expression.TypeIs(expr, b.TypeOperand); | |
} | |
return b; | |
} | |
protected virtual Expression VisitConstant(ConstantExpression c) { | |
return c; | |
} | |
protected virtual Expression VisitConditional(ConditionalExpression c) { | |
Expression test = this.Visit(c.Test); | |
Expression ifTrue = this.Visit(c.IfTrue); | |
Expression ifFalse = this.Visit(c.IfFalse); | |
if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) { | |
return Expression.Condition(test, ifTrue, ifFalse); | |
} | |
return c; | |
} | |
protected virtual Expression VisitParameter(ParameterExpression p) { | |
return p; | |
} | |
protected virtual Expression VisitMemberAccess(MemberExpression m) { | |
Expression exp = this.Visit(m.Expression); | |
if (exp != m.Expression) { | |
return Expression.MakeMemberAccess(exp, m.Member); | |
} | |
return m; | |
} | |
protected virtual Expression VisitMethodCall(MethodCallExpression m) { | |
Expression obj = this.Visit(m.Object); | |
IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments); | |
if (obj != m.Object || args != m.Arguments) { | |
return Expression.Call(obj, m.Method, args); | |
} | |
return m; | |
} | |
protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) { | |
List<Expression> list = null; | |
for (int i = 0, n = original.Count; i < n; i++) { | |
Expression p = this.Visit(original[i]); | |
if (list != null) { | |
list.Add(p); | |
} | |
else if (p != original[i]) { | |
list = new List<Expression>(n); | |
for (int j = 0; j < i; j++) { | |
list.Add(original[j]); | |
} | |
list.Add(p); | |
} | |
} | |
if (list != null) { | |
return list.AsReadOnly(); | |
} | |
return original; | |
} | |
protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) { | |
Expression e = this.Visit(assignment.Expression); | |
if (e != assignment.Expression) { | |
return Expression.Bind(assignment.Member, e); | |
} | |
return assignment; | |
} | |
protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) { | |
IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings); | |
if (bindings != binding.Bindings) { | |
return Expression.MemberBind(binding.Member, bindings); | |
} | |
return binding; | |
} | |
protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) { | |
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers); | |
if (initializers != binding.Initializers) { | |
return Expression.ListBind(binding.Member, initializers); | |
} | |
return binding; | |
} | |
protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) { | |
List<MemberBinding> list = null; | |
for (int i = 0, n = original.Count; i < n; i++) { | |
MemberBinding b = this.VisitBinding(original[i]); | |
if (list != null) { | |
list.Add(b); | |
} | |
else if (b != original[i]) { | |
list = new List<MemberBinding>(n); | |
for (int j = 0; j < i; j++) { | |
list.Add(original[j]); | |
} | |
list.Add(b); | |
} | |
} | |
if (list != null) | |
return list; | |
return original; | |
} | |
protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) { | |
List<ElementInit> list = null; | |
for (int i = 0, n = original.Count; i < n; i++) { | |
ElementInit init = this.VisitElementInitializer(original[i]); | |
if (list != null) { | |
list.Add(init); | |
} | |
else if (init != original[i]) { | |
list = new List<ElementInit>(n); | |
for (int j = 0; j < i; j++) { | |
list.Add(original[j]); | |
} | |
list.Add(init); | |
} | |
} | |
if (list != null) | |
return list; | |
return original; | |
} | |
protected virtual Expression VisitLambda(LambdaExpression lambda) { | |
Expression body = this.Visit(lambda.Body); | |
if (body != lambda.Body) { | |
return Expression.Lambda(lambda.Type, body, lambda.Parameters); | |
} | |
return lambda; | |
} | |
protected virtual NewExpression VisitNew(NewExpression nex) { | |
IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments); | |
if (args != nex.Arguments) { | |
if (nex.Members != null) | |
return Expression.New(nex.Constructor, args, nex.Members); | |
else | |
return Expression.New(nex.Constructor, args); | |
} | |
return nex; | |
} | |
protected virtual Expression VisitMemberInit(MemberInitExpression init) { | |
NewExpression n = this.VisitNew(init.NewExpression); | |
IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings); | |
if (n != init.NewExpression || bindings != init.Bindings) { | |
return Expression.MemberInit(n, bindings); | |
} | |
return init; | |
} | |
protected virtual Expression VisitListInit(ListInitExpression init) { | |
NewExpression n = this.VisitNew(init.NewExpression); | |
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers); | |
if (n != init.NewExpression || initializers != init.Initializers) { | |
return Expression.ListInit(n, initializers); | |
} | |
return init; | |
} | |
protected virtual Expression VisitNewArray(NewArrayExpression na) { | |
IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions); | |
if (exprs != na.Expressions) { | |
if (na.NodeType == ExpressionType.NewArrayInit) { | |
return Expression.NewArrayInit(na.Type.GetElementType(), exprs); | |
} | |
else { | |
return Expression.NewArrayBounds(na.Type.GetElementType(), exprs); | |
} | |
} | |
return na; | |
} | |
protected virtual Expression VisitInvocation(InvocationExpression iv) { | |
IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments); | |
Expression expr = this.Visit(iv.Expression); | |
if (args != iv.Arguments || expr != iv.Expression) { | |
return Expression.Invoke(expr, args); | |
} | |
return iv; | |
} | |
} | |
internal static class Evaluator { | |
/// <summary> | |
/// Performs evaluation & replacement of independent sub-trees | |
/// </summary> | |
/// <param name="expression">The root of the expression tree.</param> | |
/// <param name="fnCanBeEvaluated">A function that decides whether a given expression node can be part of the local function.</param> | |
/// <returns>A new tree with sub-trees evaluated and replaced.</returns> | |
public static Expression PartialEval(Expression expression, Func<Expression, bool> fnCanBeEvaluated) { | |
return new SubtreeEvaluator(new Nominator(fnCanBeEvaluated).Nominate(expression)).Eval(expression); | |
} | |
/// <summary> | |
/// Performs evaluation & replacement of independent sub-trees | |
/// </summary> | |
/// <param name="expression">The root of the expression tree.</param> | |
/// <returns>A new tree with sub-trees evaluated and replaced.</returns> | |
public static Expression PartialEval(Expression expression) { | |
return PartialEval(expression, Evaluator.CanBeEvaluatedLocally); | |
} | |
private static bool CanBeEvaluatedLocally(Expression expression) { | |
switch (expression.NodeType) { | |
case ExpressionType.Parameter: | |
return false; | |
case ExpressionType.Call: | |
MethodCallExpression m = (MethodCallExpression)expression; | |
return Attribute.GetCustomAttribute(m.Method, typeof(SqlFunctionAttribute)) == null; | |
default: | |
return true; | |
} | |
} | |
/// <summary> | |
/// Evaluates & replaces sub-trees when first candidate is reached (top-down) | |
/// </summary> | |
class SubtreeEvaluator: ExpressionVisitor { | |
HashSet<Expression> candidates; | |
internal SubtreeEvaluator(HashSet<Expression> candidates) { | |
this.candidates = candidates; | |
} | |
internal Expression Eval(Expression exp) { | |
return this.Visit(exp); | |
} | |
protected override Expression Visit(Expression exp) { | |
if (exp == null) { | |
return null; | |
} | |
if (this.candidates.Contains(exp)) { | |
return this.Evaluate(exp); | |
} | |
exp = base.Visit(exp); | |
if (exp.NodeType == ExpressionType.AndAlso || exp.NodeType == ExpressionType.OrElse) { | |
// If either branch is constant, we need to perform some manual evaluation, just in case someone writes (1 = 1 || r.Whatever = 'Something') | |
// (SQL server does not support TRUE/FALSE literals, and the equivalent (1=1) and (1=0) break when using bit parameters and solving this case is easier than the generic one) | |
BinaryExpression be = (BinaryExpression)exp; | |
ConstantExpression ce = be.Left as ConstantExpression; | |
if (ce != null) { | |
if ((bool)ce.Value == true) { // first time in my life that an explicit comparison to true seems like a good idea (although unnecessary) | |
return (exp.NodeType == ExpressionType.AndAlso ? be.Right : Expression.Constant(true)); | |
} | |
else { | |
return (exp.NodeType == ExpressionType.AndAlso ? Expression.Constant(false) : be.Right); | |
} | |
} | |
ce = be.Right as ConstantExpression; | |
if (ce != null) { | |
if ((bool)ce.Value == true) { // first time in my life that an explicit comparison to true seems like a good idea (although unnecessary) | |
return (exp.NodeType == ExpressionType.AndAlso ? be.Left : Expression.Constant(true)); | |
} | |
else { | |
return (exp.NodeType == ExpressionType.AndAlso ? Expression.Constant(false) : be.Left); | |
} | |
} | |
} | |
return exp; | |
} | |
private Expression Evaluate(Expression e) { | |
if (e.NodeType == ExpressionType.Constant) { | |
return e; | |
} | |
LambdaExpression lambda = Expression.Lambda(e); | |
Delegate fn = lambda.Compile(); | |
return Expression.Constant(fn.DynamicInvoke(null), e.Type); | |
} | |
} | |
/// <summary> | |
/// Performs bottom-up analysis to determine which nodes can possibly | |
/// be part of an evaluated sub-tree. | |
/// </summary> | |
class Nominator : ExpressionVisitor { | |
Func<Expression, bool> fnCanBeEvaluated; | |
HashSet<Expression> candidates; | |
bool cannotBeEvaluated; | |
internal Nominator(Func<Expression, bool> fnCanBeEvaluated) { | |
this.fnCanBeEvaluated = fnCanBeEvaluated; | |
} | |
internal HashSet<Expression> Nominate(Expression expression) { | |
this.candidates = new HashSet<Expression>(); | |
this.Visit(expression); | |
return this.candidates; | |
} | |
protected override Expression Visit(Expression expression) { | |
if (expression != null) { | |
bool saveCannotBeEvaluated = this.cannotBeEvaluated; | |
this.cannotBeEvaluated = false; | |
base.Visit(expression); | |
if (!this.cannotBeEvaluated) { | |
if (this.fnCanBeEvaluated(expression)) { | |
this.candidates.Add(expression); | |
} | |
else { | |
this.cannotBeEvaluated = true; | |
} | |
} | |
this.cannotBeEvaluated |= saveCannotBeEvaluated; | |
} | |
return expression; | |
} | |
} | |
} | |
internal sealed class SqlBuilder : ExpressionVisitor { | |
private StringBuilder result = new StringBuilder(); | |
private SqlCommand command; | |
private SqlBuilder(SqlCommand command) { | |
this.command = command; | |
} | |
public static void AddWhereClause<T>(SqlCommand command, Expression<Func<T, bool>> filter) { | |
if (command == null) throw new ArgumentNullException("command"); | |
if (filter == null) throw new ArgumentNullException("filter"); | |
Expression e = Evaluator.PartialEval(filter.Body); | |
if (e.NodeType == ExpressionType.Constant) { | |
// need special handling for the delegate r => true | |
command.CommandText += " WHERE " + ((bool)((ConstantExpression)e).Value ? "1=1" : "1=0"); | |
} | |
else { | |
SqlBuilder b = new SqlBuilder(command); | |
b.Visit(e); | |
command.CommandText += " WHERE " + b.result.ToString(); | |
} | |
} | |
public static void AddOrderByClause<T>(SqlCommand command, Expression<Func<T, object>> order) { | |
if (command == null) throw new ArgumentNullException("command"); | |
if (order == null) throw new ArgumentNullException("order"); | |
string result = null; | |
foreach (Expression expr in (order.Body.NodeType == ExpressionType.NewArrayInit ? ((NewArrayExpression)order.Body).Expressions : Enumerable.Repeat(order.Body, 1))) { | |
Expression e = Evaluator.PartialEval(expr); | |
string currentResult; | |
if (e.NodeType == ExpressionType.Constant) { | |
// handle expressions such as ORDER BY 1, 2 | |
ConstantExpression c = (ConstantExpression)e; | |
Type t = c.Value.GetType(); | |
int ordinal; | |
if (t == typeof(byte)) ordinal = (int)(byte)c.Value; | |
else if (t == typeof(sbyte)) ordinal = (int)(sbyte)c.Value; | |
else if (t == typeof(short)) ordinal = (int)(short)c.Value; | |
else if (t == typeof(ushort)) ordinal = (int)(ushort)c.Value; | |
else if (t == typeof(int)) ordinal = (int)c.Value; | |
else if (t == typeof(uint)) ordinal = (int)(uint)c.Value; | |
else if (t == typeof(long)) ordinal = (int)(long)c.Value; | |
else if (t == typeof(ulong)) ordinal = (int)(ulong)c.Value; | |
else throw new ArgumentException(string.Format("The constant value {0} is not a number.", c.Value)); | |
currentResult = Convert.ToString(ordinal, System.Globalization.CultureInfo.InvariantCulture); | |
} | |
else { | |
bool isDesc = false; | |
while (e.NodeType == ExpressionType.Convert || e.NodeType == ExpressionType.ConvertChecked) { | |
e = ((UnaryExpression)e).Operand; | |
} | |
if (e.NodeType == ExpressionType.Negate || e.NodeType == ExpressionType.NegateChecked) { | |
e = ((UnaryExpression)e).Operand; | |
isDesc = true; | |
} | |
SqlBuilder b = new SqlBuilder(command); | |
b.Visit(e); | |
currentResult = b.result.ToString(); | |
if (isDesc) | |
currentResult += " DESC"; | |
} | |
result = (result == null ? currentResult.ToString() : (result + ", " + currentResult.ToString())); | |
} | |
command.CommandText += " ORDER BY " + result; | |
} | |
protected override Expression VisitUnary(UnaryExpression u) { | |
switch (u.NodeType) { | |
case ExpressionType.Negate: | |
case ExpressionType.NegateChecked: | |
case ExpressionType.Not: { | |
result.Append(u.NodeType == ExpressionType.Not ? "NOT (" : "-("); | |
Visit(u.Operand); | |
result.Append(")"); | |
break; | |
} | |
case ExpressionType.Convert: | |
case ExpressionType.ConvertChecked: | |
case ExpressionType.TypeAs: | |
Visit(u.Operand); // No-op in dynamically typed SQL | |
break; | |
case ExpressionType.Quote: // I don't really know the purpose of these, but they should be removed from the tree. | |
throw new NotSupportedException("Quotes are not supported."); | |
case ExpressionType.ArrayLength: | |
throw new NotSupportedException("Array lengths are not supported."); | |
default: | |
throw new NotSupportedException("Unsupported expression type."); | |
} | |
return u; | |
} | |
protected override Expression VisitBinary(BinaryExpression b) { | |
string infixOper; | |
switch (b.NodeType) { | |
case ExpressionType.Add: | |
case ExpressionType.AddChecked: | |
infixOper = "+"; | |
break; | |
case ExpressionType.Subtract: | |
case ExpressionType.SubtractChecked: | |
infixOper = "-"; | |
break; | |
case ExpressionType.Multiply: | |
case ExpressionType.MultiplyChecked: | |
infixOper = "*"; | |
break; | |
case ExpressionType.Divide: | |
infixOper = "/"; | |
break; | |
case ExpressionType.Modulo: | |
infixOper = "%"; | |
break; | |
case ExpressionType.And: | |
infixOper = "&"; | |
break; | |
case ExpressionType.AndAlso: | |
infixOper = "AND"; | |
break; | |
case ExpressionType.Or: | |
infixOper = "|"; | |
break; | |
case ExpressionType.OrElse: | |
infixOper = "OR"; | |
break; | |
case ExpressionType.LessThan: | |
infixOper = "<"; | |
break; | |
case ExpressionType.LessThanOrEqual: | |
infixOper = "<="; | |
break; | |
case ExpressionType.GreaterThan: | |
infixOper = ">"; | |
break; | |
case ExpressionType.GreaterThanOrEqual: | |
infixOper = ">="; | |
break; | |
case ExpressionType.Equal: | |
case ExpressionType.NotEqual: | |
/* // special handling because null = null returns false in sql | |
if (b.Left.NodeType == ExpressionType.Constant && ((ConstantExpression)b.Left).Value == null) { | |
result.Append("("); | |
right = Visit(b.Right); | |
result.Append(b.NodeType == ExpressionType.Equal ? ") IS NULL" : ") IS NOT NULL"); | |
return (right != b.Right ? Expression.MakeBinary(b.NodeType, b.Left, right) : b); | |
} | |
else if (b.Right.NodeType == ExpressionType.Constant && ((ConstantExpression)b.Right).Value == null) { | |
result.Append("("); | |
left = Visit(b.Left); | |
result.Append(b.NodeType == ExpressionType.Equal ? ") IS NULL" : ") IS NOT NULL"); | |
return (right != b.Right ? Expression.MakeBinary(b.NodeType, left, b.Right) : b); | |
} | |
else*/ | |
infixOper = (b.NodeType == ExpressionType.Equal ? "=" : "<>"); | |
break; | |
case ExpressionType.ExclusiveOr: | |
infixOper = "^"; | |
break; | |
case ExpressionType.Coalesce: | |
result.Append("COALESCE("); | |
Visit(b.Left); | |
result.Append(", "); | |
Visit(b.Right); | |
result.Append(")"); | |
return b; | |
case ExpressionType.ArrayIndex: | |
throw new NotSupportedException("The array index operator is not supported."); | |
case ExpressionType.RightShift: | |
throw new NotSupportedException("The right shift operator is not supported."); | |
case ExpressionType.LeftShift: | |
throw new NotSupportedException("The left shift operator is not supported."); | |
default: | |
throw new NotSupportedException("Unsupported expression type."); | |
} | |
// must be an infix operation | |
result.Append("("); | |
Visit(b.Left); | |
result.Append(") ").Append(infixOper).Append(" ("); | |
Visit(b.Right); | |
result.Append(")"); | |
return b; | |
} | |
protected override Expression VisitConditional(ConditionalExpression c) { | |
result.Append("CASE WHEN ("); | |
Visit(c.Test); | |
result.Append(") THEN ("); | |
Visit(c.IfTrue); | |
result.Append(") ELSE ("); | |
Visit(c.IfFalse); | |
result.Append(") END"); | |
return c; | |
} | |
private bool IsVarargMethod(System.Reflection.MethodInfo m) { | |
System.Reflection.ParameterInfo[] ps = m.GetParameters(); | |
return Attribute.GetCustomAttribute(ps[ps.Length - 1], typeof(ParamArrayAttribute)) != null; | |
} | |
protected override Expression VisitMethodCall(MethodCallExpression m) { | |
if (m.Object != null) | |
throw new NotSupportedException("Can only call static methods."); | |
SqlFunctionAttribute a = (SqlFunctionAttribute)Attribute.GetCustomAttribute(m.Method, typeof(SqlFunctionAttribute)); | |
if (a == null) | |
throw new NotSupportedException(string.Format("The method {0} is not decorated with a SqlFunctionAttribute.", m.Method.Name)); | |
switch (a.Type) { | |
case SqlFunctionAttribute.FunctionType.Function: | |
result.Append(a.Name); | |
if (a.ParamList) | |
result.Append("("); | |
if (!string.IsNullOrEmpty(a.MagicFirstArg)) | |
result.Append(a.MagicFirstArg).Append(", "); | |
break; | |
case SqlFunctionAttribute.FunctionType.Case: | |
result.Append("CASE"); | |
break; | |
case SqlFunctionAttribute.FunctionType.IsNull: | |
case SqlFunctionAttribute.FunctionType.IsNotNull: | |
case SqlFunctionAttribute.FunctionType.NoOp: | |
case SqlFunctionAttribute.FunctionType.Like: | |
case SqlFunctionAttribute.FunctionType.NotLike: | |
case SqlFunctionAttribute.FunctionType.In: | |
case SqlFunctionAttribute.FunctionType.NotIn: | |
result.Append("("); | |
break; | |
default: | |
throw new NotSupportedException("Bad SqlFunctionAttribute attribute."); | |
} | |
// flatten the argument list to compensate for varargs being transformed to an array | |
IList<Expression> actualArguments; | |
if (IsVarargMethod(m.Method)) { | |
Expression lastArg = m.Arguments[m.Arguments.Count - 1]; | |
IEnumerable<Expression> x; | |
if (lastArg.NodeType == ExpressionType.NewArrayInit) | |
x = ((NewArrayExpression)lastArg).Expressions; | |
else if (lastArg.NodeType == ExpressionType.Constant && ((ConstantExpression)lastArg).Type.IsArray) { | |
ConstantExpression c = (ConstantExpression)lastArg; | |
Array arr = (Array)c.Value; | |
List<Expression> l = new List<Expression>(); | |
foreach (object o in arr) | |
l.Add(Expression.Constant(o, c.Type.GetElementType())); | |
x = l; | |
} | |
else | |
x = Enumerable.Repeat(lastArg, 1); | |
actualArguments = m.Arguments.Take(m.Arguments.Count - 1).Concat(x).ToList(); | |
} | |
else { | |
actualArguments = m.Arguments; | |
} | |
for (int i = 0, n = actualArguments.Count; i < n; i++) { | |
switch (a.Type) { | |
case SqlFunctionAttribute.FunctionType.Function: | |
if (i > 0) result.Append(", "); | |
break; | |
case SqlFunctionAttribute.FunctionType.Case: | |
if (i % 2 == 0) | |
result.Append(i < n - 1 ? " WHEN " : " ELSE "); | |
else | |
result.Append(" THEN "); | |
break; | |
case SqlFunctionAttribute.FunctionType.Like: | |
if (i > 0) result.Append(") LIKE ("); | |
break; | |
case SqlFunctionAttribute.FunctionType.NotLike: | |
if (i > 0) result.Append(") NOT LIKE ("); | |
break; | |
case SqlFunctionAttribute.FunctionType.In: | |
case SqlFunctionAttribute.FunctionType.NotIn: | |
switch (i) { | |
case 0: break; | |
case 1: result.Append(a.Type == SqlFunctionAttribute.FunctionType.In ? ") IN(" : ") NOT IN("); break; | |
default: result.Append(", "); break; | |
} | |
break; | |
} | |
this.Visit(actualArguments[i]); | |
} | |
switch (a.Type) { | |
case SqlFunctionAttribute.FunctionType.Function: | |
if (a.ParamList) | |
result.Append(")"); | |
break; | |
case SqlFunctionAttribute.FunctionType.Case: | |
result.Append(" END"); | |
break; | |
case SqlFunctionAttribute.FunctionType.IsNull: | |
result.Append(") IS NULL"); | |
break; | |
case SqlFunctionAttribute.FunctionType.IsNotNull: | |
result.Append(") IS NOT NULL"); | |
break; | |
case SqlFunctionAttribute.FunctionType.Like: | |
case SqlFunctionAttribute.FunctionType.NotLike: | |
case SqlFunctionAttribute.FunctionType.NoOp: | |
case SqlFunctionAttribute.FunctionType.In: | |
case SqlFunctionAttribute.FunctionType.NotIn: | |
result.Append(")"); | |
break; | |
} | |
return m; | |
} | |
protected override Expression VisitConstant(ConstantExpression c) { | |
string paramName = "@p" + command.Parameters.Count; | |
result.Append(paramName); | |
command.Parameters.AddWithValue(paramName, c.Value ?? DBNull.Value); | |
return c; | |
} | |
protected override Expression VisitMemberAccess(MemberExpression m) { | |
Visit(m.Expression); | |
if (m.Expression.NodeType != ExpressionType.Parameter) | |
throw new NotSupportedException("Member accesses are only allowed with the lambda parameter as the object."); | |
ColumnNameAttribute a = (ColumnNameAttribute)Attribute.GetCustomAttribute(m.Member, typeof(ColumnNameAttribute)); | |
if (a == null) | |
throw new NotSupportedException("Member accesses are only allowed for columns which are decorated with a ColumnNameAttribute."); | |
result.Append("[").Append(a.Name).Append("]"); | |
return m; | |
} | |
#region Unsupported expression types | |
protected override MemberBinding VisitBinding(MemberBinding binding) { | |
throw new NotSupportedException("MemberBinding expressions are not supported."); | |
} | |
protected override ElementInit VisitElementInitializer(ElementInit initializer) { | |
throw new NotSupportedException("Element initializer expressions are not supported."); | |
} | |
protected override Expression VisitTypeIs(TypeBinaryExpression b) { | |
throw new NotSupportedException("Type is expressions are not supported."); | |
} | |
protected override MemberAssignment VisitMemberAssignment(MemberAssignment assignment) { | |
throw new NotSupportedException("Member assignment expressions are not supported."); | |
} | |
protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) { | |
throw new NotSupportedException("Member member binding expressions are not supported."); | |
} | |
protected override MemberListBinding VisitMemberListBinding(MemberListBinding binding) { | |
throw new NotSupportedException("Member list binding expressions are not supported."); | |
} | |
protected override Expression VisitLambda(LambdaExpression lambda) { | |
throw new NotSupportedException("Lambda expressions are not supported (except at the top level)."); | |
} | |
protected override NewExpression VisitNew(NewExpression nex) { | |
throw new NotSupportedException("`new´ expressions are not supported."); | |
} | |
protected override Expression VisitMemberInit(MemberInitExpression init) { | |
throw new NotSupportedException("Member initialization expressions are not supported."); | |
} | |
protected override Expression VisitListInit(ListInitExpression init) { | |
throw new NotSupportedException("List initialization expressions are not supported."); | |
} | |
protected override Expression VisitNewArray(NewArrayExpression na) { | |
throw new NotSupportedException("`new[]´ expressions are not supported (except at the top level for order by expressions."); | |
} | |
protected override Expression VisitInvocation(InvocationExpression iv) { | |
throw new NotSupportedException("Invocations are not supported"); | |
} | |
#endregion | |
} | |
#endif // !SQLPROXY_SQLCLR | |
public abstract class ResultSetEnumerator : IDisposable { | |
private SqlDataReader reader; | |
private SqlCommand command; | |
private int currentResultSet = -1; | |
private int AfterLastResultSet = int.MaxValue; | |
private Dictionary<string, object> outputParameters; | |
protected ResultSetEnumerator(SqlCommand command) { | |
this.command = command; | |
try { | |
this.reader = command.ExecuteReader(); | |
} | |
catch (Exception) { | |
command.Dispose(); // The command can never be used for anything and our dispose method will not be called | |
command = null; | |
throw; | |
} | |
} | |
private void AdvanceToRS(int num) { | |
if (currentResultSet == -1) | |
currentResultSet = 0; | |
while (currentResultSet < num) { | |
if (!reader.NextResult()) { | |
currentResultSet = AfterLastResultSet; | |
return; | |
} | |
currentResultSet++; | |
} | |
} | |
protected IEnumerable<T> GetResultSetEnumerator<T>(int resultSetNumber, Func<SqlDataReader, T> selector, bool isOnlyResultSet) { | |
if (currentResultSet >= resultSetNumber) | |
throw new InvalidOperationException("The result sets must be enumerated in order and only once per result set."); | |
if (reader == null) | |
throw new ObjectDisposedException(null); | |
try { | |
AdvanceToRS(resultSetNumber); | |
while (reader.Read()) { | |
yield return selector(reader); | |
if (currentResultSet != resultSetNumber) | |
throw new InvalidOperationException("It is not possible to continue the enumeration of a result set after enumerating a later one."); | |
if (reader == null) | |
throw new ObjectDisposedException(null); | |
} | |
if (isOnlyResultSet) | |
ReadOutputParameters(); | |
} | |
finally { | |
if (isOnlyResultSet) | |
Dispose(); | |
} | |
} | |
protected void ReadOutputParameters() { | |
if (currentResultSet != AfterLastResultSet) { | |
AdvanceToRS(AfterLastResultSet); | |
outputParameters = new Dictionary<string, object>(); | |
foreach (SqlParameter p in command.Parameters) { | |
if (p.Direction != System.Data.ParameterDirection.Input) | |
outputParameters[p.ParameterName] = (p.Value is DBNull ? null : p.Value); | |
} | |
} | |
} | |
protected T GetOutputParameter<T>(string name) { | |
if (outputParameters == null) | |
ReadOutputParameters(); | |
return SqlFunctions.DbNullCast<T>(outputParameters[name]); | |
} | |
public void Dispose() { | |
if (reader != null) { | |
reader.Dispose(); | |
reader = null; | |
} | |
if (command != null) { | |
command.Dispose(); | |
command = null; | |
} | |
} | |
#if !SQLPROXY_SQLCLR | |
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, Func<SqlDataReader, T> selector, Expression<Func<T, bool>> filter, Expression<Func<T, object>> order, int commandTimeout) { | |
SqlCommand cmd = connection.CreateCommand(); | |
cmd.CommandType = System.Data.CommandType.Text; | |
cmd.CommandTimeout = commandTimeout; | |
cmd.CommandText = query; | |
cmd.Transaction = transaction; | |
if (filter != null) SqlBuilder.AddWhereClause(cmd, filter); | |
if (order != null) SqlBuilder.AddOrderByClause(cmd, order); | |
return new ViewResultSet<T>(cmd, selector); | |
} | |
#else | |
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, Func<SqlDataReader, T> selector, int commandTimeout) { | |
SqlCommand cmd = connection.CreateCommand(); | |
cmd.CommandType = System.Data.CommandType.Text; | |
cmd.CommandTimeout = commandTimeout; | |
cmd.CommandText = query; | |
cmd.Transaction = transaction; | |
return new ViewResultSet<T>(cmd, selector); | |
} | |
#endif | |
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, IEnumerable<SqlParameter> parms, Func<SqlDataReader, T> selector, int commandTimeout) { | |
SqlCommand cmd = connection.CreateCommand(); | |
cmd.CommandType = System.Data.CommandType.Text; | |
cmd.CommandTimeout = commandTimeout; | |
cmd.CommandText = query; | |
cmd.Transaction = transaction; | |
if (parms != null) { | |
foreach (SqlParameter p in parms) | |
cmd.Parameters.Add(p); | |
} | |
return new ViewResultSet<T>(cmd, selector); | |
} | |
} | |
public class ViewResultSet<T> : ResultSetEnumerator, IEnumerable<T> { | |
private Func<SqlDataReader, T> selector; | |
public ViewResultSet(SqlCommand command, Func<SqlDataReader, T> selector) : base(command) { | |
this.selector = selector; | |
} | |
public IEnumerator<T> GetEnumerator() { | |
return ResultSet.GetEnumerator(); | |
} | |
IEnumerator IEnumerable.GetEnumerator() { | |
return GetEnumerator(); | |
} | |
public IEnumerable<T> ResultSet { | |
get { return GetResultSetEnumerator(0, selector, true); } | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<#@ template hostspecific="true" debug="true" #> | |
<#@ output extension=".cs" #> | |
<#@ assembly name="System.Data.dll" #> | |
<#@ assembly name="System.Xml.dll" #> | |
<#@ assembly name="System.Core.dll" #> | |
<#@ assembly name="System.Windows.Forms.dll" #> | |
<#@ import namespace="System.Linq" #> | |
<#@ import namespace="System.Linq.Expressions" #> | |
<#@ import namespace="System.Data.SqlClient" #> | |
<#@ import namespace="System.Data" #> | |
<#@ import namespace="System.Data.SqlTypes" #> | |
<#@ import namespace="System.Collections.ObjectModel" #> | |
<#@ import namespace="System.Collections.Generic" #> | |
<#@ import namespace="System.Text" #> | |
<#@ import namespace="System.Text.RegularExpressions" #> | |
using System; | |
using System.Data; | |
using System.Data.SqlClient; | |
using System.Data.SqlTypes; | |
using System.Collections; | |
using System.Collections.Generic; | |
using SqlProxy; | |
namespace <#= namespaceName #> | |
{ | |
<# | |
PushIndent("\t"); | |
using (connection = new SqlConnection(connectionString)) { | |
connection.Open(); | |
List<SchemaAndName> procNames = new List<SchemaAndName>(), viewNames = new List<SchemaAndName>(); | |
if (includeViews) { | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
cmd.CommandType = CommandType.Text; | |
cmd.CommandText = "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.VIEWS WHERE lower(table_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")"; | |
using (var rdr = cmd.ExecuteReader()) { | |
while (rdr.Read()) { | |
viewNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1])); | |
} | |
} | |
} | |
} | |
if (includeTables) { | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
cmd.CommandType = CommandType.Text; | |
cmd.CommandText = "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.TABLES WHERE lower(table_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")"; | |
using (var rdr = cmd.ExecuteReader()) { | |
while (rdr.Read()) { | |
viewNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1])); | |
} | |
} | |
} | |
} | |
if (includeProcedures) { | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
cmd.CommandType = CommandType.Text; | |
cmd.CommandText = "SELECT routine_schema, routine_name FROM INFORMATION_SCHEMA.ROUTINES WHERE routine_type = 'PROCEDURE' AND lower(routine_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")"; | |
using (var rdr = cmd.ExecuteReader()) { | |
while (rdr.Read()) { | |
procNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1])); | |
} | |
} | |
} | |
} | |
if (namedQueries != null) { | |
Regex parameterRegex = new Regex(@"@\{\s*([a-zA-Z0-9_]+)\s*:\s*([a-zA-Z0-9_]+)\s*}"); | |
foreach (var kvp in namedQueries.OrderBy(kvp => kvp.Key)) { | |
var parms = new List<Parameter>(); | |
foreach (Match m in parameterRegex.Matches(kvp.Value)) { | |
SqlDbType dbType; | |
try { | |
dbType = (SqlDbType)Enum.Parse(typeof(SqlDbType), m.Groups[2].Value, true); | |
} | |
catch (ArgumentException) { | |
System.Windows.Forms.MessageBox.Show("Query " + kvp.Key + ": The SqlDbType " + m.Groups[2].Value + " does not exist."); | |
goto continueOuter; | |
} | |
parms.Add(new Parameter("@" + m.Groups[1].Value, dbType, GetMaxSize(dbType), ParameterDirection.Input)); | |
} | |
string query = parameterRegex.Replace(kvp.Value, m => "@" + m.Groups[1].Value); | |
parsedNamedQueries.Add(new NamedQuery(connection, kvp.Key, query, parms, usedNames)); | |
continueOuter:; | |
} | |
} | |
foreach (var procName in procNames.OrderBy(x => x.Schema).ThenBy(x => x.Name)) { | |
if ((procsToInclude == null || procsToInclude.Contains(procName)) && | |
(procsToExclude == null || !procsToExclude.Contains(procName))) { | |
procedures.Add(new Procedure(connection, procName, usedNames)); | |
} | |
} | |
foreach (var viewName in viewNames.OrderBy(x => x.Schema).ThenBy(x => x.Name)) { | |
if ((viewsToInclude == null || viewsToInclude.Contains(viewName)) && | |
(viewsToExclude == null || !viewsToExclude.Contains(viewName))) { | |
views.Add(new View(connection, viewName, usedNames)); | |
} | |
} | |
foreach (var proc in procedures) { | |
foreach (var rs in proc.ResultSets) { | |
GenerateResultSetRowType(rs, false); | |
} | |
if (proc.ResultSets.Count > 0) | |
GenerateEnumeratorType(proc.ResultSets, proc.Parameters.Where(p => p.Direction != ParameterDirection.Input).Union(Enumerable.Repeat(Parameter.ReturnValue, 1)).ToList(), proc.ResultTypeName); | |
} | |
foreach (var view in views) { | |
GenerateResultSetRowType(view.ResultSet, !generateSqlclr); | |
} | |
foreach (NamedQuery query in parsedNamedQueries) { | |
GenerateResultSetRowType(query.ResultSet, false); | |
} | |
} | |
#> | |
<#= accessibility #> partial interface I<#= databaseClassName #> : IDisposable { | |
void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim); | |
void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel); | |
void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim); | |
void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel); | |
IDbTransaction BeginTransaction(IsolationLevel isolationLevel); | |
IDbCommand CreateCommand(); | |
IDbCommand CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters); | |
IDbConnection Connection { get; } | |
string SaveTransaction(); | |
void RollbackToSavepoint(string savepointName); | |
<# | |
PushIndent("\t"); | |
foreach (Procedure proc in procedures) { | |
WriteStoredprocSignature(proc); | |
WriteLine(";"); | |
} | |
foreach (View view in views) { | |
WriteViewSignature(view); | |
WriteLine(";"); | |
} | |
foreach (NamedQuery query in parsedNamedQueries) { | |
WriteNamedQuerySignature(query); | |
WriteLine(";"); | |
} | |
PopIndent(); | |
#> | |
} | |
<#= accessibility #> sealed partial class <#= databaseClassName #> : IDisposable, I<#= databaseClassName #> | |
{ | |
private class TransactionClass : IDbTransaction, IDisposable { | |
private SqlTransaction t; | |
private <#= databaseClassName #> db; | |
public TransactionClass(<#= databaseClassName #> db) { | |
if (db.transaction != null) | |
throw new InvalidOperationException("There is already an open transaction"); | |
this.t = db.connection.BeginTransaction(); | |
this.db = db; | |
db.transaction = this; | |
} | |
public TransactionClass(<#= databaseClassName #> db, IsolationLevel il) { | |
if (db.transaction != null) | |
throw new InvalidOperationException("There is already an open transaction"); | |
this.t = db.connection.BeginTransaction(il); | |
this.db = db; | |
db.transaction = this; | |
} | |
public void Commit() { | |
t.Commit(); | |
db.transaction = null; | |
} | |
public IDbConnection Connection { | |
get { return t.Connection; } | |
} | |
public IsolationLevel IsolationLevel { | |
get { return t.IsolationLevel; } | |
} | |
public void Rollback() { | |
t.Rollback(); | |
db.transaction = null; | |
} | |
public void Dispose() { | |
t.Dispose(); | |
db.transaction = null; | |
} | |
public static implicit operator SqlTransaction(TransactionClass t) { | |
return (t != null ? t.t : null); | |
} | |
} | |
private TransactionClass transaction; | |
private SqlConnection connection; | |
private bool ownsConnection; | |
public <#= databaseClassName #>(string connectionString) : this(new SqlConnection(connectionString), true) { | |
this.connection.Open(); | |
} | |
public <#= databaseClassName #>(SqlConnection connection, bool takeOwnership) { | |
if (connection == null) | |
throw new ArgumentNullException("connection"); | |
this.CommandTimeout = <#= Convert.ToString(commandTimeout, System.Globalization.CultureInfo.InvariantCulture) #>; | |
this.ownsConnection = takeOwnership; | |
this.connection = connection; | |
} | |
public void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim) { | |
ExecuteTransacted(db => { func(db); return true; }, retryIfDeadlockVictim, IsolationLevel.Unspecified); | |
} | |
public void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel) { | |
ExecuteTransacted(db => { func(db); return true; }, retryIfDeadlockVictim, isolationLevel); | |
} | |
public void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim) { | |
ExecuteTransacted(func, retryIfDeadlockVictim, IsolationLevel.Unspecified); | |
} | |
public void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel) { | |
retry: | |
try { | |
using (var tran = BeginTransaction(isolationLevel)) { | |
if (func(this)) | |
tran.Commit(); | |
} | |
return; | |
} | |
catch (SqlException ex) { | |
if (retryIfDeadlockVictim && ex.Number == 1205) | |
goto retry; | |
throw; | |
} | |
} | |
public IDbTransaction BeginTransaction(IsolationLevel isolationLevel) { | |
return new TransactionClass(this, isolationLevel); | |
} | |
private IDbTransaction Transaction { get { return transaction; } } | |
public IDbConnection Connection { get { return connection; } } | |
IDbCommand I<#= databaseClassName #>.CreateCommand() { | |
return CreateCommand(); | |
} | |
IDbCommand I<#= databaseClassName #>.CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters) { | |
return CreateCommand(commandType, commandText, parameters); | |
} | |
public SqlCommand CreateCommand() { | |
SqlCommand cmd = connection.CreateCommand(); | |
cmd.Transaction = transaction; | |
return cmd; | |
} | |
public SqlCommand CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters) { | |
var cmd = CreateCommand(); | |
cmd.CommandType = commandType; | |
cmd.CommandText = commandText; | |
if (parameters != null) { | |
foreach (var p in parameters) | |
cmd.Parameters.AddWithValue("@" + p.Key, p.Value ?? DBNull.Value); | |
} | |
return cmd; | |
} | |
public int CommandTimeout { get; set; } | |
public void Dispose() { | |
if (connection != null) { | |
if (ownsConnection) | |
connection.Dispose(); | |
connection = null; | |
} | |
} | |
public string SaveTransaction() { | |
string name = Guid.NewGuid().ToString("N"); | |
using (var cmd = CreateCommand(CommandType.Text, "SAVE TRANSACTION @name", new[] { new KeyValuePair<string, object>("name", name) })) { | |
cmd.ExecuteNonQuery(); | |
} | |
return name; | |
} | |
public void RollbackToSavepoint(string savepointName) { | |
using (var cmd = CreateCommand(CommandType.Text, "IF xact_state() <> -1 ROLLBACK TRANSACTION @name", new[] { new KeyValuePair<string, object>("name", savepointName) })) { | |
cmd.ExecuteNonQuery(); | |
} | |
} | |
<# | |
PushIndent("\t"); | |
if (generateSqlclr) { | |
WriteLine("#pragma warning disable 414"); | |
WriteLine("private bool preventStaticDelegates;"); | |
WriteLine("#pragma warning restore 414"); | |
WriteLine(""); | |
} | |
foreach (Procedure proc in procedures) { | |
GenerateStoredprocProxy(proc); | |
} | |
foreach (View view in views) { | |
GenerateViewProxy(view); | |
} | |
foreach (NamedQuery query in parsedNamedQueries) { | |
GenerateNamedQueryProxy(query); | |
} | |
PopIndent(); | |
PopIndent(); | |
#> | |
} | |
} | |
<#+ | |
static int GetMaxSize(SqlDbType dbType) { | |
switch (dbType) { | |
case SqlDbType.BigInt: return 8; | |
case SqlDbType.Binary: return int.MaxValue; | |
case SqlDbType.Bit: return 1; | |
case SqlDbType.Char: return 8000; | |
case SqlDbType.Date: return 0x7fffffff; | |
case SqlDbType.DateTime: return 0x7fffffff; | |
case SqlDbType.DateTime2: return 0x7fffffff; | |
case SqlDbType.DateTimeOffset: return 0x7fffffff; | |
case SqlDbType.Decimal: return 0x7fffffff; | |
case SqlDbType.Float: return 4; | |
case SqlDbType.Image: return 0x7fffffff; | |
case SqlDbType.Int: return 4; | |
case SqlDbType.Money: return 0x7fffffff; | |
case SqlDbType.NChar: return 4000; | |
case SqlDbType.NText: return 0x3fffffff; | |
case SqlDbType.NVarChar: return 0x3fffffff; | |
case SqlDbType.Real: return 0x7fffffff; | |
case SqlDbType.SmallDateTime: return 0x7fffffff; | |
case SqlDbType.SmallInt: return 2; | |
case SqlDbType.SmallMoney: return 0x7fffffff; | |
case SqlDbType.Text: return 0x7fffffff; | |
case SqlDbType.Time: return 0x7fffffff; | |
case SqlDbType.Timestamp: return 0x7fffffff; | |
case SqlDbType.TinyInt: return 1; | |
case SqlDbType.UniqueIdentifier: return 0x7fffffff; | |
case SqlDbType.VarBinary: return 0x7fffffff; | |
case SqlDbType.VarChar: return 0x7fffffff; | |
case SqlDbType.Variant: return 0x7fffffff; | |
case SqlDbType.Xml: return 0; | |
//case SqlDbType.Structured: | |
//case SqlDbType.Udt: | |
default: | |
throw new ArgumentException("dbType"); | |
} | |
} | |
static void SetFmtonly(SqlConnection con, SqlTransaction tran, bool value) { | |
using (SqlCommand cmd = con.CreateCommand()) { | |
cmd.Transaction = tran; | |
cmd.CommandText = "SET FMTONLY " + (value ? "ON" : "OFF"); | |
cmd.CommandType = CommandType.Text; | |
cmd.ExecuteNonQuery(); | |
} | |
} | |
static string FindUniqueName(string desiredName, HashSet<string> usedNames) { | |
string name = desiredName; | |
int num = 1; | |
while (usedNames.Contains(name)) | |
name = desiredName + Convert.ToString(num++, System.Globalization.CultureInfo.InvariantCulture); | |
usedNames.Add(name); | |
return name; | |
} | |
sealed class SchemaAndName { | |
public string Schema { get; private set; } | |
public string Name { get; private set; } | |
public SchemaAndName(string schema, string name) { | |
this.Schema = schema; | |
this.Name = name; | |
} | |
public override bool Equals(object obj) { | |
var other = obj as SchemaAndName; | |
return other != null && this.Schema == other.Schema && this.Name == other.Name; | |
} | |
public override int GetHashCode() { | |
return Schema.GetHashCode() ^ Name.GetHashCode(); | |
} | |
public override string ToString() { | |
return "[" + Schema + "].[" + Name + "]"; | |
} | |
} | |
class Parameter { | |
private static Parameter returnValue = new Parameter("@ReturnValue", SqlDbType.Int, 4, ParameterDirection.ReturnValue); | |
public static Parameter ReturnValue { get { return returnValue; } } | |
public string DbName { get; private set; } | |
public SqlDbType DbType { get; private set; } | |
public int DbSize { get; private set; } | |
public ParameterDirection Direction { get; private set; } | |
public string CSharpNameU { get; private set; } | |
public string CSharpNameL { get; private set; } | |
public string CSharpTypeName { get; private set; } | |
public Parameter(string dbName, SqlDbType dbType, int dbSize, ParameterDirection direction) { | |
this.DbName = dbName; | |
this.DbType = dbType; | |
if (dbSize == -1) | |
this.DbSize = GetMaxSize(dbType); | |
else | |
this.DbSize = dbSize; | |
this.Direction = direction; | |
this.CSharpNameU = MakeCSharpName(null, dbName.Substring(1), true); | |
this.CSharpNameL = MakeCSharpName(null, dbName.Substring(1), false); | |
Type type; | |
if (!dbTypeMap.TryGetValue(dbType, out type)) | |
throw new ArgumentException(string.Format("Unknown SqlDbType {0} used by field {1}.", dbType, dbName)); | |
this.CSharpTypeName = type.ToString() + (type.IsValueType && (Direction != ParameterDirection.ReturnValue) ? "?" : ""); | |
} | |
} | |
class ResultSetColumn { | |
public string DbName { get; private set; } | |
public bool Nullable { get; private set; } | |
public string CSharpName { get; private set; } | |
public Type CSharpType { get; private set; } | |
public string CSharpTypeName { get; private set; } | |
public ResultSetColumn(string dbName, Type csharpType, bool nullable) { | |
this.DbName = dbName; | |
this.Nullable = nullable; | |
this.CSharpName = MakeCSharpName(null, dbName, true); | |
this.CSharpType = csharpType; | |
this.CSharpTypeName = csharpType.ToString() + ((csharpType.IsValueType && Nullable) ? "?" : ""); | |
} | |
} | |
class ResultSet { | |
public string RowTypeName { get; private set; } | |
public ReadOnlyCollection<ResultSetColumn> Columns { get; private set; } | |
public ResultSet(string rowTypeName, DataTable schemaTable, HashSet<string> usedNames, bool everythingNullable) { | |
RowTypeName = FindUniqueName(rowTypeName, usedNames); | |
List<ResultSetColumn> columns = new List<ResultSetColumn>(); | |
int unnamedIndex = 1; | |
foreach (DataRow row in schemaTable.Rows) { | |
string columnName = (string)row["ColumnName"]; | |
if (string.IsNullOrEmpty(columnName)) | |
columnName = ("Unnamed" + unnamedIndex++); | |
columns.Add(new ResultSetColumn(columnName, (Type)row["DataType"], everythingNullable || (bool)row["AllowDbNull"])); | |
} | |
Columns = columns.AsReadOnly(); | |
} | |
public string CreatorLambda { | |
get { | |
string doer = "new " + RowTypeName + "(" + string.Join(", ", Columns.Select((c, x) => "SqlFunctions.DbNullCast<" + c.CSharpTypeName + ">(reader[" + x + "])").ToArray()) + ")"; | |
return "reader => " + (generateSqlclr ? "{ preventStaticDelegates = true; return " + doer + "; }" : doer); | |
} | |
} | |
public ResultSet(string rowTypeName, ReadOnlyCollection<ResultSetColumn> columns, HashSet<string> usedNames) { | |
this.RowTypeName = FindUniqueName(rowTypeName, usedNames); | |
this.Columns = columns; | |
} | |
} | |
class View { | |
public SchemaAndName DbName { get; private set; } | |
public string CSharpName { get; private set; } | |
public string RowTypeName { get; private set; } | |
public ResultSet ResultSet { get; private set; } | |
public View(SqlConnection connection, SchemaAndName dbName, HashSet<string> usedNames) { | |
this.DbName = dbName; | |
this.CSharpName = MakeCSharpName(dbName.Schema, dbName.Name, true); | |
SetFmtonly(connection, null, true); | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
cmd.CommandType = CommandType.Text; | |
cmd.CommandText = "SELECT * FROM [" + dbName.Schema + "].[" + dbName.Name + "]"; | |
using(SqlDataReader rdr = cmd.ExecuteReader()) { | |
using (DataTable schemaTable = rdr.GetSchemaTable()) { | |
ResultSet = new ResultSet(this.CSharpName + "Row", schemaTable, usedNames, false); | |
this.RowTypeName = ResultSet.RowTypeName; | |
} | |
} | |
} | |
SetFmtonly(connection, null, false); | |
} | |
} | |
class NamedQuery { | |
public string Name { get; private set; } | |
public string QueryText { get; private set; } | |
public string RowTypeName { get; private set; } | |
public ReadOnlyCollection<Parameter> Parameters { get; private set; } | |
public ResultSet ResultSet { get; private set; } | |
public NamedQuery(SqlConnection connection, string name, string queryText, IList<Parameter> parameters, HashSet<string> usedNames) { | |
this.Name = name; | |
this.QueryText = queryText; | |
this.Parameters = new List<Parameter>(parameters).AsReadOnly(); | |
SetFmtonly(connection, null, true); | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
cmd.CommandType = CommandType.Text; | |
cmd.CommandText = queryText; | |
foreach (var p in parameters) | |
cmd.Parameters.AddWithValue(p.DbName, DBNull.Value); | |
using(SqlDataReader rdr = cmd.ExecuteReader()) { | |
using (DataTable schemaTable = rdr.GetSchemaTable()) { | |
ResultSet = new ResultSet(this.Name + "Row", schemaTable, usedNames, true); | |
this.RowTypeName = ResultSet.RowTypeName; | |
} | |
} | |
} | |
SetFmtonly(connection, null, false); | |
} | |
} | |
class Procedure { | |
public SchemaAndName DbName { get; private set; } | |
public string CSharpName { get; private set; } | |
public string ResultTypeName { get; private set; } | |
public ReadOnlyCollection<Parameter> Parameters { get; private set; } | |
public ReadOnlyCollection<ResultSet> ResultSets { get; private set; } | |
public Procedure(SqlConnection connection, SchemaAndName dbName, HashSet<string> usedNames) { | |
this.DbName = dbName; | |
this.CSharpName = MakeCSharpName(dbName.Schema, dbName.Name, true); | |
this.ResultTypeName = FindUniqueName(this.CSharpName + "Result", usedNames); | |
try { | |
using (SqlCommand cmd = connection.CreateCommand()) { | |
List<Parameter> parameters = new List<Parameter>(); | |
object[] customParams = (customExtractions != null && customExtractions.ContainsKey(dbName)) ? customExtractions[dbName] : null; | |
cmd.CommandText = "SELECT p.name, st.name, p.max_length, p.is_output FROM sys.parameters p JOIN sys.types t ON t.user_type_id = p.user_type_id JOIN sys.types st ON st.user_type_id = t.system_type_id WHERE p.object_id = object_id(@name) ORDER BY p.parameter_id"; | |
cmd.CommandType = CommandType.Text; | |
cmd.Parameters.AddWithValue("@name", "[" + dbName.Schema + "].[" + dbName.Name + "]"); | |
using (var rdr = cmd.ExecuteReader()) { | |
while (rdr.Read()) { | |
SqlDbType dbType; | |
if (!dbTypeNameMap.TryGetValue((string)rdr[1], out dbType)) { | |
throw new Exception("Unknown database type " + (string)rdr[1]); | |
} | |
parameters.Add(new Parameter((string)rdr[0], dbType, (short)rdr[2], (bool)rdr[3] ? ParameterDirection.InputOutput : ParameterDirection.Input)); | |
} | |
} | |
cmd.Parameters.Clear(); | |
cmd.CommandText = "[" + dbName.Schema + "].[" + dbName.Name + "]"; | |
cmd.CommandType = CommandType.StoredProcedure; | |
for (int i = 0, n = parameters.Count; i < n; i++) { | |
object value = (customParams != null && i < customParams.Length ? customParams[i] : null); | |
if (value == null) | |
value = (parameters[i].DbType == SqlDbType.Xml ? (object)SqlXml.Null : DBNull.Value); | |
var p = cmd.Parameters.Add(parameters[i].DbName, parameters[i].DbType); | |
if (parameters[i].DbSize != 0) | |
p.Size = parameters[i].DbSize; | |
p.Value = value; | |
p.Direction = parameters[i].Direction; | |
} | |
using (SqlTransaction tran = connection.BeginTransaction()) { | |
cmd.Transaction = tran; | |
if (customParams == null) | |
SetFmtonly(connection, tran, true); | |
List<ResultSet> resultSets = new List<ResultSet>(); | |
try { | |
using(SqlDataReader rdr = cmd.ExecuteReader()) { | |
if (rdr != null) { | |
int i = 0; | |
do { | |
using (DataTable schemaTable = rdr.GetSchemaTable()) { | |
if (schemaTable != null) { | |
i++; | |
resultSets.Add(new ResultSet(CSharpName + "ResultSet" + i + "Row", schemaTable, usedNames, false)); | |
} | |
} | |
} while (rdr.NextResult()); | |
if (i == 1) { | |
usedNames.Remove(resultSets[0].RowTypeName); | |
resultSets[0] = new ResultSet(CSharpName + "ResultSetRow", resultSets[0].Columns, usedNames); | |
} | |
} | |
} | |
} | |
finally { | |
if (customParams == null) | |
SetFmtonly(connection, tran, false); | |
} | |
Parameters = parameters.AsReadOnly(); | |
ResultSets = resultSets.AsReadOnly(); | |
tran.Rollback(); | |
} | |
} | |
} | |
catch (Exception e) { | |
System.Windows.Forms.MessageBox.Show(dbName + ": " + e.Message); | |
} | |
} | |
} | |
SqlConnection connection; | |
List<Procedure> procedures = new List<Procedure>(); | |
List<View> views = new List<View>(); | |
List<NamedQuery> parsedNamedQueries = new List<NamedQuery>(); | |
HashSet<string> usedNames = new HashSet<string>(); | |
static Dictionary<string, SqlDbType> dbTypeNameMap = new Dictionary<string, SqlDbType>(StringComparer.InvariantCultureIgnoreCase) { | |
{ "bigint", SqlDbType.BigInt }, | |
{ "binary", SqlDbType.Binary }, | |
{ "bit", SqlDbType.Bit }, | |
{ "char", SqlDbType.Char }, | |
{ "datetime", SqlDbType.DateTime }, | |
{ "decimal", SqlDbType.Decimal }, | |
{ "float", SqlDbType.Float }, | |
{ "image", SqlDbType.Image }, | |
{ "int", SqlDbType.Int }, | |
{ "money", SqlDbType.Decimal }, | |
{ "nchar", SqlDbType.NChar }, | |
{ "ntext", SqlDbType.NText }, | |
{ "numeric", SqlDbType.Float }, | |
{ "nvarchar", SqlDbType.NVarChar }, | |
{ "real", SqlDbType.Real }, | |
{ "smalldatetime", SqlDbType.SmallDateTime }, | |
{ "smallint", SqlDbType.SmallInt }, | |
{ "smallmoney", SqlDbType.SmallMoney }, | |
{ "sql_variant", SqlDbType.Variant }, | |
{ "sysname", SqlDbType.NVarChar }, | |
{ "text", SqlDbType.Text }, | |
{ "timestamp", SqlDbType.Timestamp }, | |
{ "tinyint", SqlDbType.TinyInt }, | |
{ "uniqueidentifier", SqlDbType.UniqueIdentifier }, | |
{ "varbinary", SqlDbType.VarBinary }, | |
{ "varchar", SqlDbType.VarChar }, | |
{ "xml", SqlDbType.Xml }, | |
}; | |
static Dictionary<SqlDbType, Type> dbTypeMap = new Dictionary<SqlDbType, Type>() { | |
{ SqlDbType.BigInt, typeof(long) }, | |
{ SqlDbType.Binary, typeof(byte[]) }, | |
{ SqlDbType.Bit, typeof(bool) }, | |
{ SqlDbType.Char, typeof(string) }, | |
{ SqlDbType.Date, typeof(DateTime) }, | |
{ SqlDbType.DateTime, typeof(DateTime) }, | |
{ SqlDbType.DateTime2, typeof(DateTime) }, | |
{ SqlDbType.DateTimeOffset, typeof(TimeSpan) }, | |
{ SqlDbType.Decimal, typeof(decimal) }, | |
{ SqlDbType.Float, typeof(double) }, | |
{ SqlDbType.Image, typeof(byte[]) }, | |
{ SqlDbType.Int, typeof(int) }, | |
{ SqlDbType.Money, typeof(decimal) }, | |
{ SqlDbType.NChar, typeof(string) }, | |
{ SqlDbType.NText, typeof(string) }, | |
{ SqlDbType.NVarChar, typeof(string) }, | |
{ SqlDbType.Real, typeof(float) }, | |
{ SqlDbType.SmallDateTime, typeof(DateTime) }, | |
{ SqlDbType.SmallInt, typeof(short) }, | |
{ SqlDbType.SmallMoney, typeof(decimal) }, | |
{ SqlDbType.Text, typeof(string) }, | |
{ SqlDbType.Time, typeof(DateTime) }, | |
{ SqlDbType.Timestamp, typeof(byte[]) }, | |
{ SqlDbType.TinyInt, typeof(byte) }, | |
{ SqlDbType.UniqueIdentifier, typeof(System.Guid) }, | |
{ SqlDbType.VarBinary, typeof(byte[]) }, | |
{ SqlDbType.VarChar, typeof(string) }, | |
{ SqlDbType.Variant, typeof(object) }, | |
{ SqlDbType.Xml, typeof(SqlXml) }, | |
// { SqlDbType.Structured, null }, | |
// { SqlDbType.Udt, null }, | |
}; | |
static string MakeCSharpName(string sqlSchema, string sqlName, bool firstUppercase) { | |
if (sqlSchema != null && !sqlSchema.Equals(defaultSchemaName, StringComparison.InvariantCultureIgnoreCase)) | |
sqlName = sqlSchema + "." + sqlName; | |
StringBuilder sb = new StringBuilder(); | |
bool uppercase = firstUppercase, lowercase = !firstUppercase; | |
foreach (char ch in sqlName) { | |
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) { | |
sb.Append(lowercase ? char.ToLower(ch) : (uppercase ? char.ToUpper(ch) : ch)); | |
uppercase = false; | |
lowercase = false; | |
} | |
else if (ch >= '0' && ch <= '9') { | |
if (sb.Length > 0) | |
sb.Append(ch); | |
uppercase = true; | |
} | |
else if (char.IsSeparator(ch) || ch == '_' || ch == '.') { | |
uppercase = true; | |
} | |
} | |
return sb.ToString(); | |
} | |
void GenerateResultSetRowType(ResultSet rs, bool includeColumnNameAttributes) { | |
WriteLine(accessibility + " sealed class " + rs.RowTypeName + " {"); | |
PushIndent("\t"); | |
foreach (ResultSetColumn col in rs.Columns) { | |
if (includeColumnNameAttributes) | |
WriteLine("[ColumnName(\"" + col.DbName + "\")]"); | |
WriteLine("public " + col.CSharpTypeName + " " + col.CSharpName + " { get; private set; }"); | |
} | |
WriteLine(""); | |
Write("public " + rs.RowTypeName + "("); | |
Write(string.Join(", ", rs.Columns.Select(c => c.CSharpTypeName + " " + MakeCSharpName(null, c.CSharpName, false)).ToArray())); | |
WriteLine(") {"); | |
PushIndent("\t"); | |
foreach (ResultSetColumn c in rs.Columns) { | |
WriteLine("this." + c.CSharpName + " = " + MakeCSharpName(null, c.CSharpName, false) + ";"); | |
} | |
PopIndent(); | |
WriteLine("}"); | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
} | |
void GenerateEnumeratorType(IList<ResultSet> resultSets, IList<Parameter> outParams, string typeName) { | |
Write(accessibility + " sealed class " + typeName + " : ResultSetEnumerator"); | |
if (resultSets.Count == 1) | |
Write(", IEnumerable<" + resultSets[0].RowTypeName + ">"); | |
WriteLine(" {"); | |
PushIndent("\t"); | |
if (generateSqlclr) { | |
WriteLine("#pragma warning disable 414"); | |
WriteLine("private bool preventStaticDelegates;"); | |
WriteLine("#pragma warning restore 414"); | |
WriteLine(""); | |
} | |
WriteLine("public " + typeName + "(SqlCommand command) : base(command) {"); | |
WriteLine("}"); | |
if (resultSets.Count == 1) { | |
WriteLine(""); | |
WriteLine("public IEnumerator<" + resultSets[0].RowTypeName + "> GetEnumerator() {"); | |
PushIndent("\t"); | |
WriteLine("return ResultSet.GetEnumerator();"); | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
WriteLine("IEnumerator IEnumerable.GetEnumerator() {"); | |
PushIndent("\t"); | |
WriteLine("return ResultSet.GetEnumerator();"); | |
PopIndent(); | |
WriteLine("}"); | |
} | |
for (int i = 0; i < resultSets.Count; i++) { | |
WriteLine(""); | |
WriteLine("public IEnumerable<" + resultSets[i].RowTypeName + "> ResultSet" + (resultSets.Count == 1 ? "" : Convert.ToString(i+1)) + " {"); | |
PushIndent("\t"); | |
WriteLine("get { return GetResultSetEnumerator(" + i.ToString() + ", " + resultSets[i].CreatorLambda + ", " + (resultSets.Count == 1 ? "true" : "false") + "); }"); | |
PopIndent(); | |
WriteLine("}"); | |
} | |
foreach (var param in outParams) { | |
WriteLine(""); | |
WriteLine("public " + param.CSharpTypeName + " " + param.CSharpNameU + " {"); | |
PushIndent("\t"); | |
WriteLine("get { return GetOutputParameter<" + param.CSharpTypeName + ">(\"" + param.DbName + "\"); }"); | |
PopIndent(); | |
WriteLine("}"); | |
} | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
} | |
void WriteStoredprocSignature(Procedure proc) { | |
string returnTypeName = proc.ResultSets.Count == 0 ? Parameter.ReturnValue.CSharpTypeName : proc.ResultTypeName; | |
bool onlyInputParams = (proc.ResultSets.Count > 0); | |
Write(returnTypeName + " " + proc.CSharpName + "("); | |
bool first = true; | |
foreach (Parameter param in proc.Parameters) { | |
if (!first) | |
Write(", "); | |
switch (param.Direction) { | |
case ParameterDirection.Input: | |
break; | |
case ParameterDirection.Output: | |
if (onlyInputParams) | |
continue; | |
Write("out "); | |
break; | |
case ParameterDirection.InputOutput: | |
if (!onlyInputParams) | |
Write("ref "); | |
break; | |
} | |
Write(param.CSharpTypeName + " " + param.CSharpNameL); | |
first = false; | |
} | |
Write(")"); | |
} | |
void WriteViewSignature(View view) { | |
Write("IEnumerable<" + view.RowTypeName + "> Query" + view.CSharpName + "("); | |
if (!generateSqlclr) | |
Write("System.Linq.Expressions.Expression<Func<" + view.ResultSet.RowTypeName + ", bool>> filter, System.Linq.Expressions.Expression<Func<" + view.ResultSet.RowTypeName + ", object>> order"); | |
Write(")"); | |
} | |
void WriteNamedQuerySignature(NamedQuery query) { | |
Write("IEnumerable<" + query.RowTypeName + "> " + query.Name + "(" + string.Join(", ", query.Parameters.Select(p => p.CSharpTypeName + " " + p.CSharpNameL).ToArray()) + ")"); | |
} | |
void GenerateStoredprocProxy(Procedure proc) { | |
Write("public "); | |
WriteStoredprocSignature(proc); | |
WriteLine(" {"); | |
PushIndent("\t"); | |
if (proc.ResultSets.Count == 0) { | |
WriteLine("using (SqlCommand cmd = CreateCommand()) {"); | |
PushIndent("\t"); | |
} | |
else { | |
WriteLine("SqlCommand cmd = CreateCommand();"); | |
} | |
WriteLine("cmd.CommandText = \"[" + proc.DbName.Schema + "].[" + proc.DbName.Name + "]\";"); | |
WriteLine("cmd.CommandType = CommandType.StoredProcedure;"); | |
WriteLine("cmd.CommandTimeout = CommandTimeout;"); | |
WriteLine("SqlParameter p;"); | |
foreach (Parameter param in proc.Parameters) { | |
Write("p = cmd.Parameters.Add(\"" + param.DbName + "\", SqlDbType." + param.DbType.ToString()); | |
if (param.DbSize != 0) | |
Write(", " + param.DbSize); | |
WriteLine(");"); | |
WriteLine("p.Value = (object)" + param.CSharpNameL + " ?? " + (param.DbType == SqlDbType.Xml ? "System.Data.SqlTypes.SqlXml.Null" : "DBNull.Value") + ";"); | |
if (param.Direction != ParameterDirection.Input) | |
WriteLine("p.Direction = ParameterDirection." + param.Direction + ";"); | |
} | |
WriteLine("p = cmd.Parameters.Add(\"" + Parameter.ReturnValue.DbName + "\", SqlDbType." + Parameter.ReturnValue.DbType.ToString() + ");"); | |
WriteLine("p.Direction = ParameterDirection.ReturnValue;"); | |
if (proc.ResultSets.Count == 0) { | |
WriteLine("cmd.ExecuteNonQuery();"); | |
foreach (Parameter param in proc.Parameters.Where(p => p.Direction != ParameterDirection.Input)) { | |
WriteLine(param.CSharpNameL + " = SqlFunctions.DbNullCast<" + param.CSharpTypeName + ">(cmd.Parameters[\"" + param.DbName + "\"].Value);"); | |
} | |
WriteLine("return SqlFunctions.DbNullCast<" + Parameter.ReturnValue.CSharpTypeName + ">(cmd.Parameters[\"" + Parameter.ReturnValue.DbName + "\"].Value);"); | |
PopIndent(); | |
WriteLine("}"); | |
} | |
else { | |
WriteLine("return new " + proc.CSharpName + "Result(cmd);"); | |
} | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
} | |
void GenerateViewProxy(View view) { | |
Write("public "); | |
WriteViewSignature(view); | |
WriteLine(" {"); | |
PushIndent("\t"); | |
WriteLine("return ResultSetEnumerator.ExecuteQuery(connection, transaction, \"SELECT " + string.Join(", ", view.ResultSet.Columns.Select(c => "[" + c.DbName + "]").ToArray()) + " FROM [" + view.DbName.Schema + "].[" + view.DbName.Name + "]\", " + view.ResultSet.CreatorLambda + (generateSqlclr ? "" : ", filter, order") + ", CommandTimeout);"); | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
} | |
void GenerateNamedQueryProxy(NamedQuery query) { | |
Write("public "); | |
WriteNamedQuerySignature(query); | |
WriteLine(" {"); | |
PushIndent("\t"); | |
if (query.Parameters.Count > 0) { | |
WriteLine("List<SqlParameter> parms = new List<SqlParameter>();"); | |
WriteLine("SqlParameter p;"); | |
foreach (Parameter param in query.Parameters) { | |
Write("p = new SqlParameter(\"" + param.DbName + "\", SqlDbType." + param.DbType.ToString()); | |
if (param.DbSize != 0) | |
Write(", " + param.DbSize); | |
WriteLine(");"); | |
WriteLine("p.Value = (object)" + param.CSharpNameL + " ?? DBNull.Value;"); | |
WriteLine("parms.Add(p);"); | |
} | |
} | |
WriteLine("return ResultSetEnumerator.ExecuteQuery(connection, transaction, @\"" + query.QueryText.Replace("\"", "\"\"") + "\", " + (query.Parameters.Count > 0 ? "parms" : "null") + ", " + query.ResultSet.CreatorLambda + ", CommandTimeout);"); | |
PopIndent(); | |
WriteLine("}"); | |
WriteLine(""); | |
} | |
#> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment