-
-
Save afish/dbc1f346c198bc6453b96068a7fdeeb8 to your computer and use it in GitHub Desktop.
This weaver adds traits to C# using Fody.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Mono.Cecil; | |
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using Mono.Cecil.Cil; | |
using TraitIntroducer; | |
namespace Weavers | |
{ | |
public class TraitWeaver | |
{ | |
public ModuleDefinition ModuleDefinition { get; set; } | |
public Action<string> LogWarning { get; set; } | |
public IList<TypeDefinition> AllTypes { get; set; } | |
public TypeDefinition TraitForAttributeTypeDefinition { get; set; } | |
public MethodDefinition BaseMethodExtension { get; set; } | |
public void Execute() | |
{ | |
AllTypes = ModuleDefinition.Types.Concat(ModuleDefinition.GetTypeReferences().Select(t => t.Resolve())).ToList(); | |
TraitForAttributeTypeDefinition = AllTypes.FirstOrDefault(t => t.FullName == typeof(TraitForAttribute).FullName); | |
if (TraitForAttributeTypeDefinition == null) return; | |
BaseMethodExtension = AllTypes.First(t => t.FullName == typeof(TratExtensions).FullName).Methods.First(m => m.Name == "Base"); | |
var orderedTypes = AllTypes.OrderBy(type => GetTypeHierarchy(type).Count()); | |
foreach (var type in orderedTypes.Where(t => t.IsClass)) | |
{ | |
FixClass(type); | |
} | |
foreach (var type in orderedTypes.Where(t => t.IsInterface)) | |
{ | |
FixInterface(type); | |
} | |
} | |
private void FixClass(TypeDefinition type) | |
{ | |
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).TakeWhile(t => t.IsInterface).Reverse().ToArray(); | |
var upperHierarchy = GetTypeHierarchy(type).Reverse().Skip(1).SkipWhile(t => t.IsInterface).Where(t => t.IsClass).ToArray(); | |
var introducedMethods = new Dictionary<string, Tuple<MethodDefinition, MethodDefinition>>(); | |
var lastInjectedMethod = new Dictionary<string, Tuple<MethodDefinition, MethodDefinition>>(); | |
var methodsToFix = new Dictionary<string, MethodDefinition>(); | |
foreach (var implementedInterface in hierarchy) | |
{ | |
var extenders = GetExtenders(implementedInterface); | |
foreach (var extender in extenders) | |
{ | |
foreach (var method in extender.Methods) | |
{ | |
Tuple<MethodDefinition, MethodDefinition> baseMethodPair; | |
lastInjectedMethod.TryGetValue(method.Name, out baseMethodPair); | |
FixUpperHierarchy(upperHierarchy, method); | |
var baseMethod = baseMethodPair?.Item2 ?? GetMethodByNameFromUpperHierarchy(upperHierarchy, method); | |
var injected = InjectExtensionMethodToInheritor(type, extender, method, baseMethod); | |
var matchingMethod = type.Methods.FirstOrDefault(m => m.Name == method.Name); | |
lastInjectedMethod[method.Name] = Tuple.Create(method, injected); | |
if (matchingMethod != null) | |
{ | |
matchingMethod.Attributes |= MethodAttributes.Virtual; | |
matchingMethod.Attributes &= ~MethodAttributes.NewSlot; | |
methodsToFix[matchingMethod.Name] = method; | |
FixSingleMethodCallToBase(matchingMethod, injected, baseMethod); | |
} | |
else | |
{ | |
introducedMethods[method.Name] = Tuple.Create(method, injected); | |
} | |
} | |
} | |
} | |
foreach (var introducedMethod in introducedMethods.Values) | |
{ | |
InjectVirtualMethodToInheritor(type, introducedMethod.Item1, introducedMethod.Item2); | |
} | |
foreach (var methodToFix in methodsToFix) | |
{ | |
FixUpperHierarchy(upperHierarchy, methodToFix.Value); | |
} | |
} | |
private MethodDefinition GetMethodByNameFromUpperHierarchy(TypeDefinition[] upperHierarchy, MethodDefinition method) | |
{ | |
return upperHierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null); | |
} | |
private void FixUpperHierarchy(TypeDefinition[] hierarchy, MethodDefinition method) | |
{ | |
if (!hierarchy.Any()) | |
{ | |
return; | |
} | |
var typeToFix = hierarchy.First(); | |
var restHierarchy = hierarchy.Skip(1).ToArray(); | |
var matchingMethod = typeToFix.Methods.FirstOrDefault(m => m.Name == method.Name); | |
var upperMethod = GetMethodByNameFromUpperHierarchy(restHierarchy, method); | |
FixSingleMethodCallToBase(matchingMethod, upperMethod); | |
FixUpperHierarchy(restHierarchy, method); | |
} | |
private void FixSingleMethodCallToBase(MethodDefinition existingMethod, MethodDefinition methodToCall, MethodDefinition baseMethod = null) | |
{ | |
if (existingMethod == null || methodToCall == null) | |
{ | |
return; | |
} | |
foreach (var instruction in existingMethod | |
.Body.Instructions | |
.Where(i => | |
IsCallToOurBase(i, BaseMethodExtension.FullName) | |
|| (baseMethod != null && IsCallToOurBase(i, baseMethod.FullName)) | |
) | |
) | |
{ | |
instruction.Operand = methodToCall.GetElementMethod(); | |
} | |
} | |
private void FixInterface(TypeDefinition type) | |
{ | |
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).ToArray(); | |
var extenders = GetExtenders(type); | |
foreach (var extender in extenders) | |
{ | |
foreach (var method in extender.Methods) | |
{ | |
var existingMethod = hierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null) ?? type.Methods.FirstOrDefault(m => m.Name == method.Name); | |
if (existingMethod == null) | |
{ | |
InjectMethodToInterface(type, method); | |
} | |
else | |
{ | |
FixTraitMethod(method, existingMethod); | |
} | |
} | |
} | |
} | |
private IEnumerable<TypeDefinition> GetExtenders(TypeDefinition typeToExtend) | |
{ | |
return AllTypes | |
.Where(type => type.CustomAttributes.Any(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName)) | |
.Where(type => | |
{ | |
var extendedInterafaceType = type.CustomAttributes.First(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName).ConstructorArguments.First().Value as TypeDefinition; | |
var extendedInterfaceTypeDefinition = AllTypes.First(t => t.FullName == extendedInterafaceType.FullName); | |
return extendedInterfaceTypeDefinition.FullName == typeToExtend.FullName; | |
}); | |
} | |
private IEnumerable<TypeDefinition> GetTypeHierarchy(TypeDefinition type) | |
{ | |
if (type == null) | |
{ | |
return Enumerable.Empty<TypeDefinition>(); | |
} | |
return GetTypeHierarchy(type.BaseType as TypeDefinition) | |
.Concat(type.Interfaces.OfType<TypeDefinition>().SelectMany(GetTypeHierarchy)) | |
.Concat(new [] {type}) | |
.GroupBy(t => t.FullName) | |
.Select(x => x.First()); | |
} | |
private void FixTraitMethod(MethodDefinition method, MethodDefinition methodToCall) | |
{ | |
method.Body.Instructions.Clear(); | |
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0)); | |
method.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall)); | |
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ret)); | |
} | |
private void InjectMethodToInterface(TypeDefinition extendedInterface, MethodDefinition method) | |
{ | |
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Abstract | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType); | |
extendedInterface.Methods.Add(newMethod); | |
FixTraitMethod(method, newMethod); | |
} | |
private string GetInheritorMethodName(TypeDefinition extender, MethodDefinition method) | |
{ | |
return $"{method.Name}_{extender.Name}"; | |
} | |
private bool IsCallToOurBase(Instruction instruction, string name) | |
{ | |
return instruction.OpCode == OpCodes.Call && (instruction.Operand as MethodReference)?.FullName == name; | |
} | |
private MethodDefinition InjectExtensionMethodToInheritor(TypeDefinition inheritor, TypeDefinition extender, MethodDefinition method, MethodDefinition baseMethod) | |
{ | |
var newMethod = new MethodDefinition(GetInheritorMethodName(extender, method), MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType); | |
foreach (var instruction in method.Body.Instructions) | |
{ | |
if (baseMethod != null && IsCallToOurBase(instruction, BaseMethodExtension.FullName)) | |
{ | |
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Call, baseMethod)); | |
} | |
else | |
{ | |
newMethod.Body.Instructions.Add(instruction); | |
} | |
} | |
inheritor.Methods.Add(newMethod); | |
return newMethod; | |
} | |
private MethodDefinition InjectVirtualMethodToInheritor(TypeDefinition inheritor, MethodDefinition method, MethodDefinition methodToCall) | |
{ | |
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, method.ReturnType); | |
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0)); | |
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall)); | |
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ret)); | |
inheritor.Methods.Add(newMethod); | |
return newMethod; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment