Skip to content

Instantly share code, notes, and snippets.

@oguimbal
Created July 31, 2019 22:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oguimbal/341b9080f57ceb943c9f692ddc5010ab to your computer and use it in GitHub Desktop.
Save oguimbal/341b9080f57ceb943c9f692ddc5010ab to your computer and use it in GitHub Desktop.
Custom C# container demo
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Text;
using System.Threading.Tasks;
using System.Linq.Expressions;
// register all services
var c = new Container()
.RegisterSingleton<IMyService, MyImplementation>();
// create & inject an instance
var cls = c.Create<MyClass>();
cls.Call();
public class MyClass {
// inject like this
[Inject]
private IMyService myService;
public void Call() {
// use like that
this.myService.DoSomething();
}
}
interface IMyService {
void DoSomething();
}
public class MyImplementation : IMyService
{
public void DoSomething() {
Console.WriteLine("Do my thing");
}
}
// ================ IMPLEMENTATION ============
public class Inject : Attribute
{
public bool Optional { get; private set; }
public bool CreateIfNecessary { get; private set; }
public Inject(bool optional = false, bool createIfNecessary = false)
{
Optional = optional;
CreateIfNecessary = createIfNecessary;
}
}
public interface IUnsealerFortests
{
void UnsealForTests();
}
public interface IOnInject
{
void Injected();
}
public interface IServices
{
T GetInstance<T>(bool throwIfNotFound = true);
T Create<T>();
object GetInstance(Type t);
}
[DebuggerStepThrough]
public partial class Container : IServiceProvider, IServices, IUnsealerFortests
{
Dictionary<Type, Func<object>> registrations = new Dictionary<Type, Func<object>>();
Dictionary<Type, (Type owner, MethodInfo method)> genericFactories = new Dictionary<Type, (Type owner, MethodInfo method)>();
List<IDisposable> disposables = new List<IDisposable>();
Dictionary<Type, Type> genericServices = new Dictionary<Type, Type>();
IServiceProvider parent;
Container parentContainer;
Dictionary<Type, object> singletonsImplemCreations = new Dictionary<Type, object>();
Queue<string> stack;
HashSet<Type> typeStack;
public string Name { get; private set; }
public Container()
{
SetDefaults();
}
public Container(IServiceProvider parent, string name = null)
: this()
{
Name = name;
SetParent(parent);
}
public void SetParent(IServiceProvider container)
{
if (parent == this)
throw new ArgumentException("Invalid service parent (loop)");
this.parent = container;
if (parent != null)
this.parentContainer = (parent as Container) ?? (Container)parent.GetService(typeof(Container));
else
this.parentContainer = null;
if (parentContainer == this)
throw new ArgumentException("Invalid service parent (loop)");
}
public override string ToString()
{
if (Name == null && parent == null)
return "Anonymous container";
if (Name == null)
return $"Request container <= {parent})";
if (parent != null)
return $"{Name} container <= {parent}";
return Name + " container";
}
void SetDefaults()
{
RegisterSingleton<IServiceProvider>(this);
RegisterSingleton<IServices>(this);
RegisterSingleton<Container>(this);
}
public class Template
{
}
public Container RegisterSingleton<TImpl>(bool ovr = false, Action<TImpl> init = null)
where TImpl : class
=> RegisterSingleton<TImpl, TImpl>(ovr, init);
public Container RegisterSingleton<TService, TImpl>(bool ovr = false, Action<TImpl> init = null)
where TImpl : class, TService
where TService : class
=> RegisterSingleton<TService>(() =>
{
var impl = CreateSingleImplementation<TImpl>();
init?.Invoke(impl);
return impl;
}, ovr);
public Container RegisterSingleton<TService>(TService instance, bool ovr = false)
where TService : class
=> RegisterSingleton<TService>(() => instance, ovr);
public Container RegisterSingleton<TService>(Func<TService> instanceCreator, bool ovr = false)
where TService : class
{
CheckSealed();
TService instance = null;
if (ovr)
registrations[typeof(TService)] = () => instance ?? WrapSigleton((instance = instanceCreator()));
else
registrations.Add(typeof(TService), () => instance ?? WrapSigleton((instance = instanceCreator())));
return this;
}
public Container RegisterSingletonFactory<TServiceTemplate, TFactory>(Expression<Func<TFactory, TServiceTemplate>> factory)
{
if (!typeof(TServiceTemplate).IsGenericType)
throw new ArgumentException("Expecting generic type as template");
if (typeof(TServiceTemplate).GetGenericArguments().Single() != typeof(Template))
throw new ArgumentException("Expecting YourType<Container.Template> as template");
var body = Ext.Unwrap(factory.Body) as MethodCallExpression;
if (body == null
|| !body.Method.IsGenericMethod
|| body.Method.DeclaringType != typeof(TFactory)
|| body.Method.GetGenericArguments().Single() != typeof(Template)
|| body.Method.ReturnType != typeof(TServiceTemplate)
|| body.Method.GetParameters().Length != 0
)
throw new ArgumentException("Expecting myFactory.MyMethod<Container.Template>() method definition");
genericFactories[typeof(TServiceTemplate).GetGenericTypeDefinition()] = (typeof(TFactory), body.Method.GetGenericMethodDefinition());
return this;
}
// =====================================================================
public Container Register<TImpl>(bool ovr = false)
where TImpl : class
=> Register<TImpl, TImpl>(ovr);
public Container Register<TService, TImpl>(bool ovr = false)
where TImpl : TService
where TService : class
=> Register<TService>(() => Create<TImpl>(), ovr);
public Container Register<TService>(Func<TService> instanceCreator, bool ovr = false)
where TService : class
{
CheckSealed();
if (ovr)
registrations[typeof(TService)] = instanceCreator;
else
registrations.Add(typeof(TService), instanceCreator);
return this;
}
public void RegisterGeneric(Type type, Type genericImpl)
{
if (!type.IsGenericTypeDefinition || !genericImpl.IsGenericTypeDefinition || genericImpl.GetGenericArguments().Length != type.GetGenericArguments().Length)
throw new ArgumentException();
genericServices[type] = genericImpl;
}
// =====================================================================
T WrapSigleton<T>(T obj) where T : class
{
var asDisp = obj as IDisposable;
if (asDisp != null)
lock (disposables)
disposables.Add(asDisp);
return obj;
}
public void Clear(bool clearParent = false, bool dispose = false)
{
CheckSealed();
registrations.Clear();
genericFactories.Clear();
if (dispose)
DisposeSingletons();
disposables.Clear();
genericServices.Clear();
singletonsImplemCreations.Clear();
if (clearParent)
parentContainer?.Clear(clearParent: clearParent, dispose: dispose);
SetDefaults();
}
bool isSealed = false;
public void Seal()
=> isSealed = true;
void IUnsealerFortests.UnsealForTests()
=> isSealed = false;
private void CheckSealed()
{
if (isSealed)
throw new InvalidOperationException("This container is sealed");
}
bool disposing;
public void DisposeSingletons()
{
if (disposing)
return;
disposing = true;
try
{
disposables.Reverse();
foreach (var d in disposables)
d.Dispose();
disposables.Clear();
singletonsImplemCreations.Clear();
}
finally
{
disposing = false;
}
}
public T GetInstance<T>(bool throwIfNotFound = true)
=> (T)GetInstance(typeof(T), false /* false IMPORTANT*/, throwIfNotFound);
public T Create<T>()
=> (T)CreateInstance(typeof(T));
public object Create(Type type)
=> CreateInstance(type);
public T DettachSingleton<T>()
{
lock (singletonsImplemCreations)
{
object ret;
if (!singletonsImplemCreations.TryGetValue(typeof(T), out ret))
return default(T);
singletonsImplemCreations.Remove(typeof(T));
var asDisp = ret as IDisposable;
if (asDisp != null)
lock (disposables)
disposables.Remove(asDisp);
return (T)ret;
}
}
public T CreateSingleImplementation<T>()
where T : class
{
lock (singletonsImplemCreations)
{
object ret;
if (singletonsImplemCreations.TryGetValue(typeof(T), out ret))
return (T)ret;
return WrapSigleton((T)(singletonsImplemCreations[typeof(T)] = Create<T>()));
}
}
public object GetInstance(Type serviceType)
=> GetInstance(serviceType, false /* false IMPORTANT */, true);
object locker = new object();
public object GetInstance(Type serviceType, bool create, bool throwIfNotFound)
{
lock (locker)
{
Func<object> creator;
if (!this.registrations.TryGetValue(serviceType, out creator))
{
if (!create && serviceType.IsGenericType)
{
var gen = serviceType.GetGenericTypeDefinition();
(Type owner, MethodInfo method) mi;
if (this.genericFactories.TryGetValue(gen, out mi))
{
// generic type found => register the factory
object instance = null;
creator = this.registrations[serviceType] = () =>
{
if (instance != null)
return instance;
var provider = GetInstance(mi.owner);
instance = mi.method.MakeGenericMethod(serviceType.GetGenericArguments()).Invoke(provider, null);
instance = WrapSigleton(instance);
return instance;
};
}
}
// ask parent
if (creator == null && parent != null)
{
var found = ProvideParent(serviceType);
if (found != null)
return found;
}
}
bool ownsStack = false;
bool successAdd = false;
try
{
if (stack == null)
{
typeStack = new HashSet<Type>();
stack = new Queue<string>();
ownsStack = true;
}
stack.Enqueue(serviceType.Name);
successAdd = typeStack.Add(serviceType);
if (!successAdd || stack.Count > 100)
throw new Exception($"Circular dependency instanciating {serviceType}. It contains a reference to a service being instanciated: \n" + string.Join("\n -> ", stack));
if (creator != null)
return creator();
if (create)
{
if (!serviceType.IsAbstract)
{
CheckSealed();
var instance = CreateInstance(serviceType);
this.registrations[serviceType] = () => instance;
return instance;
}
else if (serviceType.IsGenericType)
{
Type def;
if (this.genericServices.TryGetValue(serviceType.GetGenericTypeDefinition(), out def))
{
CheckSealed();
var instance = CreateInstance(def.MakeGenericType(serviceType.GetGenericArguments()));
this.registrations[serviceType] = () => instance;
return instance;
}
}
}
}
catch (DependencyMissingException ex)
{
throw new DependencyMissingException(serviceType, ex);
}
finally
{
if (stack?.Count > 0)
stack.Dequeue();
if (successAdd)
typeStack?.Remove(serviceType);
if (ownsStack)
{
typeStack = null;
stack = null;
}
}
if (parent != null)
{
var service = parent.GetService(serviceType);
if (service != null)
return service;
}
if (throwIfNotFound)
{
throw new DependencyMissingException(serviceType);
}
return null;
}
}
class DependencyMissingException : Exception
{
public Stack<Type> DependencyPath { get; private set; }
public DependencyMissingException(Type serviceType, DependencyMissingException inner = null)
: base(Msg(serviceType, inner))
{
this.DependencyPath = inner?.DependencyPath ?? new Stack<Type>(new[] { serviceType });
}
private static string Msg(Type serviceType, DependencyMissingException inner = null)
{
if (inner == null)
return $"Missing dependency {serviceType.Name}";
inner.DependencyPath.Push(serviceType);
return $"\n\nMissing dependency {inner.DependencyPath.Last().Name}\n\nwhile instanciating {string.Join(" => ", inner.DependencyPath.Take(inner.DependencyPath.Count - 1).Select(x => x.Name))}";
}
}
public object GetService(Type serviceType)
{
var got = GetInstance(serviceType, false /* false important */, false);
if (parent != null)
return got ?? ProvideParent(serviceType);
return got;
}
object ProvideParent(Type serviceType)
{
return parent?.GetService(serviceType);
}
static Dictionary<Type, Func<Container, object>> builders = new Dictionary<Type, Func<Container, object>>();
static Dictionary<Type, Action<Container, object>> injectors = new Dictionary<Type, Action<Container, object>>();
static readonly MethodInfo GetInstanceMethod = typeof(Container).GetMethods()
.Single(x => x.Name == nameof(GetInstance) && x.GetGenericArguments().Length == 0 && x.GetParameters().Length == 1);
static readonly MethodInfo GetInstanceDetailedMethod = typeof(Container).GetMethods()
.Single(x => x.Name == nameof(GetInstance) && x.GetGenericArguments().Length == 0 && x.GetParameters().Length == 3);
static Func<Container, object> CreateBuilder(Type implementationType)
{
var ctor = implementationType.GetConstructors().FirstOrDefault();
if (ctor == null)
throw new ArgumentException("Cannot instanciate type " + implementationType + ": it has no constructor");
if (implementationType.Assembly == typeof(string).Assembly)
throw new ArgumentException("Cannot inject " + implementationType);
DynamicMethod dm = new DynamicMethod($"{implementationType.Name}_Builder", typeof(object), new Type[] { typeof(Container) }, implementationType, true);
ILGenerator il = dm.GetILGenerator();
// load .ctor parameters on stack
foreach (var p in ctor
.GetParameters()
.Select(p => p.ParameterType))
{
// container.GetInstance(p)
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldtoken, p);
il.Emit(OpCodes.Call, typeof(Type).GetMethod(nameof(Type.GetTypeFromHandle)));
il.Emit(OpCodes.Call, GetInstanceMethod);
}
// create object: new XXX(args)
il.Emit(OpCodes.Newobj, ctor);
_EmitInjectFields(implementationType, il);
il.Emit(OpCodes.Ret);
var setter = (Func<Container, object>)dm.CreateDelegate(typeof(Func<Container, object>));
return setter;
}
private static void _EmitInjectFields(Type ctype, ILGenerator il)
{
// set fields
while (ctype != null && ctype != typeof(Object))
{
foreach (var field in ctype.GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.Instance))
{
var inj = field.GetCustomAttribute<Inject>();
if (inj == null)
continue;
il.Emit(OpCodes.Dup); // obj
// container.GetInstance(p.FieldType, create: inj.CreateIfNecessary, throwIfNotFound: !inj.Optional))
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldtoken, field.FieldType);
il.Emit(OpCodes.Call, typeof(Type).GetMethod(nameof(Type.GetTypeFromHandle)));
il.Emit(inj.CreateIfNecessary ? OpCodes.Ldc_I4_1 : OpCodes.Ldc_I4_0);
il.Emit(inj.Optional ? OpCodes.Ldc_I4_0 : OpCodes.Ldc_I4_1);
il.Emit(OpCodes.Call, GetInstanceDetailedMethod);
// obj.field = val
il.Emit(OpCodes.Stfld, field);
}
ctype = ctype.BaseType;
}
}
public void Inject<T>(T obj) where T : class
{
var injector = Ext.GetOrAdd(injectors, typeof(T), CreateInjector);
injector(this, obj);
}
Action<Container, object> CreateInjector(Type implementationType)
{
if (implementationType.Assembly == typeof(string).Assembly)
throw new ArgumentException("Cannot inject " + implementationType);
DynamicMethod dm = new DynamicMethod($"{implementationType.Name}_Injector", null, new Type[] { typeof(Container), typeof(object) }, implementationType, true);
ILGenerator il = dm.GetILGenerator();
il.Emit(OpCodes.Ldarg_1);
_EmitInjectFields(implementationType, il);
il.Emit(OpCodes.Pop);
il.Emit(OpCodes.Ret);
var inj = (Action<Container, object>)dm.CreateDelegate(typeof(Action<Container, object>));
return inj;
}
private object CreateInstance(Type implementationType)
{
object instance;
try
{
Func<Container, object> builder;
if (!builders.TryGetValue(implementationType, out builder))
{
builder = CreateBuilder(implementationType);
builders[implementationType] = builder;
}
instance = builder(this);
}
catch (DependencyMissingException ex)
{
throw new DependencyMissingException(implementationType, ex);
}
// after object instanciation
var asInj = instance as IOnInject;
if (asInj != null)
asInj.Injected();
return instance;
}
}
public static class Ext
{
public static TValue GetOrAdd<TKey,TValue>(IDictionary<TKey,TValue> dic, TKey key, Func<TKey,TValue> getter)
{
TValue val;
if (dic.TryGetValue(key, out val))
return val;
return dic[key] = getter(key);
}
public static Expression Unwrap(Expression @this, bool casts=true)
{
while (@this.NodeType == ExpressionType.Quote || casts && @this.NodeType == ExpressionType.Convert)
@this = ((UnaryExpression)@this).Operand;
return @this;
}
}
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="System.Memory" version="4.5.2" targetFramework="net45" />
<package id="ServiceStack.Text" version="5.5.0" targetFramework="net45" />
<package id="ServiceStack.Client" version="5.5.0" targetFramework="net45" />
<package id="ServiceStack.Interfaces" version="5.5.0" targetFramework="net45" />
</packages>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment