-
-
Save bessgeor/44c6837ae8e5404007029d545e312152 to your computer and use it in GitHub Desktop.
open System | |
open System.Collections.Generic | |
open System.Linq | |
open System.Linq.Expressions | |
open System.Reflection | |
open LinqToDB | |
module AST_Conversion = | |
type private StubContext() = | |
interface IDataContext with | |
member val ContextID = null | |
member val CreateSqlProvider = null | |
member val GetSqlOptimizer = null | |
member val SqlProviderFlags = null | |
member val DataReaderType = null | |
member val MappingSchema = null | |
member val InlineParameters = false with get, set | |
member val QueryHints = null | |
member val NextQueryHints = null | |
member val CloseAfterUse = false with get, set | |
member _.GetReaderExpression (_, _, _, _) = null | |
member _.IsDBNullAllowed (_, _) = Nullable true | |
member _.Clone (_) = null | |
member _.Close () = () | |
member _.GetQueryRunner(_, _, _, _, _) = null | |
member _.add_OnClosing _ = () | |
member _.remove_OnClosing _ = () | |
member val OnEntityCreated = null with get, set | |
member _.Dispose () = () | |
type private StubGrouping<'a,'b>(a) = | |
interface IGrouping<'a,'b> with | |
member val Key = a with get | |
member _.GetEnumerator(): IEnumerator<_> = Seq.empty<'b>.GetEnumerator() | |
member this.GetEnumerator() = Array.empty<'b>.GetEnumerator() | |
let private acceptableTypes = | |
[| | |
typeof<int> | |
typeof<Guid> | |
typeof<string> | |
|] | |
let private getWritePropertyOfSubstitutableType (x: Type) = | |
let pub = x.GetProperties(BindingFlags.Public ||| BindingFlags.Instance) | |
let pri = x.GetProperties(BindingFlags.NonPublic ||| BindingFlags.Instance) | |
let ctor = | |
x.GetConstructors() | |
|> Seq.map (fun ctor -> ctor, ctor.GetParameters()) | |
|> Seq.sortBy (snd >> Array.length) | |
|> Seq.filter (fun (_, parameters) -> | |
parameters.Length = 0 || acceptableTypes.Contains(parameters.[0].ParameterType) | |
) | |
|> Seq.tryHead | |
|> Option.map fst | |
if ctor = None then | |
None | |
else | |
let ctor = ctor.Value | |
if ctor.GetParameters().Length > 0 then | |
Some(None, ctor) | |
else | |
let mutable res = None | |
for v in Seq.append pub pri do | |
if res = None && v.CanRead && v.CanWrite && acceptableTypes.Contains v.PropertyType then | |
res <- Some v | |
res |> Option.map (fun res -> Some res, ctor) | |
let private getReadPropertyOfSubstitutableType (x: Type) = | |
let pub = x.GetProperties(BindingFlags.Public ||| BindingFlags.Instance) | |
let pri = x.GetProperties(BindingFlags.NonPublic ||| BindingFlags.Instance) | |
let mutable prop = None | |
let props = Seq.append pub pri |> Seq.sortByDescending (fun x -> x.Name = "Item1") | |
for v in props do | |
if prop = None && v.CanRead && acceptableTypes.Contains v.PropertyType then | |
prop <- Some v | |
prop | |
let private toConstExpr x = Expression.Constant(box x) :> Expression | |
type private SubstitutionKey = | |
| S_Int of int | |
| S_Guid of Guid | |
| S_String of string | |
| S_DataContext | |
with | |
static member NewFromType (x: Type, counter: int) = | |
match x with | |
| x when x = typeof<int> -> | |
toConstExpr counter, S_Int counter | |
| x when x = typeof<Guid> -> | |
let guid = Guid.NewGuid() | |
toConstExpr guid, S_Guid guid | |
| x when x = typeof<string> -> | |
let guid = Guid.NewGuid().ToString() | |
toConstExpr guid, S_String guid | |
| x when x = typeof<IDataContext> -> | |
let ctor = typeof<StubContext>.GetConstructor([||]) | |
let expr = Expression.New(ctor) | |
upcast expr, S_DataContext | |
| x when x.IsGenericType && x.GetGenericTypeDefinition() = typedefof<IGrouping<_,_>> -> | |
SubstitutionKey.NewFromType(typedefof<StubGrouping<_,_>>.MakeGenericType(x.GetGenericArguments()), counter) | |
| x -> | |
let constructing = getWritePropertyOfSubstitutableType x | |
if constructing.IsNone || x.IsValueType then | |
failwithf "unexpected compiling query lambda parameter type (%s)" x.Name | |
else | |
let prop, ctor = constructing.Value | |
let ctorParams = ctor.GetParameters() | |
let usingCtor = prop = None | |
let expr, key = | |
let targetType = if usingCtor then ctorParams.[0].ParameterType else prop.Value.PropertyType | |
SubstitutionKey.NewFromType(targetType, counter) | |
if key = S_DataContext then | |
failwith "impossibru" | |
else | |
if usingCtor then | |
let newParams = | |
[| | |
// may become a leg-shooting if this won't provide | |
// an instance with readable property | |
// which would be chosen by `getReadPropertyOfSubstitutableType` | |
expr | |
for p in ctorParams.Skip(1) do | |
Expression.Default p.ParameterType | |
|] | |
let create = Expression.New(ctor, newParams) | |
upcast create, key | |
else | |
let assign = Expression.Bind(prop.Value, expr) | |
let create = Expression.MemberInit(Expression.New ctor, [|assign :> MemberBinding|]) | |
upcast create, key | |
static member ExtractFromInstance (instance: obj) = | |
match instance with | |
| null -> None | |
| :? int as x -> S_Int x |> Some | |
| :? Guid as x -> S_Guid x |> Some | |
| :? string as x -> S_String x |> Some | |
| :? IDataContext -> S_DataContext |> Some | |
| x -> | |
let x = x.GetType() | |
let readProp = getReadPropertyOfSubstitutableType x | |
if readProp.IsNone || x.IsValueType then | |
None | |
else | |
let prop = readProp.Value | |
SubstitutionKey.ExtractFromInstance(prop.GetValue instance) | |
let private getParameterSubstitution (count: int) (p: ParameterExpression) = | |
SubstitutionKey.NewFromType(p.Type, count) | |
type private BackwardSubstitutingVisitor(substitutions: Dictionary<SubstitutionKey, ParameterExpression>) = | |
inherit ExpressionVisitor() | |
override _.VisitConstant(expr) = | |
let key = SubstitutionKey.ExtractFromInstance expr.Value | |
if key.IsSome then | |
let hasKey, value = substitutions.TryGetValue key.Value | |
if hasKey then | |
upcast value | |
else base.VisitConstant(expr) | |
else base.VisitConstant(expr) | |
type private ForwardSubstitutingVisitor(forward: Dictionary<ParameterExpression, Expression>, backward: Dictionary<SubstitutionKey, ParameterExpression>) = | |
inherit ExpressionVisitor() | |
override _.VisitParameter(param) = | |
let hasKey, value = forward.TryGetValue(param) | |
if hasKey then value | |
else | |
let expr, key = SubstitutionKey.NewFromType(param.Type, backward.Count) | |
do forward.Add(param, expr) | |
do backward.Add(key, param) | |
expr | |
type FSharpLambdaVisitor private(forward: Dictionary<ParameterExpression, Expression>, backward: Dictionary<SubstitutionKey, ParameterExpression>) = | |
inherit ExpressionVisitor() | |
new() = FSharpLambdaVisitor(Dictionary(), Dictionary()) | |
override _.VisitMethodCall(call) = | |
if call.Method.Name = "QuotationToLambdaExpression" && call.Arguments.Any() then | |
match call.Arguments.[0] with | |
| :? MethodCallExpression as innerCall -> | |
if innerCall.Method.Name = "SubstHelper" && innerCall.Arguments.Any() && innerCall.Arguments.[0].NodeType = ExpressionType.Constant then | |
let forwardSubstituter = ForwardSubstitutingVisitor(forward, backward) | |
let innerCall = | |
innerCall.Update( | |
innerCall.Object, | |
[| | |
innerCall.Arguments.[0] | |
innerCall.Arguments.[1] | |
forwardSubstituter.VisitAndConvert(innerCall.Arguments.[2], "FSharpLambdaVisitor.VisitMethodCall") | |
|] | |
) | |
let call = call.Update(call.Object, [|innerCall|]) | |
let lambda = Expression.Lambda<Func<Expression>>(call, [||]) | |
let compiled = lambda.Compile().Invoke() | |
let converted = compiled |> Expression.Quote | |
let backSubstituter = BackwardSubstitutingVisitor backward | |
let backSubstituted = backSubstituter.VisitAndConvert(converted, "FSharpLambdaVisitor.VisitMethodCall") | |
FSharpLambdaVisitor(Dictionary(forward), Dictionary(backward)).Visit backSubstituted | |
else | |
base.VisitMethodCall innerCall | |
| _ -> base.VisitMethodCall call | |
else base.VisitMethodCall call | |
override _.VisitLambda(lambda) = | |
let forward = Dictionary(forward) | |
let backward = Dictionary(backward) | |
let forwardSubstituter = ForwardSubstitutingVisitor(forward, backward) | |
for p in lambda.Parameters do | |
forwardSubstituter.Visit(p) |> ignore // just fill the dicts | |
let lambdaScopeVisitor = FSharpLambdaVisitor(forward, backward) | |
upcast lambda.Update(lambdaScopeVisitor.Visit(lambda.Body), lambda.Parameters) |
I think this visitor replaces all you visitors and does the same, doesn't it?
type FSharpLambdaVisitor() = inherit ExpressionVisitor() override _.VisitMethodCall(call) = if call.Method.Name <> "QuotationToLambdaExpression" then base.VisitMethodCall(call) else let lambda = Expression.Lambda(call, Array.empty) let newNode = lambda.Compile().DynamicInvoke() :?> Expression base.Visit(newNode)
Thank you for your feedback.
The code you provided seems like working (I haven't tried to run it, though), but I believe it would work only for lambdas without closures. FSharp-generated conversion accepts all the closure values (and substitutes closure references for closure values during conversion) which would fail the compilation of expression tree because the referenced closures won't be in scope for the LambdaExpression generated. That's why I use forward/backward substituting visitors: they allow conversion lamba to compile and also allow to find substitution values in converted from FSharp expression tree and substitute them back to original expressions which is the way C# closures are shown in expression trees.
But there seems to be no way to create such a temporary substitution for an arbitrary type, so I generate substitution values based on properties of known types, instances of which I know to be captured into closures in my special case. Generation of such special values is kinda ad-hoc, and is to be replaced/extended for each case handled, but I can't find any better solution.
I think this visitor replaces all you visitors and does the same, doesn't it?