Skip to content

Instantly share code, notes, and snippets.

@afish
Last active January 23, 2017 13:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afish/dbc1f346c198bc6453b96068a7fdeeb8 to your computer and use it in GitHub Desktop.
Save afish/dbc1f346c198bc6453b96068a7fdeeb8 to your computer and use it in GitHub Desktop.
This weaver adds traits to C# using Fody.
using Mono.Cecil;
using System;
using System.Collections.Generic;
using System.Linq;
using Mono.Cecil.Cil;
using TraitIntroducer;
namespace Weavers
{
public class TraitWeaver
{
public ModuleDefinition ModuleDefinition { get; set; }
public Action<string> LogWarning { get; set; }
public IList<TypeDefinition> AllTypes { get; set; }
public TypeDefinition TraitForAttributeTypeDefinition { get; set; }
public MethodDefinition BaseMethodExtension { get; set; }
public void Execute()
{
AllTypes = ModuleDefinition.Types.Concat(ModuleDefinition.GetTypeReferences().Select(t => t.Resolve())).ToList();
TraitForAttributeTypeDefinition = AllTypes.FirstOrDefault(t => t.FullName == typeof(TraitForAttribute).FullName);
if (TraitForAttributeTypeDefinition == null) return;
BaseMethodExtension = AllTypes.First(t => t.FullName == typeof(TratExtensions).FullName).Methods.First(m => m.Name == "Base");
var orderedTypes = AllTypes.OrderBy(type => GetTypeHierarchy(type).Count());
foreach (var type in orderedTypes.Where(t => t.IsClass))
{
FixClass(type);
}
foreach (var type in orderedTypes.Where(t => t.IsInterface))
{
FixInterface(type);
}
}
private void FixClass(TypeDefinition type)
{
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).TakeWhile(t => t.IsInterface).Reverse().ToArray();
var upperHierarchy = GetTypeHierarchy(type).Reverse().Skip(1).SkipWhile(t => t.IsInterface).Where(t => t.IsClass).ToArray();
var introducedMethods = new Dictionary<string, Tuple<MethodDefinition, MethodDefinition>>();
var lastInjectedMethod = new Dictionary<string, Tuple<MethodDefinition, MethodDefinition>>();
var methodsToFix = new Dictionary<string, MethodDefinition>();
foreach (var implementedInterface in hierarchy)
{
var extenders = GetExtenders(implementedInterface);
foreach (var extender in extenders)
{
foreach (var method in extender.Methods)
{
Tuple<MethodDefinition, MethodDefinition> baseMethodPair;
lastInjectedMethod.TryGetValue(method.Name, out baseMethodPair);
FixUpperHierarchy(upperHierarchy, method);
var baseMethod = baseMethodPair?.Item2 ?? GetMethodByNameFromUpperHierarchy(upperHierarchy, method);
var injected = InjectExtensionMethodToInheritor(type, extender, method, baseMethod);
var matchingMethod = type.Methods.FirstOrDefault(m => m.Name == method.Name);
lastInjectedMethod[method.Name] = Tuple.Create(method, injected);
if (matchingMethod != null)
{
matchingMethod.Attributes |= MethodAttributes.Virtual;
matchingMethod.Attributes &= ~MethodAttributes.NewSlot;
methodsToFix[matchingMethod.Name] = method;
FixSingleMethodCallToBase(matchingMethod, injected, baseMethod);
}
else
{
introducedMethods[method.Name] = Tuple.Create(method, injected);
}
}
}
}
foreach (var introducedMethod in introducedMethods.Values)
{
InjectVirtualMethodToInheritor(type, introducedMethod.Item1, introducedMethod.Item2);
}
foreach (var methodToFix in methodsToFix)
{
FixUpperHierarchy(upperHierarchy, methodToFix.Value);
}
}
private MethodDefinition GetMethodByNameFromUpperHierarchy(TypeDefinition[] upperHierarchy, MethodDefinition method)
{
return upperHierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null);
}
private void FixUpperHierarchy(TypeDefinition[] hierarchy, MethodDefinition method)
{
if (!hierarchy.Any())
{
return;
}
var typeToFix = hierarchy.First();
var restHierarchy = hierarchy.Skip(1).ToArray();
var matchingMethod = typeToFix.Methods.FirstOrDefault(m => m.Name == method.Name);
var upperMethod = GetMethodByNameFromUpperHierarchy(restHierarchy, method);
FixSingleMethodCallToBase(matchingMethod, upperMethod);
FixUpperHierarchy(restHierarchy, method);
}
private void FixSingleMethodCallToBase(MethodDefinition existingMethod, MethodDefinition methodToCall, MethodDefinition baseMethod = null)
{
if (existingMethod == null || methodToCall == null)
{
return;
}
foreach (var instruction in existingMethod
.Body.Instructions
.Where(i =>
IsCallToOurBase(i, BaseMethodExtension.FullName)
|| (baseMethod != null && IsCallToOurBase(i, baseMethod.FullName))
)
)
{
instruction.Operand = methodToCall.GetElementMethod();
}
}
private void FixInterface(TypeDefinition type)
{
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).ToArray();
var extenders = GetExtenders(type);
foreach (var extender in extenders)
{
foreach (var method in extender.Methods)
{
var existingMethod = hierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null) ?? type.Methods.FirstOrDefault(m => m.Name == method.Name);
if (existingMethod == null)
{
InjectMethodToInterface(type, method);
}
else
{
FixTraitMethod(method, existingMethod);
}
}
}
}
private IEnumerable<TypeDefinition> GetExtenders(TypeDefinition typeToExtend)
{
return AllTypes
.Where(type => type.CustomAttributes.Any(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName))
.Where(type =>
{
var extendedInterafaceType = type.CustomAttributes.First(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName).ConstructorArguments.First().Value as TypeDefinition;
var extendedInterfaceTypeDefinition = AllTypes.First(t => t.FullName == extendedInterafaceType.FullName);
return extendedInterfaceTypeDefinition.FullName == typeToExtend.FullName;
});
}
private IEnumerable<TypeDefinition> GetTypeHierarchy(TypeDefinition type)
{
if (type == null)
{
return Enumerable.Empty<TypeDefinition>();
}
return GetTypeHierarchy(type.BaseType as TypeDefinition)
.Concat(type.Interfaces.OfType<TypeDefinition>().SelectMany(GetTypeHierarchy))
.Concat(new [] {type})
.GroupBy(t => t.FullName)
.Select(x => x.First());
}
private void FixTraitMethod(MethodDefinition method, MethodDefinition methodToCall)
{
method.Body.Instructions.Clear();
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
method.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
}
private void InjectMethodToInterface(TypeDefinition extendedInterface, MethodDefinition method)
{
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Abstract | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
extendedInterface.Methods.Add(newMethod);
FixTraitMethod(method, newMethod);
}
private string GetInheritorMethodName(TypeDefinition extender, MethodDefinition method)
{
return $"{method.Name}_{extender.Name}";
}
private bool IsCallToOurBase(Instruction instruction, string name)
{
return instruction.OpCode == OpCodes.Call && (instruction.Operand as MethodReference)?.FullName == name;
}
private MethodDefinition InjectExtensionMethodToInheritor(TypeDefinition inheritor, TypeDefinition extender, MethodDefinition method, MethodDefinition baseMethod)
{
var newMethod = new MethodDefinition(GetInheritorMethodName(extender, method), MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
foreach (var instruction in method.Body.Instructions)
{
if (baseMethod != null && IsCallToOurBase(instruction, BaseMethodExtension.FullName))
{
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Call, baseMethod));
}
else
{
newMethod.Body.Instructions.Add(instruction);
}
}
inheritor.Methods.Add(newMethod);
return newMethod;
}
private MethodDefinition InjectVirtualMethodToInheritor(TypeDefinition inheritor, MethodDefinition method, MethodDefinition methodToCall)
{
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, method.ReturnType);
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
inheritor.Methods.Add(newMethod);
return newMethod;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment