Skip to content

Instantly share code, notes, and snippets.

@afish
Last active Aug 14, 2016
Embed
What would you like to do?
This gist shows simple implementation of traits in C# using Fody with overrides resolution.
using Mono.Cecil;
using System;
using System.Collections.Generic;
using System.Linq;
using Mono.Cecil.Cil;
using TraitIntroducer;
namespace Weavers
{
public class TraitWeaver
{
public ModuleDefinition ModuleDefinition { get; set; }
public Action<string> LogWarning { get; set; }
public IList<TypeDefinition> AllTypes { get; set; }
public TypeDefinition TraitForAttributeTypeDefinition { get; set; }
public void Execute()
{
AllTypes = ModuleDefinition.Types.Concat(ModuleDefinition.GetTypeReferences().Select(t => t.Resolve())).ToList();
TraitForAttributeTypeDefinition = AllTypes.FirstOrDefault(t => t.FullName == typeof(TraitForAttribute).FullName);
if (TraitForAttributeTypeDefinition == null) return;
var orderedTypes = AllTypes.OrderBy(type => GetTypeHierarchy(type).Count());
foreach (var type in orderedTypes.Where(t => t.IsClass))
{
FixClass(type);
}
foreach (var type in orderedTypes.Where(t => t.IsInterface))
{
FixInterface(type);
}
}
private void FixClass(TypeDefinition type)
{
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).TakeWhile(t => t.IsInterface).Reverse().ToArray();
var introducedMethods = new Dictionary<string, Tuple<MethodDefinition, MethodDefinition>>();
foreach (var implementedInterface in hierarchy)
{
var extenders = GetExtenders(implementedInterface);
foreach (var extender in extenders)
{
foreach (var method in extender.Methods)
{
var injected = InjectExtensionMethodToInheritor(type, extender, method);
var matchingMethod = type.Methods.FirstOrDefault(m => m.Name == method.Name);
if (matchingMethod != null)
{
matchingMethod.Attributes |= MethodAttributes.Virtual;
matchingMethod.Attributes &= ~MethodAttributes.NewSlot;
}
else
{
introducedMethods[method.Name] = Tuple.Create(method, injected);
}
}
}
}
foreach (var introducedMethod in introducedMethods.Values)
{
InjectVirtualMethodToInheritor(type, introducedMethod.Item1, introducedMethod.Item2);
}
}
private void FixInterface(TypeDefinition type)
{
var hierarchy = GetTypeHierarchy(type).Reverse().Skip(1).ToArray();
var extenders = GetExtenders(type);
foreach (var extender in extenders)
{
foreach (var method in extender.Methods)
{
var existingMethod = hierarchy.Select(t => t.Methods.FirstOrDefault(m => m.Name == method.Name)).FirstOrDefault(m => m != null) ?? type.Methods.FirstOrDefault(m => m.Name == method.Name);
if (existingMethod == null)
{
InjectMethodToInterface(type, method);
}
else
{
FixTraitMethod(method, existingMethod);
}
}
}
}
private IEnumerable<TypeDefinition> GetExtenders(TypeDefinition typeToExtend)
{
return AllTypes
.Where(type => type.CustomAttributes.Any(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName))
.Where(type =>
{
var extendedInterafaceType = type.CustomAttributes.First(attribute => attribute.AttributeType.FullName == TraitForAttributeTypeDefinition.FullName).ConstructorArguments.First().Value as TypeDefinition;
var extendedInterfaceTypeDefinition = AllTypes.First(t => t.FullName == extendedInterafaceType.FullName);
return extendedInterfaceTypeDefinition.FullName == typeToExtend.FullName;
});
}
private IEnumerable<TypeDefinition> GetTypeHierarchy(TypeDefinition type)
{
if (type == null)
{
return Enumerable.Empty<TypeDefinition>();
}
return GetTypeHierarchy(type.BaseType as TypeDefinition)
.Concat(type.Interfaces.OfType<TypeDefinition>().SelectMany(GetTypeHierarchy))
.Concat(new [] {type})
.GroupBy(t => t.FullName)
.Select(x => x.First());
}
private void FixTraitMethod(MethodDefinition method, MethodDefinition methodToCall)
{
method.Body.Instructions.Clear();
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
method.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
method.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
}
private void InjectMethodToInterface(TypeDefinition extendedInterface, MethodDefinition method)
{
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Abstract | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
extendedInterface.Methods.Add(newMethod);
FixTraitMethod(method, newMethod);
}
private string GetInheritorMethodName(TypeDefinition extender, MethodDefinition method)
{
return $"{method.Name}_{extender.Name}";
}
private MethodDefinition InjectExtensionMethodToInheritor(TypeDefinition inheritor, TypeDefinition extender, MethodDefinition method)
{
var newMethod = new MethodDefinition(GetInheritorMethodName(extender, method), MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig | MethodAttributes.NewSlot, method.ReturnType);
foreach (var instruction in method.Body.Instructions)
{
newMethod.Body.Instructions.Add(instruction);
}
inheritor.Methods.Add(newMethod);
return newMethod;
}
private MethodDefinition InjectVirtualMethodToInheritor(TypeDefinition inheritor, MethodDefinition method, MethodDefinition methodToCall)
{
var newMethod = new MethodDefinition(method.Name, MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.HideBySig, method.ReturnType);
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ldarg_0));
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Callvirt, methodToCall));
newMethod.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
inheritor.Methods.Add(newMethod);
return newMethod;
}
}
}
using System;
using TraitIntroducer;
namespace TraitsDemo
{
public interface IA
{
}
public interface IB
{
}
public interface IC
{
}
[TraitFor(typeof(IA))]
public static class IA_Implementation
{
public static void Print(this IA instance)
{
Console.WriteLine("I'm IA");
instance.Base();
}
}
[TraitFor(typeof(IB))]
public static class IB_Implementation
{
public static void Print(this IB instance)
{
Console.WriteLine("I'm IB");
instance.Base();
}
}
[TraitFor(typeof(IC))]
public static class IC_Implementation
{
public static void Print(this IC instance)
{
Console.WriteLine("I'm IC");
instance.Base();
}
}
[TraitFor(typeof(IC))]
public static class IC_Implementation2
{
public static void Print(this IC instance)
{
Console.WriteLine("I'm IC2");
instance.Base();
}
}
public class A : IA
{
public virtual void Print()
{
Console.WriteLine("I'm A");
this.Base();
}
}
public class B : A
{
}
public class C : B
{
public override void Print()
{
Console.WriteLine("I'm C");
this.Base();
}
}
public class D : C, IB, IC
{
}
class Program
{
static void Main(string[] args)
{
IA a = new D();
a.Print();
// I'm IC2
// Not yet working:
// I'm IC
// I'm IB
// I'm C
// I'm A
// I'm IA
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment