Skip to content

Instantly share code, notes, and snippets.

@pabloogc
Created August 23, 2016 08:17
Show Gist options
  • Save pabloogc/a1a39944eb8e2006834ba4027794c0b8 to your computer and use it in GitHub Desktop.
Save pabloogc/a1a39944eb8e2006834ba4027794c0b8 to your computer and use it in GitHub Desktop.
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DaggerOverride {
Class<?> value();
String fieldName() default "";
}
import org.junit.Rule;
import org.junit.rules.MethodRule;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import javax.inject.Provider;
public class DaggerOverrideRule implements MethodRule {
private Object testClassInstance;
private List<TypedProvider> componentProviders = new ArrayList<>();
private List<OverrideField> overrideFields = new ArrayList<>();
/**
* Immediate component provider. Make sure mock is ready at test class creation
* (that is, singletons). Otherwise use the lazy alternative {@link #component(Provider)}
*/
@SuppressWarnings("unchecked")
public <T> DaggerOverrideRule component(final T component) {
final Provider<T> provider = new Provider<T>() {
@Override
public T get() {
return component;
}
};
this.componentProviders.add(new TypedProvider<>(
(Class<T>) component.getClass(), provider
));
return this;
}
/**
* Lazy component provider, required when dealing with other rules like Espresso or Mockito
* unless you chain them in the correct order.
*/
@SuppressWarnings("unchecked")
public <T> DaggerOverrideRule component(Provider<T> component) {
final Class<?> returnType;
try {
returnType = component.getClass().getDeclaredMethod("get").getReturnType();
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
componentProviders.add(new TypedProvider(returnType, component));
return this;
}
/**
* Rule implementation, use {@link #apply(Object)} if using it without {@link Rule}
*/
@Override
public Statement apply(final Statement base, FrameworkMethod method, Object target) {
apply(target);
return new Statement() {
@Override
public void evaluate() throws Throwable {
before();
base.evaluate();
after();
}
};
}
/**
* Find all the {@link DaggerOverride} fields in the test class (testClassInstance).
* You only need to do this once per test class.
* <p/>
* After applying the rule call {@link #before()} and {@link #after()} to setup/cleanup
* dagger mocks.
*/
public void apply(Object testClassInstance) {
this.testClassInstance = testClassInstance;
processDaggerOverrideFields(testClassInstance.getClass());
}
public void before() throws Exception {
for (final OverrideField overrideField : overrideFields) {
final Object componentInstance = overrideField.componentProvider.get();
final Field field = componentInstance
.getClass()
.getDeclaredField(overrideField.name);
field.setAccessible(true);
overrideField.overrideValue = overrideField.testClassField.get(testClassInstance);
overrideField.originalValue = field.get(componentInstance);
field.set(componentInstance, new Provider() {
@Override
public Object get() {
return overrideField.overrideValue;
}
});
}
}
public void after() throws Exception {
//clean up
for (OverrideField overrideField : overrideFields) {
final Object componentInstance = overrideField.componentProvider
.get();
final Field field = componentInstance
.getClass()
.getDeclaredField(overrideField.name);
field.setAccessible(true);
field.set(componentInstance, overrideField.originalValue);
}
}
private void processDaggerOverrideFields(Class<?> testClass) {
try {
for (Field field : testClass.getDeclaredFields()) {
field.setAccessible(true);
final DaggerOverride annotation = field.getAnnotation(DaggerOverride.class);
if (annotation == null) continue;
//Find the TypeToMockProvider field in the dagger generated component
//By type if present or name if its not empty.
String fieldName;
if (annotation.fieldName().isEmpty()) {
final List<Field> provisionFields =
findProvisionFieldsInComponent(annotation.value(), field.getType());
if (provisionFields.size() != 1) {
throw new IllegalArgumentException(
"Cant't match by type, provision methods found: "
+ provisionFields.size()
+ " expected exactly 1");
}
fieldName = provisionFields.get(0).getName();
} else {
fieldName = annotation.fieldName();
//Will throw if testClassField does not exist
annotation.value().getDeclaredField(fieldName);
}
Provider<?> provider = null;
for (TypedProvider<?> componentProvider : componentProviders) {
if (annotation.value().isAssignableFrom(componentProvider.providerType)) {
provider = componentProvider.provider;
break;
}
}
if (provider == null) {
throw new RuntimeException("No provider found for: " + annotation.getClass());
}
overrideFields.add(new OverrideField(field, provider, fieldName));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static final class TypedProvider<T> {
final Class<T> providerType;
final Provider<T> provider;
private TypedProvider(Class<T> providerType, Provider<T> provider) {
this.providerType = providerType;
this.provider = provider;
}
}
private static final class OverrideField {
private final Field testClassField;
private final Provider<?> componentProvider;
private final String name;
private Object overrideValue;
private Object originalValue;
private OverrideField(Field testClassField, Provider<?> componentProvider, String name) {
this.testClassField = testClassField;
this.componentProvider = componentProvider;
this.name = name;
}
}
private static List<Field> findProvisionFieldsInComponent(Class<?> componentClass, Class<?> type) {
try {
//Find the dagger implementation
componentClass = Class.forName(componentClass.getPackage().getName() + ".Dagger" + componentClass.getSimpleName());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
final ArrayList<Field> out = new ArrayList<>();
for (Field componentField : componentClass.getDeclaredFields()) {
final Type componentFieldType = componentField.getGenericType();
if (componentFieldType instanceof ParameterizedType
&& ((ParameterizedType) componentFieldType).getRawType().equals(Provider.class)) {
final ParameterizedType parameterizedType = (ParameterizedType) componentFieldType;
if (parameterizedType.getActualTypeArguments()[0].equals(type)) {
//Match!
out.add(componentField);
}
}
}
return out;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment