Skip to content

Instantly share code, notes, and snippets.

@erik-kallen
Created August 30, 2012 12:41
Show Gist options
  • Save erik-kallen/3527712 to your computer and use it in GitHub Desktop.
Save erik-kallen/3527712 to your computer and use it in GitHub Desktop.
SqlProxy
<#@ include file="SqlProxy.ttinc" #>
<#+
static string connectionString = "Server=.; Integrated Security=true; Database=AdventureWorks";
static string namespaceName = "DAL";
static string databaseClassName = "DatabaseProxy";
static bool includeProcedures = true;
static bool includeViews = true;
static bool includeTables = false;
static int commandTimeout = 600;
static bool generateSqlclr = false;
static SchemaAndName[] procsToInclude = null;
static SchemaAndName[] procsToExclude = new SchemaAndName[0];
static string accessibility = "public";
static string[] schemaNames = new[] { "dbo", "HumanResources" };
static string defaultSchemaName = "dbo";
static SchemaAndName[] viewsToInclude = null;
static SchemaAndName[] viewsToExclude = null;
static Dictionary<SchemaAndName, object[]> customExtractions
= new Dictionary<SchemaAndName, object[]>() {};
static Dictionary<string, string> namedQueries = static Dictionary<string, string> namedQueries = new Dictionary<string, string> { { "EmployeesByBirthDate", "SELECT e.EmployeeID, e.HireDate, c.FirstName, c.LastName, c.Phone FROM HumanResources.Employee e JOIN Person.Contact c ON c.ContactID = e.ContactID WHERE e.BirthDate BETWEEN @{minDate: DateTime} AND @{maxDate: DateTime}" } };
;
#>
using System;
using System.Collections.Generic;
using System.Text;
using System.Collections.ObjectModel;
using System.Data.SqlClient;
using System.Collections;
using System.Data.SqlTypes;
using System.Xml;
using System.IO;
#if !SQLPROXY_SQLCLR
using System.Linq;
using System.Linq.Expressions;
#endif
namespace SqlProxy {
#if SQLPROXY_SQLCLR
public delegate TResult Func<TResult>();
public delegate TResult Func<T, TResult>(T arg1);
public delegate TResult Func<T1, T2, TResult>(T1 arg1, T2 arg2);
public delegate TResult Func<T1, T2, T3, TResult>(T1 arg1, T2 arg2, T3 arg3);
public delegate TResult Func<T1, T2, T3, T4, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4);
#endif
public delegate TResult Func<T1, T2, T3, T4, T5, TResult>(T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5);
#if !SQLPROXY_SQLCLR
[AttributeUsage(AttributeTargets.Property)]
internal sealed class ColumnNameAttribute : Attribute {
public string Name { get; set; }
public ColumnNameAttribute() {
}
public ColumnNameAttribute(string name) {
this.Name = name;
}
}
[AttributeUsage(AttributeTargets.Method)]
internal sealed class SqlFunctionAttribute : Attribute {
public enum FunctionType {
Function,
Case,
IsNotNull,
IsNull,
NoOp,
Like,
NotLike,
In,
NotIn
}
public FunctionType Type { get; private set; }
public string Name { get; private set; }
public string MagicFirstArg { get; set; }
public bool ParamList { get; set; }
public SqlFunctionAttribute(string name) {
this.Type = FunctionType.Function;
this.Name = name;
this.ParamList = true;
}
public SqlFunctionAttribute(FunctionType type) {
this.Type = type;
}
}
#endif // !SQLPROXY_SQLCLR
public static class SqlFunctions {
#if !SQLPROXY_SQLCLR
[SqlFunction("SYSDATETIME")] public static DateTime SysDateTime() { throw new Exception("Don't call"); }
[SqlFunction("SYSDATETIMEOFFSET")] public static DateTime SysDateTimeOffset() { throw new Exception("Don't call"); }
[SqlFunction("SYSUTCDATETIME")] public static DateTime SysUtcDateTime() { throw new Exception("Don't call"); }
[SqlFunction("CURRENT_TIMESTAMP", ParamList = false)] public static DateTime CurrentTimestamp() { throw new Exception("Don't call"); }
[SqlFunction("GETDATE")] public static DateTime GetDate() { throw new Exception("Don't call"); }
[SqlFunction("GETUTCDATE")] public static DateTime GetUtcDate() { throw new Exception("Don't call"); }
[SqlFunction("DATENAME")] public static string DateName(string datePart, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEPART")] public static int DatePart(string datePart, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DAY")] public static int Day(DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("MONTH")] public static int Month(DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("YEAR")] public static int Year(DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF")] public static int DateDiff(string datePart, DateTime startDate, DateTime endDate) { throw new Exception("Don't call"); }
[SqlFunction("SWITCHOFFSET")] public static TimeSpan SwitchOffset(TimeSpan dateTimeOffset, string timeZone) { throw new Exception("Don't call"); }
[SqlFunction("TODATETIMEOFFSET")] public static TimeSpan ToDateTimeOffset(DateTime dateTime, string timeZone) { throw new Exception("Don't call"); }
[SqlFunction("@@DATEFIRST", ParamList = false)] public static byte DateFirst() { throw new Exception("Don't call"); }
[SqlFunction("ISDATE")] public static int IsDate(object expression) { throw new Exception("Don't call"); }
[SqlFunction("ABS")] public static T Abs<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("ACOS")] public static T Acos<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("ASIN")] public static T Asin<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("ATAN")] public static T Atan<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("ATN2")] public static T Atn2<T, U>(T expr1, U expr2) { throw new Exception("Don't call"); }
[SqlFunction("CEILING")] public static T Ceiling<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("DEGREES")] public static T Degrees<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("EXP")] public static T Exp<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("FLOOR")] public static T Floor<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("LOG")] public static T Log<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("LOG10")] public static T Log10<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("PI")] public static double Pi() { throw new Exception("Don't call"); }
[SqlFunction("POWER")] public static T Power<T, U>(T b, U e) { throw new Exception("Don't call"); }
[SqlFunction("RADIANS")] public static T Radians<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("RAND")] public static double Rand() { throw new Exception("Don't call"); }
[SqlFunction("RAND")] public static double Rand(int seed) { throw new Exception("Don't call"); }
[SqlFunction("ROUND")] public static T Round<T>(T expression, int length) { throw new Exception("Don't call"); }
[SqlFunction("ROUND")] public static T Round<T>(T expression, int length, int function) { throw new Exception("Don't call"); }
[SqlFunction("SIGN")] public static T Sign<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("SIN")] public static T Sin<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("SQRT")] public static T Sqrt<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("SQUARE")] public static T Square<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("TAN")] public static T Tan<T>(T expression) { throw new Exception("Don't call"); }
[SqlFunction("CURRENT_USER", ParamList = false)] public static string CurrentUser() { throw new Exception("Don't call"); }
[SqlFunction("USER_NAME")] public static string UserName(int id) { throw new Exception("Don't call"); }
[SqlFunction("SESSION_USER", ParamList = false)] public static string SessionUser() { throw new Exception("Don't call"); }
[SqlFunction("IS_MEMBER")] public static int? IsMember(string group) { throw new Exception("Don't call"); }
[SqlFunction("ASCII")] public static int Ascii(string expression) { throw new Exception("Don't call"); }
[SqlFunction("CHAR")] public static string Char(int expression) { throw new Exception("Don't call"); }
[SqlFunction("CHARINDEX")] public static int CharIndex(string expression1, string expression2) { throw new Exception("Don't call"); }
[SqlFunction("CHARINDEX")] public static int CharIndex(string expression1, string expression2, long startLocation) { throw new Exception("Don't call"); }
[SqlFunction("DIFFERENCE")] public static int Difference(string expression1, string expression2) { throw new Exception("Don't call"); }
[SqlFunction("LEFT")] public static string Left(string str, long count) { throw new Exception("Don't call"); }
[SqlFunction("LEN")] public static int Len(string expression) { throw new Exception("Don't call"); }
[SqlFunction("LOWER")] public static string Lower(string expression) { throw new Exception("Don't call"); }
[SqlFunction("LTRIM")] public static string LTrim(string expression) { throw new Exception("Don't call"); }
[SqlFunction("NCHAR")] public static string NChar(int expression) { throw new Exception("Don't call"); }
[SqlFunction("PATINDEX")] public static long PatIndex(string pattern, string expression) { throw new Exception("Don't call"); }
[SqlFunction("QUOTENAME")] public static string QuoteName(string expression) { throw new Exception("Don't call"); }
[SqlFunction("QUOTENAME")] public static string QuoteName(string expression, string quoteChar) { throw new Exception("Don't call"); }
[SqlFunction("REPLACE")] public static string Replace(string expression, string find, string replace) { throw new Exception("Don't call"); }
[SqlFunction("REPLICATE")] public static string Replicate(string expression, int count) { throw new Exception("Don't call"); }
[SqlFunction("REVERSE")] public static string Reverse(string expression) { throw new Exception("Don't call"); }
[SqlFunction("RIGHT")] public static string Right(string str, long count) { throw new Exception("Don't call"); }
[SqlFunction("RTRIM")] public static string RTrim(string expression) { throw new Exception("Don't call"); }
[SqlFunction("SOUNDEX")] public static string SoundEx(string expression) { throw new Exception("Don't call"); }
[SqlFunction("SPACE")] public static string Space(int count) { throw new Exception("Don't call"); }
[SqlFunction("STR")] public static string Str(double expression) { throw new Exception("Don't call"); }
[SqlFunction("STR")] public static string Str(double expression, int length) { throw new Exception("Don't call"); }
[SqlFunction("STR")] public static string Str(double expression, int length, int dec) { throw new Exception("Don't call"); }
[SqlFunction("STUFF")] public static string Stuff(double expr1, int start, int length, string expr2) { throw new Exception("Don't call"); }
[SqlFunction("SUBSTRING")] public static string Substring(string expression, int start, int length) { throw new Exception("Don't call"); }
[SqlFunction("UNICODE")] public static int Unicode(string expression) { throw new Exception("Don't call"); }
[SqlFunction("UPPER")] public static string Upper(string expression) { throw new Exception("Don't call"); }
[SqlFunction("APP_NAME")] public static string AppName() { throw new Exception("Don't call"); }
[SqlFunction("COALESCE")] public static T Coalesce<T>(params T[] values) { throw new Exception("Don't call"); }
[SqlFunction("ISNUMERIC")] public static int IsNumeric(object expression) { throw new Exception("Don't call"); }
[SqlFunction("NEWID")] public static Guid NewId(object expression) { throw new Exception("Don't call"); }
[SqlFunction("NULLIF")] public static T? NullIfV<T>(T expr1, T expr2) where T : struct { throw new Exception("Don't call"); }
[SqlFunction("NULLIF")] public static T NullIfR<T>(T expr1, T expr2) where T : class { throw new Exception("Don't call"); }
[SqlFunction("@@LANGUAGE", ParamList = false)] public static string Language() { throw new Exception("Don't call"); }
[SqlFunction("ISNULL")] public static bool IsNull<T>(T expression, T nullvalue) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.Case)] public static object Case(bool test1, object case1, params object[] otherFiltersAndCases) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.IsNull)] public static bool IsNull(object expression) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.IsNotNull)] public static bool IsNotNull(object expression) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "yy")] public static DateTime DateAddYy(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "q")] public static DateTime DateAddQ(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "m")] public static DateTime DateAddM(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "dy")] public static DateTime DateAddDy(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "d")] public static DateTime DateAddD(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "ww")] public static DateTime DateAddWw(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "dw")] public static DateTime DateAddDw(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "hh")] public static DateTime DateAddHh(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "mi")] public static DateTime DateAddMi(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "s")] public static DateTime DateAddS(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "ms")] public static DateTime DateAddMs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "mcs")] public static DateTime DateAddMcs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEADD", MagicFirstArg = "ns")] public static DateTime DateAddNs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "yy")] public static DateTime DateDiffYy(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "q")] public static DateTime DateDiffQ(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "m")] public static DateTime DateDiffM(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "dy")] public static DateTime DateDiffDy(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "d")] public static DateTime DateDiffD(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "ww")] public static DateTime DateDiffWw(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "dw")] public static DateTime DateDiffDw(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "hh")] public static DateTime DateDiffHh(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "mi")] public static DateTime DateDiffMi(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "s")] public static DateTime DateDiffS(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "ms")] public static DateTime DateDiffMs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "mcs")] public static DateTime DateDiffMcs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("DATEDIFF", MagicFirstArg = "ns")] public static DateTime DateDiffNs(int number, DateTime date) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "varchar")] public static string ConvertToString(object expression) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "varchar")] public static string ConvertToString(object expression, int style) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "int")] public static string ConvertToInt(object expression) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "int")] public static string ConvertToInt(object expression, int style) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "datetime")] public static string ConvertToDateTime(object expression) { throw new Exception("Don't call"); }
[SqlFunction("CONVERT", MagicFirstArg = "datetime")] public static string ConvertToDateTime(object expression, int style) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.NoOp)] public static T ImplicitConvert<T>(object o) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.Like)] public static bool Like(string expression, string pattern) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.NotLike)] public static bool NotLike(string expression, string pattern) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.In)] public static bool In<T>(T expr, params T[] values) { throw new Exception("Don't call"); }
[SqlFunction(SqlFunctionAttribute.FunctionType.NotIn)] public static bool NotIn<T>(T expr, params T[] values) { throw new Exception("Don't call"); }
#endif // !SQLPROXY_SQLCLR
public static T DbNullCast<T>(object o) {
if (typeof(T) == typeof(SqlXml))
return ((SqlXml)o).IsNull ? default(T) : (T)o;
else
return o is DBNull ? default(T) : (T)o;
}
public static SqlXml LoadXml(string xml) {
return xml != null ? new SqlXml(XmlReader.Create(new StringReader(xml), new XmlReaderSettings { CloseInput = true })) : null;
}
public static XmlDocument LoadSqlXml(SqlXml xml) {
var d = new XmlDocument();
using (var rdr = xml.CreateReader()) {
d.Load(rdr);
}
return d;
}
}
#if !SQLPROXY_SQLCLR
internal abstract class ExpressionVisitor {
protected ExpressionVisitor() {
}
protected virtual Expression Visit(Expression exp) {
if (exp == null)
return exp;
switch (exp.NodeType) {
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Not:
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.ArrayLength:
case ExpressionType.Quote:
case ExpressionType.TypeAs:
return this.VisitUnary((UnaryExpression)exp);
case ExpressionType.Add:
case ExpressionType.AddChecked:
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.And:
case ExpressionType.AndAlso:
case ExpressionType.Or:
case ExpressionType.OrElse:
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Coalesce:
case ExpressionType.ArrayIndex:
case ExpressionType.RightShift:
case ExpressionType.LeftShift:
case ExpressionType.ExclusiveOr:
return this.VisitBinary((BinaryExpression)exp);
case ExpressionType.TypeIs:
return this.VisitTypeIs((TypeBinaryExpression)exp);
case ExpressionType.Conditional:
return this.VisitConditional((ConditionalExpression)exp);
case ExpressionType.Constant:
return this.VisitConstant((ConstantExpression)exp);
case ExpressionType.Parameter:
return this.VisitParameter((ParameterExpression)exp);
case ExpressionType.MemberAccess:
return this.VisitMemberAccess((MemberExpression)exp);
case ExpressionType.Call:
return this.VisitMethodCall((MethodCallExpression)exp);
case ExpressionType.Lambda:
return this.VisitLambda((LambdaExpression)exp);
case ExpressionType.New:
return this.VisitNew((NewExpression)exp);
case ExpressionType.NewArrayInit:
case ExpressionType.NewArrayBounds:
return this.VisitNewArray((NewArrayExpression)exp);
case ExpressionType.Invoke:
return this.VisitInvocation((InvocationExpression)exp);
case ExpressionType.MemberInit:
return this.VisitMemberInit((MemberInitExpression)exp);
case ExpressionType.ListInit:
return this.VisitListInit((ListInitExpression)exp);
default:
throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
}
}
protected virtual MemberBinding VisitBinding(MemberBinding binding) {
switch (binding.BindingType) {
case MemberBindingType.Assignment:
return this.VisitMemberAssignment((MemberAssignment)binding);
case MemberBindingType.MemberBinding:
return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
case MemberBindingType.ListBinding:
return this.VisitMemberListBinding((MemberListBinding)binding);
default:
throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
}
}
protected virtual ElementInit VisitElementInitializer(ElementInit initializer) {
ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
if (arguments != initializer.Arguments) {
return Expression.ElementInit(initializer.AddMethod, arguments);
}
return initializer;
}
protected virtual Expression VisitUnary(UnaryExpression u) {
Expression operand = this.Visit(u.Operand);
if (operand != u.Operand) {
return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
}
return u;
}
protected virtual Expression VisitBinary(BinaryExpression b) {
Expression left = this.Visit(b.Left);
Expression right = this.Visit(b.Right);
Expression conversion = this.Visit(b.Conversion);
if (left != b.Left || right != b.Right || conversion != b.Conversion) {
if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)
return Expression.Coalesce(left, right, conversion as LambdaExpression);
else
return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
}
return b;
}
protected virtual Expression VisitTypeIs(TypeBinaryExpression b) {
Expression expr = this.Visit(b.Expression);
if (expr != b.Expression) {
return Expression.TypeIs(expr, b.TypeOperand);
}
return b;
}
protected virtual Expression VisitConstant(ConstantExpression c) {
return c;
}
protected virtual Expression VisitConditional(ConditionalExpression c) {
Expression test = this.Visit(c.Test);
Expression ifTrue = this.Visit(c.IfTrue);
Expression ifFalse = this.Visit(c.IfFalse);
if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {
return Expression.Condition(test, ifTrue, ifFalse);
}
return c;
}
protected virtual Expression VisitParameter(ParameterExpression p) {
return p;
}
protected virtual Expression VisitMemberAccess(MemberExpression m) {
Expression exp = this.Visit(m.Expression);
if (exp != m.Expression) {
return Expression.MakeMemberAccess(exp, m.Member);
}
return m;
}
protected virtual Expression VisitMethodCall(MethodCallExpression m) {
Expression obj = this.Visit(m.Object);
IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
if (obj != m.Object || args != m.Arguments) {
return Expression.Call(obj, m.Method, args);
}
return m;
}
protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) {
List<Expression> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
Expression p = this.Visit(original[i]);
if (list != null) {
list.Add(p);
}
else if (p != original[i]) {
list = new List<Expression>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(p);
}
}
if (list != null) {
return list.AsReadOnly();
}
return original;
}
protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) {
Expression e = this.Visit(assignment.Expression);
if (e != assignment.Expression) {
return Expression.Bind(assignment.Member, e);
}
return assignment;
}
protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {
IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
if (bindings != binding.Bindings) {
return Expression.MemberBind(binding.Member, bindings);
}
return binding;
}
protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) {
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
if (initializers != binding.Initializers) {
return Expression.ListBind(binding.Member, initializers);
}
return binding;
}
protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) {
List<MemberBinding> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
MemberBinding b = this.VisitBinding(original[i]);
if (list != null) {
list.Add(b);
}
else if (b != original[i]) {
list = new List<MemberBinding>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(b);
}
}
if (list != null)
return list;
return original;
}
protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) {
List<ElementInit> list = null;
for (int i = 0, n = original.Count; i < n; i++) {
ElementInit init = this.VisitElementInitializer(original[i]);
if (list != null) {
list.Add(init);
}
else if (init != original[i]) {
list = new List<ElementInit>(n);
for (int j = 0; j < i; j++) {
list.Add(original[j]);
}
list.Add(init);
}
}
if (list != null)
return list;
return original;
}
protected virtual Expression VisitLambda(LambdaExpression lambda) {
Expression body = this.Visit(lambda.Body);
if (body != lambda.Body) {
return Expression.Lambda(lambda.Type, body, lambda.Parameters);
}
return lambda;
}
protected virtual NewExpression VisitNew(NewExpression nex) {
IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
if (args != nex.Arguments) {
if (nex.Members != null)
return Expression.New(nex.Constructor, args, nex.Members);
else
return Expression.New(nex.Constructor, args);
}
return nex;
}
protected virtual Expression VisitMemberInit(MemberInitExpression init) {
NewExpression n = this.VisitNew(init.NewExpression);
IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
if (n != init.NewExpression || bindings != init.Bindings) {
return Expression.MemberInit(n, bindings);
}
return init;
}
protected virtual Expression VisitListInit(ListInitExpression init) {
NewExpression n = this.VisitNew(init.NewExpression);
IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
if (n != init.NewExpression || initializers != init.Initializers) {
return Expression.ListInit(n, initializers);
}
return init;
}
protected virtual Expression VisitNewArray(NewArrayExpression na) {
IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
if (exprs != na.Expressions) {
if (na.NodeType == ExpressionType.NewArrayInit) {
return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
}
else {
return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
}
}
return na;
}
protected virtual Expression VisitInvocation(InvocationExpression iv) {
IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
Expression expr = this.Visit(iv.Expression);
if (args != iv.Arguments || expr != iv.Expression) {
return Expression.Invoke(expr, args);
}
return iv;
}
}
internal static class Evaluator {
/// <summary>
/// Performs evaluation & replacement of independent sub-trees
/// </summary>
/// <param name="expression">The root of the expression tree.</param>
/// <param name="fnCanBeEvaluated">A function that decides whether a given expression node can be part of the local function.</param>
/// <returns>A new tree with sub-trees evaluated and replaced.</returns>
public static Expression PartialEval(Expression expression, Func<Expression, bool> fnCanBeEvaluated) {
return new SubtreeEvaluator(new Nominator(fnCanBeEvaluated).Nominate(expression)).Eval(expression);
}
/// <summary>
/// Performs evaluation & replacement of independent sub-trees
/// </summary>
/// <param name="expression">The root of the expression tree.</param>
/// <returns>A new tree with sub-trees evaluated and replaced.</returns>
public static Expression PartialEval(Expression expression) {
return PartialEval(expression, Evaluator.CanBeEvaluatedLocally);
}
private static bool CanBeEvaluatedLocally(Expression expression) {
switch (expression.NodeType) {
case ExpressionType.Parameter:
return false;
case ExpressionType.Call:
MethodCallExpression m = (MethodCallExpression)expression;
return Attribute.GetCustomAttribute(m.Method, typeof(SqlFunctionAttribute)) == null;
default:
return true;
}
}
/// <summary>
/// Evaluates & replaces sub-trees when first candidate is reached (top-down)
/// </summary>
class SubtreeEvaluator: ExpressionVisitor {
HashSet<Expression> candidates;
internal SubtreeEvaluator(HashSet<Expression> candidates) {
this.candidates = candidates;
}
internal Expression Eval(Expression exp) {
return this.Visit(exp);
}
protected override Expression Visit(Expression exp) {
if (exp == null) {
return null;
}
if (this.candidates.Contains(exp)) {
return this.Evaluate(exp);
}
exp = base.Visit(exp);
if (exp.NodeType == ExpressionType.AndAlso || exp.NodeType == ExpressionType.OrElse) {
// If either branch is constant, we need to perform some manual evaluation, just in case someone writes (1 = 1 || r.Whatever = 'Something')
// (SQL server does not support TRUE/FALSE literals, and the equivalent (1=1) and (1=0) break when using bit parameters and solving this case is easier than the generic one)
BinaryExpression be = (BinaryExpression)exp;
ConstantExpression ce = be.Left as ConstantExpression;
if (ce != null) {
if ((bool)ce.Value == true) { // first time in my life that an explicit comparison to true seems like a good idea (although unnecessary)
return (exp.NodeType == ExpressionType.AndAlso ? be.Right : Expression.Constant(true));
}
else {
return (exp.NodeType == ExpressionType.AndAlso ? Expression.Constant(false) : be.Right);
}
}
ce = be.Right as ConstantExpression;
if (ce != null) {
if ((bool)ce.Value == true) { // first time in my life that an explicit comparison to true seems like a good idea (although unnecessary)
return (exp.NodeType == ExpressionType.AndAlso ? be.Left : Expression.Constant(true));
}
else {
return (exp.NodeType == ExpressionType.AndAlso ? Expression.Constant(false) : be.Left);
}
}
}
return exp;
}
private Expression Evaluate(Expression e) {
if (e.NodeType == ExpressionType.Constant) {
return e;
}
LambdaExpression lambda = Expression.Lambda(e);
Delegate fn = lambda.Compile();
return Expression.Constant(fn.DynamicInvoke(null), e.Type);
}
}
/// <summary>
/// Performs bottom-up analysis to determine which nodes can possibly
/// be part of an evaluated sub-tree.
/// </summary>
class Nominator : ExpressionVisitor {
Func<Expression, bool> fnCanBeEvaluated;
HashSet<Expression> candidates;
bool cannotBeEvaluated;
internal Nominator(Func<Expression, bool> fnCanBeEvaluated) {
this.fnCanBeEvaluated = fnCanBeEvaluated;
}
internal HashSet<Expression> Nominate(Expression expression) {
this.candidates = new HashSet<Expression>();
this.Visit(expression);
return this.candidates;
}
protected override Expression Visit(Expression expression) {
if (expression != null) {
bool saveCannotBeEvaluated = this.cannotBeEvaluated;
this.cannotBeEvaluated = false;
base.Visit(expression);
if (!this.cannotBeEvaluated) {
if (this.fnCanBeEvaluated(expression)) {
this.candidates.Add(expression);
}
else {
this.cannotBeEvaluated = true;
}
}
this.cannotBeEvaluated |= saveCannotBeEvaluated;
}
return expression;
}
}
}
internal sealed class SqlBuilder : ExpressionVisitor {
private StringBuilder result = new StringBuilder();
private SqlCommand command;
private SqlBuilder(SqlCommand command) {
this.command = command;
}
public static void AddWhereClause<T>(SqlCommand command, Expression<Func<T, bool>> filter) {
if (command == null) throw new ArgumentNullException("command");
if (filter == null) throw new ArgumentNullException("filter");
Expression e = Evaluator.PartialEval(filter.Body);
if (e.NodeType == ExpressionType.Constant) {
// need special handling for the delegate r => true
command.CommandText += " WHERE " + ((bool)((ConstantExpression)e).Value ? "1=1" : "1=0");
}
else {
SqlBuilder b = new SqlBuilder(command);
b.Visit(e);
command.CommandText += " WHERE " + b.result.ToString();
}
}
public static void AddOrderByClause<T>(SqlCommand command, Expression<Func<T, object>> order) {
if (command == null) throw new ArgumentNullException("command");
if (order == null) throw new ArgumentNullException("order");
string result = null;
foreach (Expression expr in (order.Body.NodeType == ExpressionType.NewArrayInit ? ((NewArrayExpression)order.Body).Expressions : Enumerable.Repeat(order.Body, 1))) {
Expression e = Evaluator.PartialEval(expr);
string currentResult;
if (e.NodeType == ExpressionType.Constant) {
// handle expressions such as ORDER BY 1, 2
ConstantExpression c = (ConstantExpression)e;
Type t = c.Value.GetType();
int ordinal;
if (t == typeof(byte)) ordinal = (int)(byte)c.Value;
else if (t == typeof(sbyte)) ordinal = (int)(sbyte)c.Value;
else if (t == typeof(short)) ordinal = (int)(short)c.Value;
else if (t == typeof(ushort)) ordinal = (int)(ushort)c.Value;
else if (t == typeof(int)) ordinal = (int)c.Value;
else if (t == typeof(uint)) ordinal = (int)(uint)c.Value;
else if (t == typeof(long)) ordinal = (int)(long)c.Value;
else if (t == typeof(ulong)) ordinal = (int)(ulong)c.Value;
else throw new ArgumentException(string.Format("The constant value {0} is not a number.", c.Value));
currentResult = Convert.ToString(ordinal, System.Globalization.CultureInfo.InvariantCulture);
}
else {
bool isDesc = false;
while (e.NodeType == ExpressionType.Convert || e.NodeType == ExpressionType.ConvertChecked) {
e = ((UnaryExpression)e).Operand;
}
if (e.NodeType == ExpressionType.Negate || e.NodeType == ExpressionType.NegateChecked) {
e = ((UnaryExpression)e).Operand;
isDesc = true;
}
SqlBuilder b = new SqlBuilder(command);
b.Visit(e);
currentResult = b.result.ToString();
if (isDesc)
currentResult += " DESC";
}
result = (result == null ? currentResult.ToString() : (result + ", " + currentResult.ToString()));
}
command.CommandText += " ORDER BY " + result;
}
protected override Expression VisitUnary(UnaryExpression u) {
switch (u.NodeType) {
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Not: {
result.Append(u.NodeType == ExpressionType.Not ? "NOT (" : "-(");
Visit(u.Operand);
result.Append(")");
break;
}
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.TypeAs:
Visit(u.Operand); // No-op in dynamically typed SQL
break;
case ExpressionType.Quote: // I don't really know the purpose of these, but they should be removed from the tree.
throw new NotSupportedException("Quotes are not supported.");
case ExpressionType.ArrayLength:
throw new NotSupportedException("Array lengths are not supported.");
default:
throw new NotSupportedException("Unsupported expression type.");
}
return u;
}
protected override Expression VisitBinary(BinaryExpression b) {
string infixOper;
switch (b.NodeType) {
case ExpressionType.Add:
case ExpressionType.AddChecked:
infixOper = "+";
break;
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
infixOper = "-";
break;
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
infixOper = "*";
break;
case ExpressionType.Divide:
infixOper = "/";
break;
case ExpressionType.Modulo:
infixOper = "%";
break;
case ExpressionType.And:
infixOper = "&";
break;
case ExpressionType.AndAlso:
infixOper = "AND";
break;
case ExpressionType.Or:
infixOper = "|";
break;
case ExpressionType.OrElse:
infixOper = "OR";
break;
case ExpressionType.LessThan:
infixOper = "<";
break;
case ExpressionType.LessThanOrEqual:
infixOper = "<=";
break;
case ExpressionType.GreaterThan:
infixOper = ">";
break;
case ExpressionType.GreaterThanOrEqual:
infixOper = ">=";
break;
case ExpressionType.Equal:
case ExpressionType.NotEqual:
/* // special handling because null = null returns false in sql
if (b.Left.NodeType == ExpressionType.Constant && ((ConstantExpression)b.Left).Value == null) {
result.Append("(");
right = Visit(b.Right);
result.Append(b.NodeType == ExpressionType.Equal ? ") IS NULL" : ") IS NOT NULL");
return (right != b.Right ? Expression.MakeBinary(b.NodeType, b.Left, right) : b);
}
else if (b.Right.NodeType == ExpressionType.Constant && ((ConstantExpression)b.Right).Value == null) {
result.Append("(");
left = Visit(b.Left);
result.Append(b.NodeType == ExpressionType.Equal ? ") IS NULL" : ") IS NOT NULL");
return (right != b.Right ? Expression.MakeBinary(b.NodeType, left, b.Right) : b);
}
else*/
infixOper = (b.NodeType == ExpressionType.Equal ? "=" : "<>");
break;
case ExpressionType.ExclusiveOr:
infixOper = "^";
break;
case ExpressionType.Coalesce:
result.Append("COALESCE(");
Visit(b.Left);
result.Append(", ");
Visit(b.Right);
result.Append(")");
return b;
case ExpressionType.ArrayIndex:
throw new NotSupportedException("The array index operator is not supported.");
case ExpressionType.RightShift:
throw new NotSupportedException("The right shift operator is not supported.");
case ExpressionType.LeftShift:
throw new NotSupportedException("The left shift operator is not supported.");
default:
throw new NotSupportedException("Unsupported expression type.");
}
// must be an infix operation
result.Append("(");
Visit(b.Left);
result.Append(") ").Append(infixOper).Append(" (");
Visit(b.Right);
result.Append(")");
return b;
}
protected override Expression VisitConditional(ConditionalExpression c) {
result.Append("CASE WHEN (");
Visit(c.Test);
result.Append(") THEN (");
Visit(c.IfTrue);
result.Append(") ELSE (");
Visit(c.IfFalse);
result.Append(") END");
return c;
}
private bool IsVarargMethod(System.Reflection.MethodInfo m) {
System.Reflection.ParameterInfo[] ps = m.GetParameters();
return Attribute.GetCustomAttribute(ps[ps.Length - 1], typeof(ParamArrayAttribute)) != null;
}
protected override Expression VisitMethodCall(MethodCallExpression m) {
if (m.Object != null)
throw new NotSupportedException("Can only call static methods.");
SqlFunctionAttribute a = (SqlFunctionAttribute)Attribute.GetCustomAttribute(m.Method, typeof(SqlFunctionAttribute));
if (a == null)
throw new NotSupportedException(string.Format("The method {0} is not decorated with a SqlFunctionAttribute.", m.Method.Name));
switch (a.Type) {
case SqlFunctionAttribute.FunctionType.Function:
result.Append(a.Name);
if (a.ParamList)
result.Append("(");
if (!string.IsNullOrEmpty(a.MagicFirstArg))
result.Append(a.MagicFirstArg).Append(", ");
break;
case SqlFunctionAttribute.FunctionType.Case:
result.Append("CASE");
break;
case SqlFunctionAttribute.FunctionType.IsNull:
case SqlFunctionAttribute.FunctionType.IsNotNull:
case SqlFunctionAttribute.FunctionType.NoOp:
case SqlFunctionAttribute.FunctionType.Like:
case SqlFunctionAttribute.FunctionType.NotLike:
case SqlFunctionAttribute.FunctionType.In:
case SqlFunctionAttribute.FunctionType.NotIn:
result.Append("(");
break;
default:
throw new NotSupportedException("Bad SqlFunctionAttribute attribute.");
}
// flatten the argument list to compensate for varargs being transformed to an array
IList<Expression> actualArguments;
if (IsVarargMethod(m.Method)) {
Expression lastArg = m.Arguments[m.Arguments.Count - 1];
IEnumerable<Expression> x;
if (lastArg.NodeType == ExpressionType.NewArrayInit)
x = ((NewArrayExpression)lastArg).Expressions;
else if (lastArg.NodeType == ExpressionType.Constant && ((ConstantExpression)lastArg).Type.IsArray) {
ConstantExpression c = (ConstantExpression)lastArg;
Array arr = (Array)c.Value;
List<Expression> l = new List<Expression>();
foreach (object o in arr)
l.Add(Expression.Constant(o, c.Type.GetElementType()));
x = l;
}
else
x = Enumerable.Repeat(lastArg, 1);
actualArguments = m.Arguments.Take(m.Arguments.Count - 1).Concat(x).ToList();
}
else {
actualArguments = m.Arguments;
}
for (int i = 0, n = actualArguments.Count; i < n; i++) {
switch (a.Type) {
case SqlFunctionAttribute.FunctionType.Function:
if (i > 0) result.Append(", ");
break;
case SqlFunctionAttribute.FunctionType.Case:
if (i % 2 == 0)
result.Append(i < n - 1 ? " WHEN " : " ELSE ");
else
result.Append(" THEN ");
break;
case SqlFunctionAttribute.FunctionType.Like:
if (i > 0) result.Append(") LIKE (");
break;
case SqlFunctionAttribute.FunctionType.NotLike:
if (i > 0) result.Append(") NOT LIKE (");
break;
case SqlFunctionAttribute.FunctionType.In:
case SqlFunctionAttribute.FunctionType.NotIn:
switch (i) {
case 0: break;
case 1: result.Append(a.Type == SqlFunctionAttribute.FunctionType.In ? ") IN(" : ") NOT IN("); break;
default: result.Append(", "); break;
}
break;
}
this.Visit(actualArguments[i]);
}
switch (a.Type) {
case SqlFunctionAttribute.FunctionType.Function:
if (a.ParamList)
result.Append(")");
break;
case SqlFunctionAttribute.FunctionType.Case:
result.Append(" END");
break;
case SqlFunctionAttribute.FunctionType.IsNull:
result.Append(") IS NULL");
break;
case SqlFunctionAttribute.FunctionType.IsNotNull:
result.Append(") IS NOT NULL");
break;
case SqlFunctionAttribute.FunctionType.Like:
case SqlFunctionAttribute.FunctionType.NotLike:
case SqlFunctionAttribute.FunctionType.NoOp:
case SqlFunctionAttribute.FunctionType.In:
case SqlFunctionAttribute.FunctionType.NotIn:
result.Append(")");
break;
}
return m;
}
protected override Expression VisitConstant(ConstantExpression c) {
string paramName = "@p" + command.Parameters.Count;
result.Append(paramName);
command.Parameters.AddWithValue(paramName, c.Value ?? DBNull.Value);
return c;
}
protected override Expression VisitMemberAccess(MemberExpression m) {
Visit(m.Expression);
if (m.Expression.NodeType != ExpressionType.Parameter)
throw new NotSupportedException("Member accesses are only allowed with the lambda parameter as the object.");
ColumnNameAttribute a = (ColumnNameAttribute)Attribute.GetCustomAttribute(m.Member, typeof(ColumnNameAttribute));
if (a == null)
throw new NotSupportedException("Member accesses are only allowed for columns which are decorated with a ColumnNameAttribute.");
result.Append("[").Append(a.Name).Append("]");
return m;
}
#region Unsupported expression types
protected override MemberBinding VisitBinding(MemberBinding binding) {
throw new NotSupportedException("MemberBinding expressions are not supported.");
}
protected override ElementInit VisitElementInitializer(ElementInit initializer) {
throw new NotSupportedException("Element initializer expressions are not supported.");
}
protected override Expression VisitTypeIs(TypeBinaryExpression b) {
throw new NotSupportedException("Type is expressions are not supported.");
}
protected override MemberAssignment VisitMemberAssignment(MemberAssignment assignment) {
throw new NotSupportedException("Member assignment expressions are not supported.");
}
protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {
throw new NotSupportedException("Member member binding expressions are not supported.");
}
protected override MemberListBinding VisitMemberListBinding(MemberListBinding binding) {
throw new NotSupportedException("Member list binding expressions are not supported.");
}
protected override Expression VisitLambda(LambdaExpression lambda) {
throw new NotSupportedException("Lambda expressions are not supported (except at the top level).");
}
protected override NewExpression VisitNew(NewExpression nex) {
throw new NotSupportedException("`new´ expressions are not supported.");
}
protected override Expression VisitMemberInit(MemberInitExpression init) {
throw new NotSupportedException("Member initialization expressions are not supported.");
}
protected override Expression VisitListInit(ListInitExpression init) {
throw new NotSupportedException("List initialization expressions are not supported.");
}
protected override Expression VisitNewArray(NewArrayExpression na) {
throw new NotSupportedException("`new[]´ expressions are not supported (except at the top level for order by expressions.");
}
protected override Expression VisitInvocation(InvocationExpression iv) {
throw new NotSupportedException("Invocations are not supported");
}
#endregion
}
#endif // !SQLPROXY_SQLCLR
public abstract class ResultSetEnumerator : IDisposable {
private SqlDataReader reader;
private SqlCommand command;
private int currentResultSet = -1;
private int AfterLastResultSet = int.MaxValue;
private Dictionary<string, object> outputParameters;
protected ResultSetEnumerator(SqlCommand command) {
this.command = command;
try {
this.reader = command.ExecuteReader();
}
catch (Exception) {
command.Dispose(); // The command can never be used for anything and our dispose method will not be called
command = null;
throw;
}
}
private void AdvanceToRS(int num) {
if (currentResultSet == -1)
currentResultSet = 0;
while (currentResultSet < num) {
if (!reader.NextResult()) {
currentResultSet = AfterLastResultSet;
return;
}
currentResultSet++;
}
}
protected IEnumerable<T> GetResultSetEnumerator<T>(int resultSetNumber, Func<SqlDataReader, T> selector, bool isOnlyResultSet) {
if (currentResultSet >= resultSetNumber)
throw new InvalidOperationException("The result sets must be enumerated in order and only once per result set.");
if (reader == null)
throw new ObjectDisposedException(null);
try {
AdvanceToRS(resultSetNumber);
while (reader.Read()) {
yield return selector(reader);
if (currentResultSet != resultSetNumber)
throw new InvalidOperationException("It is not possible to continue the enumeration of a result set after enumerating a later one.");
if (reader == null)
throw new ObjectDisposedException(null);
}
if (isOnlyResultSet)
ReadOutputParameters();
}
finally {
if (isOnlyResultSet)
Dispose();
}
}
protected void ReadOutputParameters() {
if (currentResultSet != AfterLastResultSet) {
AdvanceToRS(AfterLastResultSet);
outputParameters = new Dictionary<string, object>();
foreach (SqlParameter p in command.Parameters) {
if (p.Direction != System.Data.ParameterDirection.Input)
outputParameters[p.ParameterName] = (p.Value is DBNull ? null : p.Value);
}
}
}
protected T GetOutputParameter<T>(string name) {
if (outputParameters == null)
ReadOutputParameters();
return SqlFunctions.DbNullCast<T>(outputParameters[name]);
}
public void Dispose() {
if (reader != null) {
reader.Dispose();
reader = null;
}
if (command != null) {
command.Dispose();
command = null;
}
}
#if !SQLPROXY_SQLCLR
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, Func<SqlDataReader, T> selector, Expression<Func<T, bool>> filter, Expression<Func<T, object>> order, int commandTimeout) {
SqlCommand cmd = connection.CreateCommand();
cmd.CommandType = System.Data.CommandType.Text;
cmd.CommandTimeout = commandTimeout;
cmd.CommandText = query;
cmd.Transaction = transaction;
if (filter != null) SqlBuilder.AddWhereClause(cmd, filter);
if (order != null) SqlBuilder.AddOrderByClause(cmd, order);
return new ViewResultSet<T>(cmd, selector);
}
#else
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, Func<SqlDataReader, T> selector, int commandTimeout) {
SqlCommand cmd = connection.CreateCommand();
cmd.CommandType = System.Data.CommandType.Text;
cmd.CommandTimeout = commandTimeout;
cmd.CommandText = query;
cmd.Transaction = transaction;
return new ViewResultSet<T>(cmd, selector);
}
#endif
public static ViewResultSet<T> ExecuteQuery<T>(SqlConnection connection, SqlTransaction transaction, string query, IEnumerable<SqlParameter> parms, Func<SqlDataReader, T> selector, int commandTimeout) {
SqlCommand cmd = connection.CreateCommand();
cmd.CommandType = System.Data.CommandType.Text;
cmd.CommandTimeout = commandTimeout;
cmd.CommandText = query;
cmd.Transaction = transaction;
if (parms != null) {
foreach (SqlParameter p in parms)
cmd.Parameters.Add(p);
}
return new ViewResultSet<T>(cmd, selector);
}
}
public class ViewResultSet<T> : ResultSetEnumerator, IEnumerable<T> {
private Func<SqlDataReader, T> selector;
public ViewResultSet(SqlCommand command, Func<SqlDataReader, T> selector) : base(command) {
this.selector = selector;
}
public IEnumerator<T> GetEnumerator() {
return ResultSet.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() {
return GetEnumerator();
}
public IEnumerable<T> ResultSet {
get { return GetResultSetEnumerator(0, selector, true); }
}
}
}
<#@ template hostspecific="true" debug="true" #>
<#@ output extension=".cs" #>
<#@ assembly name="System.Data.dll" #>
<#@ assembly name="System.Xml.dll" #>
<#@ assembly name="System.Core.dll" #>
<#@ assembly name="System.Windows.Forms.dll" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Linq.Expressions" #>
<#@ import namespace="System.Data.SqlClient" #>
<#@ import namespace="System.Data" #>
<#@ import namespace="System.Data.SqlTypes" #>
<#@ import namespace="System.Collections.ObjectModel" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Text.RegularExpressions" #>
using System;
using System.Data;
using System.Data.SqlClient;
using System.Data.SqlTypes;
using System.Collections;
using System.Collections.Generic;
using SqlProxy;
namespace <#= namespaceName #>
{
<#
PushIndent("\t");
using (connection = new SqlConnection(connectionString)) {
connection.Open();
List<SchemaAndName> procNames = new List<SchemaAndName>(), viewNames = new List<SchemaAndName>();
if (includeViews) {
using (SqlCommand cmd = connection.CreateCommand()) {
cmd.CommandType = CommandType.Text;
cmd.CommandText = "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.VIEWS WHERE lower(table_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")";
using (var rdr = cmd.ExecuteReader()) {
while (rdr.Read()) {
viewNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1]));
}
}
}
}
if (includeTables) {
using (SqlCommand cmd = connection.CreateCommand()) {
cmd.CommandType = CommandType.Text;
cmd.CommandText = "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.TABLES WHERE lower(table_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")";
using (var rdr = cmd.ExecuteReader()) {
while (rdr.Read()) {
viewNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1]));
}
}
}
}
if (includeProcedures) {
using (SqlCommand cmd = connection.CreateCommand()) {
cmd.CommandType = CommandType.Text;
cmd.CommandText = "SELECT routine_schema, routine_name FROM INFORMATION_SCHEMA.ROUTINES WHERE routine_type = 'PROCEDURE' AND lower(routine_schema) IN(" + string.Join(", ", schemaNames.Select(x => "'" + x.ToLower() + "'").ToArray()) + ")";
using (var rdr = cmd.ExecuteReader()) {
while (rdr.Read()) {
procNames.Add(new SchemaAndName((string)rdr[0], (string)rdr[1]));
}
}
}
}
if (namedQueries != null) {
Regex parameterRegex = new Regex(@"@\{\s*([a-zA-Z0-9_]+)\s*:\s*([a-zA-Z0-9_]+)\s*}");
foreach (var kvp in namedQueries.OrderBy(kvp => kvp.Key)) {
var parms = new List<Parameter>();
foreach (Match m in parameterRegex.Matches(kvp.Value)) {
SqlDbType dbType;
try {
dbType = (SqlDbType)Enum.Parse(typeof(SqlDbType), m.Groups[2].Value, true);
}
catch (ArgumentException) {
System.Windows.Forms.MessageBox.Show("Query " + kvp.Key + ": The SqlDbType " + m.Groups[2].Value + " does not exist.");
goto continueOuter;
}
parms.Add(new Parameter("@" + m.Groups[1].Value, dbType, GetMaxSize(dbType), ParameterDirection.Input));
}
string query = parameterRegex.Replace(kvp.Value, m => "@" + m.Groups[1].Value);
parsedNamedQueries.Add(new NamedQuery(connection, kvp.Key, query, parms, usedNames));
continueOuter:;
}
}
foreach (var procName in procNames.OrderBy(x => x.Schema).ThenBy(x => x.Name)) {
if ((procsToInclude == null || procsToInclude.Contains(procName)) &&
(procsToExclude == null || !procsToExclude.Contains(procName))) {
procedures.Add(new Procedure(connection, procName, usedNames));
}
}
foreach (var viewName in viewNames.OrderBy(x => x.Schema).ThenBy(x => x.Name)) {
if ((viewsToInclude == null || viewsToInclude.Contains(viewName)) &&
(viewsToExclude == null || !viewsToExclude.Contains(viewName))) {
views.Add(new View(connection, viewName, usedNames));
}
}
foreach (var proc in procedures) {
foreach (var rs in proc.ResultSets) {
GenerateResultSetRowType(rs, false);
}
if (proc.ResultSets.Count > 0)
GenerateEnumeratorType(proc.ResultSets, proc.Parameters.Where(p => p.Direction != ParameterDirection.Input).Union(Enumerable.Repeat(Parameter.ReturnValue, 1)).ToList(), proc.ResultTypeName);
}
foreach (var view in views) {
GenerateResultSetRowType(view.ResultSet, !generateSqlclr);
}
foreach (NamedQuery query in parsedNamedQueries) {
GenerateResultSetRowType(query.ResultSet, false);
}
}
#>
<#= accessibility #> partial interface I<#= databaseClassName #> : IDisposable {
void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim);
void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel);
void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim);
void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel);
IDbTransaction BeginTransaction(IsolationLevel isolationLevel);
IDbCommand CreateCommand();
IDbCommand CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters);
IDbConnection Connection { get; }
string SaveTransaction();
void RollbackToSavepoint(string savepointName);
<#
PushIndent("\t");
foreach (Procedure proc in procedures) {
WriteStoredprocSignature(proc);
WriteLine(";");
}
foreach (View view in views) {
WriteViewSignature(view);
WriteLine(";");
}
foreach (NamedQuery query in parsedNamedQueries) {
WriteNamedQuerySignature(query);
WriteLine(";");
}
PopIndent();
#>
}
<#= accessibility #> sealed partial class <#= databaseClassName #> : IDisposable, I<#= databaseClassName #>
{
private class TransactionClass : IDbTransaction, IDisposable {
private SqlTransaction t;
private <#= databaseClassName #> db;
public TransactionClass(<#= databaseClassName #> db) {
if (db.transaction != null)
throw new InvalidOperationException("There is already an open transaction");
this.t = db.connection.BeginTransaction();
this.db = db;
db.transaction = this;
}
public TransactionClass(<#= databaseClassName #> db, IsolationLevel il) {
if (db.transaction != null)
throw new InvalidOperationException("There is already an open transaction");
this.t = db.connection.BeginTransaction(il);
this.db = db;
db.transaction = this;
}
public void Commit() {
t.Commit();
db.transaction = null;
}
public IDbConnection Connection {
get { return t.Connection; }
}
public IsolationLevel IsolationLevel {
get { return t.IsolationLevel; }
}
public void Rollback() {
t.Rollback();
db.transaction = null;
}
public void Dispose() {
t.Dispose();
db.transaction = null;
}
public static implicit operator SqlTransaction(TransactionClass t) {
return (t != null ? t.t : null);
}
}
private TransactionClass transaction;
private SqlConnection connection;
private bool ownsConnection;
public <#= databaseClassName #>(string connectionString) : this(new SqlConnection(connectionString), true) {
this.connection.Open();
}
public <#= databaseClassName #>(SqlConnection connection, bool takeOwnership) {
if (connection == null)
throw new ArgumentNullException("connection");
this.CommandTimeout = <#= Convert.ToString(commandTimeout, System.Globalization.CultureInfo.InvariantCulture) #>;
this.ownsConnection = takeOwnership;
this.connection = connection;
}
public void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim) {
ExecuteTransacted(db => { func(db); return true; }, retryIfDeadlockVictim, IsolationLevel.Unspecified);
}
public void ExecuteTransacted(Action<I<#= databaseClassName #>> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel) {
ExecuteTransacted(db => { func(db); return true; }, retryIfDeadlockVictim, isolationLevel);
}
public void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim) {
ExecuteTransacted(func, retryIfDeadlockVictim, IsolationLevel.Unspecified);
}
public void ExecuteTransacted(Func<I<#= databaseClassName #>, bool> func, bool retryIfDeadlockVictim, IsolationLevel isolationLevel) {
retry:
try {
using (var tran = BeginTransaction(isolationLevel)) {
if (func(this))
tran.Commit();
}
return;
}
catch (SqlException ex) {
if (retryIfDeadlockVictim && ex.Number == 1205)
goto retry;
throw;
}
}
public IDbTransaction BeginTransaction(IsolationLevel isolationLevel) {
return new TransactionClass(this, isolationLevel);
}
private IDbTransaction Transaction { get { return transaction; } }
public IDbConnection Connection { get { return connection; } }
IDbCommand I<#= databaseClassName #>.CreateCommand() {
return CreateCommand();
}
IDbCommand I<#= databaseClassName #>.CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters) {
return CreateCommand(commandType, commandText, parameters);
}
public SqlCommand CreateCommand() {
SqlCommand cmd = connection.CreateCommand();
cmd.Transaction = transaction;
return cmd;
}
public SqlCommand CreateCommand(CommandType commandType, string commandText, IEnumerable<KeyValuePair<string, object>> parameters) {
var cmd = CreateCommand();
cmd.CommandType = commandType;
cmd.CommandText = commandText;
if (parameters != null) {
foreach (var p in parameters)
cmd.Parameters.AddWithValue("@" + p.Key, p.Value ?? DBNull.Value);
}
return cmd;
}
public int CommandTimeout { get; set; }
public void Dispose() {
if (connection != null) {
if (ownsConnection)
connection.Dispose();
connection = null;
}
}
public string SaveTransaction() {
string name = Guid.NewGuid().ToString("N");
using (var cmd = CreateCommand(CommandType.Text, "SAVE TRANSACTION @name", new[] { new KeyValuePair<string, object>("name", name) })) {
cmd.ExecuteNonQuery();
}
return name;
}
public void RollbackToSavepoint(string savepointName) {
using (var cmd = CreateCommand(CommandType.Text, "IF xact_state() <> -1 ROLLBACK TRANSACTION @name", new[] { new KeyValuePair<string, object>("name", savepointName) })) {
cmd.ExecuteNonQuery();
}
}
<#
PushIndent("\t");
if (generateSqlclr) {
WriteLine("#pragma warning disable 414");
WriteLine("private bool preventStaticDelegates;");
WriteLine("#pragma warning restore 414");
WriteLine("");
}
foreach (Procedure proc in procedures) {
GenerateStoredprocProxy(proc);
}
foreach (View view in views) {
GenerateViewProxy(view);
}
foreach (NamedQuery query in parsedNamedQueries) {
GenerateNamedQueryProxy(query);
}
PopIndent();
PopIndent();
#>
}
}
<#+
static int GetMaxSize(SqlDbType dbType) {
switch (dbType) {
case SqlDbType.BigInt: return 8;
case SqlDbType.Binary: return int.MaxValue;
case SqlDbType.Bit: return 1;
case SqlDbType.Char: return 8000;
case SqlDbType.Date: return 0x7fffffff;
case SqlDbType.DateTime: return 0x7fffffff;
case SqlDbType.DateTime2: return 0x7fffffff;
case SqlDbType.DateTimeOffset: return 0x7fffffff;
case SqlDbType.Decimal: return 0x7fffffff;
case SqlDbType.Float: return 4;
case SqlDbType.Image: return 0x7fffffff;
case SqlDbType.Int: return 4;
case SqlDbType.Money: return 0x7fffffff;
case SqlDbType.NChar: return 4000;
case SqlDbType.NText: return 0x3fffffff;
case SqlDbType.NVarChar: return 0x3fffffff;
case SqlDbType.Real: return 0x7fffffff;
case SqlDbType.SmallDateTime: return 0x7fffffff;
case SqlDbType.SmallInt: return 2;
case SqlDbType.SmallMoney: return 0x7fffffff;
case SqlDbType.Text: return 0x7fffffff;
case SqlDbType.Time: return 0x7fffffff;
case SqlDbType.Timestamp: return 0x7fffffff;
case SqlDbType.TinyInt: return 1;
case SqlDbType.UniqueIdentifier: return 0x7fffffff;
case SqlDbType.VarBinary: return 0x7fffffff;
case SqlDbType.VarChar: return 0x7fffffff;
case SqlDbType.Variant: return 0x7fffffff;
case SqlDbType.Xml: return 0;
//case SqlDbType.Structured:
//case SqlDbType.Udt:
default:
throw new ArgumentException("dbType");
}
}
static void SetFmtonly(SqlConnection con, SqlTransaction tran, bool value) {
using (SqlCommand cmd = con.CreateCommand()) {
cmd.Transaction = tran;
cmd.CommandText = "SET FMTONLY " + (value ? "ON" : "OFF");
cmd.CommandType = CommandType.Text;
cmd.ExecuteNonQuery();
}
}
static string FindUniqueName(string desiredName, HashSet<string> usedNames) {
string name = desiredName;
int num = 1;
while (usedNames.Contains(name))
name = desiredName + Convert.ToString(num++, System.Globalization.CultureInfo.InvariantCulture);
usedNames.Add(name);
return name;
}
sealed class SchemaAndName {
public string Schema { get; private set; }
public string Name { get; private set; }
public SchemaAndName(string schema, string name) {
this.Schema = schema;
this.Name = name;
}
public override bool Equals(object obj) {
var other = obj as SchemaAndName;
return other != null && this.Schema == other.Schema && this.Name == other.Name;
}
public override int GetHashCode() {
return Schema.GetHashCode() ^ Name.GetHashCode();
}
public override string ToString() {
return "[" + Schema + "].[" + Name + "]";
}
}
class Parameter {
private static Parameter returnValue = new Parameter("@ReturnValue", SqlDbType.Int, 4, ParameterDirection.ReturnValue);
public static Parameter ReturnValue { get { return returnValue; } }
public string DbName { get; private set; }
public SqlDbType DbType { get; private set; }
public int DbSize { get; private set; }
public ParameterDirection Direction { get; private set; }
public string CSharpNameU { get; private set; }
public string CSharpNameL { get; private set; }
public string CSharpTypeName { get; private set; }
public Parameter(string dbName, SqlDbType dbType, int dbSize, ParameterDirection direction) {
this.DbName = dbName;
this.DbType = dbType;
if (dbSize == -1)
this.DbSize = GetMaxSize(dbType);
else
this.DbSize = dbSize;
this.Direction = direction;
this.CSharpNameU = MakeCSharpName(null, dbName.Substring(1), true);
this.CSharpNameL = MakeCSharpName(null, dbName.Substring(1), false);
Type type;
if (!dbTypeMap.TryGetValue(dbType, out type))
throw new ArgumentException(string.Format("Unknown SqlDbType {0} used by field {1}.", dbType, dbName));
this.CSharpTypeName = type.ToString() + (type.IsValueType && (Direction != ParameterDirection.ReturnValue) ? "?" : "");
}
}
class ResultSetColumn {
public string DbName { get; private set; }
public bool Nullable { get; private set; }
public string CSharpName { get; private set; }
public Type CSharpType { get; private set; }
public string CSharpTypeName { get; private set; }
public ResultSetColumn(string dbName, Type csharpType, bool nullable) {
this.DbName = dbName;
this.Nullable = nullable;
this.CSharpName = MakeCSharpName(null, dbName, true);
this.CSharpType = csharpType;
this.CSharpTypeName = csharpType.ToString() + ((csharpType.IsValueType && Nullable) ? "?" : "");
}
}
class ResultSet {
public string RowTypeName { get; private set; }
public ReadOnlyCollection<ResultSetColumn> Columns { get; private set; }
public ResultSet(string rowTypeName, DataTable schemaTable, HashSet<string> usedNames, bool everythingNullable) {
RowTypeName = FindUniqueName(rowTypeName, usedNames);
List<ResultSetColumn> columns = new List<ResultSetColumn>();
int unnamedIndex = 1;
foreach (DataRow row in schemaTable.Rows) {
string columnName = (string)row["ColumnName"];
if (string.IsNullOrEmpty(columnName))
columnName = ("Unnamed" + unnamedIndex++);
columns.Add(new ResultSetColumn(columnName, (Type)row["DataType"], everythingNullable || (bool)row["AllowDbNull"]));
}
Columns = columns.AsReadOnly();
}
public string CreatorLambda {
get {
string doer = "new " + RowTypeName + "(" + string.Join(", ", Columns.Select((c, x) => "SqlFunctions.DbNullCast<" + c.CSharpTypeName + ">(reader[" + x + "])").ToArray()) + ")";
return "reader => " + (generateSqlclr ? "{ preventStaticDelegates = true; return " + doer + "; }" : doer);
}
}
public ResultSet(string rowTypeName, ReadOnlyCollection<ResultSetColumn> columns, HashSet<string> usedNames) {
this.RowTypeName = FindUniqueName(rowTypeName, usedNames);
this.Columns = columns;
}
}
class View {
public SchemaAndName DbName { get; private set; }
public string CSharpName { get; private set; }
public string RowTypeName { get; private set; }
public ResultSet ResultSet { get; private set; }
public View(SqlConnection connection, SchemaAndName dbName, HashSet<string> usedNames) {
this.DbName = dbName;
this.CSharpName = MakeCSharpName(dbName.Schema, dbName.Name, true);
SetFmtonly(connection, null, true);
using (SqlCommand cmd = connection.CreateCommand()) {
cmd.CommandType = CommandType.Text;
cmd.CommandText = "SELECT * FROM [" + dbName.Schema + "].[" + dbName.Name + "]";
using(SqlDataReader rdr = cmd.ExecuteReader()) {
using (DataTable schemaTable = rdr.GetSchemaTable()) {
ResultSet = new ResultSet(this.CSharpName + "Row", schemaTable, usedNames, false);
this.RowTypeName = ResultSet.RowTypeName;
}
}
}
SetFmtonly(connection, null, false);
}
}
class NamedQuery {
public string Name { get; private set; }
public string QueryText { get; private set; }
public string RowTypeName { get; private set; }
public ReadOnlyCollection<Parameter> Parameters { get; private set; }
public ResultSet ResultSet { get; private set; }
public NamedQuery(SqlConnection connection, string name, string queryText, IList<Parameter> parameters, HashSet<string> usedNames) {
this.Name = name;
this.QueryText = queryText;
this.Parameters = new List<Parameter>(parameters).AsReadOnly();
SetFmtonly(connection, null, true);
using (SqlCommand cmd = connection.CreateCommand()) {
cmd.CommandType = CommandType.Text;
cmd.CommandText = queryText;
foreach (var p in parameters)
cmd.Parameters.AddWithValue(p.DbName, DBNull.Value);
using(SqlDataReader rdr = cmd.ExecuteReader()) {
using (DataTable schemaTable = rdr.GetSchemaTable()) {
ResultSet = new ResultSet(this.Name + "Row", schemaTable, usedNames, true);
this.RowTypeName = ResultSet.RowTypeName;
}
}
}
SetFmtonly(connection, null, false);
}
}
class Procedure {
public SchemaAndName DbName { get; private set; }
public string CSharpName { get; private set; }
public string ResultTypeName { get; private set; }
public ReadOnlyCollection<Parameter> Parameters { get; private set; }
public ReadOnlyCollection<ResultSet> ResultSets { get; private set; }
public Procedure(SqlConnection connection, SchemaAndName dbName, HashSet<string> usedNames) {
this.DbName = dbName;
this.CSharpName = MakeCSharpName(dbName.Schema, dbName.Name, true);
this.ResultTypeName = FindUniqueName(this.CSharpName + "Result", usedNames);
try {
using (SqlCommand cmd = connection.CreateCommand()) {
List<Parameter> parameters = new List<Parameter>();
object[] customParams = (customExtractions != null && customExtractions.ContainsKey(dbName)) ? customExtractions[dbName] : null;
cmd.CommandText = "SELECT p.name, st.name, p.max_length, p.is_output FROM sys.parameters p JOIN sys.types t ON t.user_type_id = p.user_type_id JOIN sys.types st ON st.user_type_id = t.system_type_id WHERE p.object_id = object_id(@name) ORDER BY p.parameter_id";
cmd.CommandType = CommandType.Text;
cmd.Parameters.AddWithValue("@name", "[" + dbName.Schema + "].[" + dbName.Name + "]");
using (var rdr = cmd.ExecuteReader()) {
while (rdr.Read()) {
SqlDbType dbType;
if (!dbTypeNameMap.TryGetValue((string)rdr[1], out dbType)) {
throw new Exception("Unknown database type " + (string)rdr[1]);
}
parameters.Add(new Parameter((string)rdr[0], dbType, (short)rdr[2], (bool)rdr[3] ? ParameterDirection.InputOutput : ParameterDirection.Input));
}
}
cmd.Parameters.Clear();
cmd.CommandText = "[" + dbName.Schema + "].[" + dbName.Name + "]";
cmd.CommandType = CommandType.StoredProcedure;
for (int i = 0, n = parameters.Count; i < n; i++) {
object value = (customParams != null && i < customParams.Length ? customParams[i] : null);
if (value == null)
value = (parameters[i].DbType == SqlDbType.Xml ? (object)SqlXml.Null : DBNull.Value);
var p = cmd.Parameters.Add(parameters[i].DbName, parameters[i].DbType);
if (parameters[i].DbSize != 0)
p.Size = parameters[i].DbSize;
p.Value = value;
p.Direction = parameters[i].Direction;
}
using (SqlTransaction tran = connection.BeginTransaction()) {
cmd.Transaction = tran;
if (customParams == null)
SetFmtonly(connection, tran, true);
List<ResultSet> resultSets = new List<ResultSet>();
try {
using(SqlDataReader rdr = cmd.ExecuteReader()) {
if (rdr != null) {
int i = 0;
do {
using (DataTable schemaTable = rdr.GetSchemaTable()) {
if (schemaTable != null) {
i++;
resultSets.Add(new ResultSet(CSharpName + "ResultSet" + i + "Row", schemaTable, usedNames, false));
}
}
} while (rdr.NextResult());
if (i == 1) {
usedNames.Remove(resultSets[0].RowTypeName);
resultSets[0] = new ResultSet(CSharpName + "ResultSetRow", resultSets[0].Columns, usedNames);
}
}
}
}
finally {
if (customParams == null)
SetFmtonly(connection, tran, false);
}
Parameters = parameters.AsReadOnly();
ResultSets = resultSets.AsReadOnly();
tran.Rollback();
}
}
}
catch (Exception e) {
System.Windows.Forms.MessageBox.Show(dbName + ": " + e.Message);
}
}
}
SqlConnection connection;
List<Procedure> procedures = new List<Procedure>();
List<View> views = new List<View>();
List<NamedQuery> parsedNamedQueries = new List<NamedQuery>();
HashSet<string> usedNames = new HashSet<string>();
static Dictionary<string, SqlDbType> dbTypeNameMap = new Dictionary<string, SqlDbType>(StringComparer.InvariantCultureIgnoreCase) {
{ "bigint", SqlDbType.BigInt },
{ "binary", SqlDbType.Binary },
{ "bit", SqlDbType.Bit },
{ "char", SqlDbType.Char },
{ "datetime", SqlDbType.DateTime },
{ "decimal", SqlDbType.Decimal },
{ "float", SqlDbType.Float },
{ "image", SqlDbType.Image },
{ "int", SqlDbType.Int },
{ "money", SqlDbType.Decimal },
{ "nchar", SqlDbType.NChar },
{ "ntext", SqlDbType.NText },
{ "numeric", SqlDbType.Float },
{ "nvarchar", SqlDbType.NVarChar },
{ "real", SqlDbType.Real },
{ "smalldatetime", SqlDbType.SmallDateTime },
{ "smallint", SqlDbType.SmallInt },
{ "smallmoney", SqlDbType.SmallMoney },
{ "sql_variant", SqlDbType.Variant },
{ "sysname", SqlDbType.NVarChar },
{ "text", SqlDbType.Text },
{ "timestamp", SqlDbType.Timestamp },
{ "tinyint", SqlDbType.TinyInt },
{ "uniqueidentifier", SqlDbType.UniqueIdentifier },
{ "varbinary", SqlDbType.VarBinary },
{ "varchar", SqlDbType.VarChar },
{ "xml", SqlDbType.Xml },
};
static Dictionary<SqlDbType, Type> dbTypeMap = new Dictionary<SqlDbType, Type>() {
{ SqlDbType.BigInt, typeof(long) },
{ SqlDbType.Binary, typeof(byte[]) },
{ SqlDbType.Bit, typeof(bool) },
{ SqlDbType.Char, typeof(string) },
{ SqlDbType.Date, typeof(DateTime) },
{ SqlDbType.DateTime, typeof(DateTime) },
{ SqlDbType.DateTime2, typeof(DateTime) },
{ SqlDbType.DateTimeOffset, typeof(TimeSpan) },
{ SqlDbType.Decimal, typeof(decimal) },
{ SqlDbType.Float, typeof(double) },
{ SqlDbType.Image, typeof(byte[]) },
{ SqlDbType.Int, typeof(int) },
{ SqlDbType.Money, typeof(decimal) },
{ SqlDbType.NChar, typeof(string) },
{ SqlDbType.NText, typeof(string) },
{ SqlDbType.NVarChar, typeof(string) },
{ SqlDbType.Real, typeof(float) },
{ SqlDbType.SmallDateTime, typeof(DateTime) },
{ SqlDbType.SmallInt, typeof(short) },
{ SqlDbType.SmallMoney, typeof(decimal) },
{ SqlDbType.Text, typeof(string) },
{ SqlDbType.Time, typeof(DateTime) },
{ SqlDbType.Timestamp, typeof(byte[]) },
{ SqlDbType.TinyInt, typeof(byte) },
{ SqlDbType.UniqueIdentifier, typeof(System.Guid) },
{ SqlDbType.VarBinary, typeof(byte[]) },
{ SqlDbType.VarChar, typeof(string) },
{ SqlDbType.Variant, typeof(object) },
{ SqlDbType.Xml, typeof(SqlXml) },
// { SqlDbType.Structured, null },
// { SqlDbType.Udt, null },
};
static string MakeCSharpName(string sqlSchema, string sqlName, bool firstUppercase) {
if (sqlSchema != null && !sqlSchema.Equals(defaultSchemaName, StringComparison.InvariantCultureIgnoreCase))
sqlName = sqlSchema + "." + sqlName;
StringBuilder sb = new StringBuilder();
bool uppercase = firstUppercase, lowercase = !firstUppercase;
foreach (char ch in sqlName) {
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
sb.Append(lowercase ? char.ToLower(ch) : (uppercase ? char.ToUpper(ch) : ch));
uppercase = false;
lowercase = false;
}
else if (ch >= '0' && ch <= '9') {
if (sb.Length > 0)
sb.Append(ch);
uppercase = true;
}
else if (char.IsSeparator(ch) || ch == '_' || ch == '.') {
uppercase = true;
}
}
return sb.ToString();
}
void GenerateResultSetRowType(ResultSet rs, bool includeColumnNameAttributes) {
WriteLine(accessibility + " sealed class " + rs.RowTypeName + " {");
PushIndent("\t");
foreach (ResultSetColumn col in rs.Columns) {
if (includeColumnNameAttributes)
WriteLine("[ColumnName(\"" + col.DbName + "\")]");
WriteLine("public " + col.CSharpTypeName + " " + col.CSharpName + " { get; private set; }");
}
WriteLine("");
Write("public " + rs.RowTypeName + "(");
Write(string.Join(", ", rs.Columns.Select(c => c.CSharpTypeName + " " + MakeCSharpName(null, c.CSharpName, false)).ToArray()));
WriteLine(") {");
PushIndent("\t");
foreach (ResultSetColumn c in rs.Columns) {
WriteLine("this." + c.CSharpName + " = " + MakeCSharpName(null, c.CSharpName, false) + ";");
}
PopIndent();
WriteLine("}");
PopIndent();
WriteLine("}");
WriteLine("");
}
void GenerateEnumeratorType(IList<ResultSet> resultSets, IList<Parameter> outParams, string typeName) {
Write(accessibility + " sealed class " + typeName + " : ResultSetEnumerator");
if (resultSets.Count == 1)
Write(", IEnumerable<" + resultSets[0].RowTypeName + ">");
WriteLine(" {");
PushIndent("\t");
if (generateSqlclr) {
WriteLine("#pragma warning disable 414");
WriteLine("private bool preventStaticDelegates;");
WriteLine("#pragma warning restore 414");
WriteLine("");
}
WriteLine("public " + typeName + "(SqlCommand command) : base(command) {");
WriteLine("}");
if (resultSets.Count == 1) {
WriteLine("");
WriteLine("public IEnumerator<" + resultSets[0].RowTypeName + "> GetEnumerator() {");
PushIndent("\t");
WriteLine("return ResultSet.GetEnumerator();");
PopIndent();
WriteLine("}");
WriteLine("");
WriteLine("IEnumerator IEnumerable.GetEnumerator() {");
PushIndent("\t");
WriteLine("return ResultSet.GetEnumerator();");
PopIndent();
WriteLine("}");
}
for (int i = 0; i < resultSets.Count; i++) {
WriteLine("");
WriteLine("public IEnumerable<" + resultSets[i].RowTypeName + "> ResultSet" + (resultSets.Count == 1 ? "" : Convert.ToString(i+1)) + " {");
PushIndent("\t");
WriteLine("get { return GetResultSetEnumerator(" + i.ToString() + ", " + resultSets[i].CreatorLambda + ", " + (resultSets.Count == 1 ? "true" : "false") + "); }");
PopIndent();
WriteLine("}");
}
foreach (var param in outParams) {
WriteLine("");
WriteLine("public " + param.CSharpTypeName + " " + param.CSharpNameU + " {");
PushIndent("\t");
WriteLine("get { return GetOutputParameter<" + param.CSharpTypeName + ">(\"" + param.DbName + "\"); }");
PopIndent();
WriteLine("}");
}
PopIndent();
WriteLine("}");
WriteLine("");
}
void WriteStoredprocSignature(Procedure proc) {
string returnTypeName = proc.ResultSets.Count == 0 ? Parameter.ReturnValue.CSharpTypeName : proc.ResultTypeName;
bool onlyInputParams = (proc.ResultSets.Count > 0);
Write(returnTypeName + " " + proc.CSharpName + "(");
bool first = true;
foreach (Parameter param in proc.Parameters) {
if (!first)
Write(", ");
switch (param.Direction) {
case ParameterDirection.Input:
break;
case ParameterDirection.Output:
if (onlyInputParams)
continue;
Write("out ");
break;
case ParameterDirection.InputOutput:
if (!onlyInputParams)
Write("ref ");
break;
}
Write(param.CSharpTypeName + " " + param.CSharpNameL);
first = false;
}
Write(")");
}
void WriteViewSignature(View view) {
Write("IEnumerable<" + view.RowTypeName + "> Query" + view.CSharpName + "(");
if (!generateSqlclr)
Write("System.Linq.Expressions.Expression<Func<" + view.ResultSet.RowTypeName + ", bool>> filter, System.Linq.Expressions.Expression<Func<" + view.ResultSet.RowTypeName + ", object>> order");
Write(")");
}
void WriteNamedQuerySignature(NamedQuery query) {
Write("IEnumerable<" + query.RowTypeName + "> " + query.Name + "(" + string.Join(", ", query.Parameters.Select(p => p.CSharpTypeName + " " + p.CSharpNameL).ToArray()) + ")");
}
void GenerateStoredprocProxy(Procedure proc) {
Write("public ");
WriteStoredprocSignature(proc);
WriteLine(" {");
PushIndent("\t");
if (proc.ResultSets.Count == 0) {
WriteLine("using (SqlCommand cmd = CreateCommand()) {");
PushIndent("\t");
}
else {
WriteLine("SqlCommand cmd = CreateCommand();");
}
WriteLine("cmd.CommandText = \"[" + proc.DbName.Schema + "].[" + proc.DbName.Name + "]\";");
WriteLine("cmd.CommandType = CommandType.StoredProcedure;");
WriteLine("cmd.CommandTimeout = CommandTimeout;");
WriteLine("SqlParameter p;");
foreach (Parameter param in proc.Parameters) {
Write("p = cmd.Parameters.Add(\"" + param.DbName + "\", SqlDbType." + param.DbType.ToString());
if (param.DbSize != 0)
Write(", " + param.DbSize);
WriteLine(");");
WriteLine("p.Value = (object)" + param.CSharpNameL + " ?? " + (param.DbType == SqlDbType.Xml ? "System.Data.SqlTypes.SqlXml.Null" : "DBNull.Value") + ";");
if (param.Direction != ParameterDirection.Input)
WriteLine("p.Direction = ParameterDirection." + param.Direction + ";");
}
WriteLine("p = cmd.Parameters.Add(\"" + Parameter.ReturnValue.DbName + "\", SqlDbType." + Parameter.ReturnValue.DbType.ToString() + ");");
WriteLine("p.Direction = ParameterDirection.ReturnValue;");
if (proc.ResultSets.Count == 0) {
WriteLine("cmd.ExecuteNonQuery();");
foreach (Parameter param in proc.Parameters.Where(p => p.Direction != ParameterDirection.Input)) {
WriteLine(param.CSharpNameL + " = SqlFunctions.DbNullCast<" + param.CSharpTypeName + ">(cmd.Parameters[\"" + param.DbName + "\"].Value);");
}
WriteLine("return SqlFunctions.DbNullCast<" + Parameter.ReturnValue.CSharpTypeName + ">(cmd.Parameters[\"" + Parameter.ReturnValue.DbName + "\"].Value);");
PopIndent();
WriteLine("}");
}
else {
WriteLine("return new " + proc.CSharpName + "Result(cmd);");
}
PopIndent();
WriteLine("}");
WriteLine("");
}
void GenerateViewProxy(View view) {
Write("public ");
WriteViewSignature(view);
WriteLine(" {");
PushIndent("\t");
WriteLine("return ResultSetEnumerator.ExecuteQuery(connection, transaction, \"SELECT " + string.Join(", ", view.ResultSet.Columns.Select(c => "[" + c.DbName + "]").ToArray()) + " FROM [" + view.DbName.Schema + "].[" + view.DbName.Name + "]\", " + view.ResultSet.CreatorLambda + (generateSqlclr ? "" : ", filter, order") + ", CommandTimeout);");
PopIndent();
WriteLine("}");
WriteLine("");
}
void GenerateNamedQueryProxy(NamedQuery query) {
Write("public ");
WriteNamedQuerySignature(query);
WriteLine(" {");
PushIndent("\t");
if (query.Parameters.Count > 0) {
WriteLine("List<SqlParameter> parms = new List<SqlParameter>();");
WriteLine("SqlParameter p;");
foreach (Parameter param in query.Parameters) {
Write("p = new SqlParameter(\"" + param.DbName + "\", SqlDbType." + param.DbType.ToString());
if (param.DbSize != 0)
Write(", " + param.DbSize);
WriteLine(");");
WriteLine("p.Value = (object)" + param.CSharpNameL + " ?? DBNull.Value;");
WriteLine("parms.Add(p);");
}
}
WriteLine("return ResultSetEnumerator.ExecuteQuery(connection, transaction, @\"" + query.QueryText.Replace("\"", "\"\"") + "\", " + (query.Parameters.Count > 0 ? "parms" : "null") + ", " + query.ResultSet.CreatorLambda + ", CommandTimeout);");
PopIndent();
WriteLine("}");
WriteLine("");
}
#>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment