Skip to content

Instantly share code, notes, and snippets.

@davidwhitney
Last active September 6, 2019 06:15
Show Gist options
  • Save davidwhitney/9bcad527168eabd285c0287523c62a6e to your computer and use it in GitHub Desktop.
Save davidwhitney/9bcad527168eabd285c0287523c62a6e to your computer and use it in GitHub Desktop.
Container.java - a feature rich DI container in 339 LOC

A Tiny-DI-Container in a single file only 339 lines long.

Features:

  • Lifecycle management (Transient, Singleton)

  • Fluent bindings

    • By default can create any concrete type without any bindings
    • Can bind Type to self
    • Can bind Inteface to implementation
    • Can bind Type or Interface to lambda function factory method
  • Supports multiple bindings

    • Can bind Foo to FooImpl1 and FooImpl2, then constructor inject Foo[] to get an instance of both

Uses non-invasive constructor injection only. Always attempts to use largest constructor.

Requires lombok for var's in old versions of Java, and for @Data attributes. Requires JUnit5 and assertJ for tests.

I also wasted about 35 lines of this code on a good error message. Thank me later ;)

var somePreCreatedType = new MyThing();
var container = new Container()
.register(SomeImplementation.class)
.register(somePreCreatedType)
.register(Interface.class, (x) -> someFactoryMethod(), Container.Lifecycle.Singleton)
.register(Interface2.class, (x) -> x.get(SomeDependencyFromContainer.class), Container.Lifecycle.Singleton)
.register(Interface3.class, Implementation3.class);
var instance = container.get(Interface3.class);
package com.electricheadsoftware.tinycontainer;
import lombok.Data;
import lombok.Getter;
import lombok.var;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
public class Container {
private HashMap<Class, List<Binding>> bindings = new HashMap<>();
private HashMap<Lifecycle, LifeCycleStrategy> lifeCycleStrategies = new HashMap<>();
private HashMap<ContainerBindingType, CreationStrategy> creationStrategies = new HashMap<>();
public Container() {
lifeCycleStrategies.put(Lifecycle.Transient, new FreshInstanceStrategy());
lifeCycleStrategies.put(Lifecycle.Singleton, new SingletonActivationStrategy());
creationStrategies.put(ContainerBindingType.Type, new CreateUsingLargestConstructor(this));
creationStrategies.put(ContainerBindingType.Factory, new DelegateToBindingStrategy());
}
public <T> T get(Class<T> type) {
return get(type, new CreationContext());
}
public <T> T get(Class<T> type, CreationContext context) {
return (T)getObject(unpackType(type), context, type.isArray());
}
private Object getObject(Class type){
return getObject(unpackType(type), new CreationContext(), type.isArray());
}
private Object getObject(Class type, final CreationContext context, boolean getAll) {
var bindings = getBindings(type, getAll);
var created = new ArrayList();
for(var binding : bindings) {
context.log(type, binding);
var lifeCycle = lifeCycleStrategies.get(binding.lifecycle);
var creationStrategy = creationStrategies.get(binding.kind);
try {
created.add(lifeCycle.getOrCreateInstance(binding, creationStrategy, context));
} catch (Exception e) {
throw new ContainerResolutionException(e, context);
}
}
return getAll ? returnTypedInstance(type, created) : created.get(created.size()-1);
}
public Container register(Object instance) {
addBinding(instance.getClass(), new Binding(instance.getClass(), new InstanceProvider(instance), ContainerBindingType.Factory));
return this;
}
public Container register(Class toSelf, Lifecycle lifecycle) {
addBinding(toSelf, new Binding(toSelf, toSelf, ContainerBindingType.Type, lifecycle));
return this;
}
public Container register(Class src, Class target) {
addBinding(src, new Binding(src, target, ContainerBindingType.Type));
return this;
}
public Container register(Class src, Function<Container, Object> provider) {
return register(src, provider, Lifecycle.Transient);
}
public Container register(Class src, Function<Container, Object> provider, Lifecycle lifecycle) {
addBinding(src, new Binding(src, new DelegatingProvider(src, provider, this), ContainerBindingType.Factory, lifecycle));
return this;
}
private List<Binding> getBindings(Class type, boolean getAll) {
if(!bindings.containsKey(type)) {
var generated = new ArrayList<Binding>();
generated.add(new Binding(type, type, ContainerBindingType.Type, Lifecycle.Transient));
return generated;
}
var bindingCollection = bindings.get(type);
if(getAll) {
return bindingCollection;
}
var lastBinding = new ArrayList<Binding>();
lastBinding.add(bindingCollection.get(bindingCollection.size() -1));
return lastBinding;
}
private void addBinding(Class target, Binding binding) {
if(!bindings.containsKey(target)) {
bindings.put(target, new ArrayList<>());
}
bindings.get(target).add(binding);
}
private <T> Class unpackType(Class<T> type) {
return type.isArray() ? type.getComponentType() : type;
}
private Object returnTypedInstance(Class requestedType, ArrayList created) {
try {
var typedArrayInstance = Array.newInstance(requestedType, created.size());
for (int i = 0; i < created.size(); i++) {
Array.set(typedArrayInstance, i, created.get(i));
}
return typedArrayInstance;
} catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
@Data
public class Binding {
public ContainerBindingType kind;
public Lifecycle lifecycle;
public Class src;
public Object destination;
Binding(Class src, Object destination, ContainerBindingType kind) {
this(src, destination, kind, Lifecycle.Transient);
}
Binding(Class src, Object destination, ContainerBindingType kind, Lifecycle lifecycle) {
this.src = src;
this.destination = destination;
this.kind = kind;
this.lifecycle = lifecycle;
}
}
public class CreationContext {
@Getter
private final ArrayList<CreationContextEntry> requestedTypes = new ArrayList<>();
public void log(Class type, Binding binding) {
requestedTypes.add(new CreationContextEntry(type,binding));
}
}
@Data
public class CreationContextEntry {
private Class type;
private Binding binding;
CreationContextEntry(Class type, Binding binding) {
this.type = type;
this.binding = binding;
}
}
public enum ContainerBindingType {
Type,
Factory
}
public interface CreationFactory {
Object getInstance(Binding binding);
}
public class DelegatingProvider implements CreationFactory {
private final Class creates;
private final Function<Container, Object> provider;
private final Container container;
DelegatingProvider(Class src, Function<Container, Object> provider, Container container) {
creates = src;
this.provider = provider;
this.container = container;
}
Class getTargetClassType(){ return creates; }
@Override
public Object getInstance(Binding binding) {
try {
return provider.apply(container);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}
public class InstanceProvider implements CreationFactory {
private final Object instance;
InstanceProvider(Object instance) { this.instance = instance; }
@Override
public Object getInstance(Binding binding) { return instance; }
}
public enum Lifecycle {
Transient,
Singleton
}
public interface CreationStrategy {
Object createInstance(Binding binding, CreationContext context) throws IllegalAccessException, InvocationTargetException, InstantiationException;
}
public interface LifeCycleStrategy {
Object getOrCreateInstance(Binding binding, CreationStrategy creationStrategyForBinding, CreationContext context) throws IllegalAccessException, InstantiationException, InvocationTargetException;
}
public class FreshInstanceStrategy implements LifeCycleStrategy {
@Override
public Object getOrCreateInstance(Binding binding, CreationStrategy creationStrategyForBinding, CreationContext context) throws IllegalAccessException, InstantiationException, InvocationTargetException {
return creationStrategyForBinding.createInstance(binding, context);
}
}
public class SingletonActivationStrategy implements LifeCycleStrategy {
private HashMap<Class, Object> instanceCache = new HashMap<>();
@Override
public Object getOrCreateInstance(Binding binding, CreationStrategy creationStrategyForBinding, CreationContext context) throws IllegalAccessException, InstantiationException, InvocationTargetException {
var type = binding.destination instanceof DelegatingProvider
? ((DelegatingProvider) binding.destination).getTargetClassType()
: binding.destination;
if(instanceCache.containsKey(type)) {
return instanceCache.get(type);
}
var instance = binding.destination instanceof DelegatingProvider
? ((CreationFactory) binding.destination).getInstance(binding)
: creationStrategyForBinding.createInstance(binding, context);
instanceCache.put(instance.getClass(), instance);
return instance;
}
}
public class DelegateToBindingStrategy implements CreationStrategy {
@Override
public Object createInstance(Binding binding, CreationContext context) {
return ((CreationFactory)binding.destination).getInstance(binding);
}
}
public class CreateUsingLargestConstructor implements CreationStrategy {
private final Container parent;
CreateUsingLargestConstructor(Container parent) {
this.parent = parent;
}
@Override
public Object createInstance(Binding binding, CreationContext context) throws IllegalAccessException, InvocationTargetException, InstantiationException {
var typeToCreate = (Class)binding.destination;
var ctor = findMostExhaustiveConstructor(typeToCreate);
var params = ctor.getParameterTypes();
var createdParams = new Object[params.length];
for (int i = 0; i < params.length; i++) {
Class p = params[i];
var paramInstance = parent.get(p, context);
createdParams[i] = paramInstance;
}
return ctor.newInstance(createdParams);
}
private Constructor findMostExhaustiveConstructor(Class targetType) {
var ctors = targetType.getDeclaredConstructors();
var paramCount = 0;
var selected = ctors[0];
for(var c : ctors){
if(c.getParameterCount() > paramCount){
selected = c;
}
}
selected.setAccessible(true);
return selected;
}
}
}
class ContainerResolutionException extends RuntimeException {
@Getter
private Exception inner;
@Getter
private Container.CreationContext context;
ContainerResolutionException(Exception inner, Container.CreationContext context) {
super(generateMessage(context));
this.inner = inner;
this.context = context;
super.setStackTrace(inner.getStackTrace());
}
private static String generateMessage(Container.CreationContext context) {
var message = "Failed to get requested instance of \'"
+ context.getRequestedTypes()
.get(0)
.getType()
.getSimpleName() + "'." + System.lineSeparator() +
"Creating type \'"
+ context.getRequestedTypes()
.get(context.getRequestedTypes().size() - 1)
.getType()
.getSimpleName() + "' failed." + System.lineSeparator() +
"Types requested: " + System.lineSeparator();
for (var typeReq : context.getRequestedTypes()) {
message += "\t\tType: "
+ typeReq.getType() + System.lineSeparator()
+ "\t\tBinding: "
+ typeReq.getBinding().toString()
+ System.lineSeparator()
+ System.lineSeparator();
}
message += "The last type in this list failed creating. " + System.lineSeparator() +
"Likely because one of it's dependencies could not be instantiated." + System.lineSeparator() +
"Steps to resolve: " + System.lineSeparator() +
"\t1- If ctor params are abstract or interfaces, ensure at least one binding exists." + System.lineSeparator() +
"\t2- If bindings are concrete types, ensure types are available to the classLoader." + System.lineSeparator() +
"\t3- Ensure there are no circular dependencies" + System.lineSeparator() +
"\t4- Avoid injecting primitive types.";
return message;
}
}
package com.electricheadsoftware.tinycontainer;
import lombok.Data;
import lombok.var;
import org.junit.Before;
import org.junit.Test;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class ContainerTest {
private Container sut;
@Before
public void setUp() {
sut = new Container();
}
@Test
public void getNoBindingsCreatesType() {
var instance = sut.get(NoDeps.class);
assertThat(instance).isNotNull();
}
@Test
public void getWithBindingsCreatesType() {
sut.register(NoDeps.class);
var instance = sut.get(NoDeps.class);
assertThat(instance).isNotNull();
}
@Test
public void getWithFactoryBindingsDelegates() {
var called = new AtomicBoolean(false);
sut.register(NoDeps.class, (x)->{
called.set(true);
return new NoDeps();
});
sut.get(FakeStorage.class);
assertThat(called.get()).isTrue();
}
@Test
public void getWithFactoryBindingsAsSingletonOnlyCalledOnce() {
var called = new AtomicInteger();
sut.register(NoDeps.class, (x)->{
called.getAndIncrement();
return new NoDeps();
}, Container.Lifecycle.Singleton);
sut.get(NoDeps.class);
sut.get(NoDeps.class);
assertThat(called.get()).isEqualTo(1);
}
@Test
public void getWithInstanceReturnsType() {
var instance = new NoDeps();
sut.register(instance);
var instance2 = sut.get(NoDeps.class);
assertThat(instance).isEqualTo(instance2);
}
@Test
public void getWithBindingsForSingletonCreatesTypeOnlyOnce() {
sut.register(NoDeps.class, Container.Lifecycle.Singleton);
var instance1 = sut.get(NoDeps.class);
var instance2 = sut.get(NoDeps.class);
assertThat(instance1).isEqualTo(instance2);
}
@Test
public void getWithBindingsForTransientCreatesTypeEachTime() {
sut.register(NoDeps.class, Container.Lifecycle.Transient);
var instance1 = sut.get(NoDeps.class);
var instance2 = sut.get(NoDeps.class);
assertThat(instance1).isNotEqualTo(instance2);
}
@Test
public void getCantResolveDependenciesErrorIsDescriptive() {
assertThatThrownBy(()-> sut.get(SomeClass.class))
.isInstanceOf(ContainerResolutionException.class)
.hasMessageContaining("Failed to get requested instance of 'SomeClass'.")
.hasMessageContaining("Creating type 'SomeAbstractDependency' failed.");
}
@Test
public void getCanReturnAllBoundInstancesWhenRequestingArray() {
sut.register(SomeType.class, (x)-> new SomeType("one"));
sut.register(SomeType.class, (x)-> new SomeType("two"));
var instance = sut.get(SomeType[].class);
assertThat(instance).isNotNull();
}
}
class NoDeps {
}
class SomeClass {
SomeClass(SomeDirectDependency dep) {
}
}
class SomeDirectDependency {
SomeDirectDependency(SomeAbstractDependency boom){
}
}
abstract class SomeAbstractDependency {
}
@Data
class SomeType {
private final String value;
SomeType(String value) {
this.value = value;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment