Last active
July 5, 2017 21:11
-
-
Save cykl/0daeb8c1f17c6eae5feacbebf50d4b8d to your computer and use it in GitHub Desktop.
Code sample to figure out how to check whether the return type of two methods are equals (including generics). See https://github.com/joel-costigliola/assertj-core/issues/1005
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cma.sandox; | |
import com.google.common.reflect.TypeResolver; | |
import org.assertj.core.api.WithAssertions; | |
import org.junit.Test; | |
import java.lang.reflect.Method; | |
import java.lang.reflect.ParameterizedType; | |
import java.lang.reflect.Type; | |
import java.util.Arrays; | |
import java.util.List; | |
public class TypeResolverTest implements WithAssertions { | |
static class Api { | |
public static <T> Asssert<T> m(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<T> mSame(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<? extends T> mExtends(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<? super T> mSuper(List<? extends T> in) { | |
return null; | |
} | |
} | |
@Test | |
public void T_and_T_are_equals() { | |
Type m = resolveGenericReturnType(Api.class, "m"); | |
Type mSame = resolveGenericReturnType(Api.class, "mSame"); | |
assertThat(m).isEqualTo(mSame); | |
} | |
@Test | |
public void T_and_QUESTION_MARK_extends_T_are_not_equals() { | |
assertThat(resolveGenericReturnType(Api.class, "m")) | |
.isNotEqualTo(resolveGenericReturnType(Api.class, "mExtends")); | |
} | |
@Test | |
public void T_and_QUESTION_MARK_super_T_are_not_equals() { | |
assertThat(resolveGenericReturnType(Api.class, "m")) | |
.isNotEqualTo(resolveGenericReturnType(Api.class, "mSuper")); | |
} | |
private static Type resolveGenericReturnType(Class<?> cls, String methodName) { | |
Method method = Arrays.stream(cls.getMethods()) | |
.filter(m -> m.getName().equals(methodName)) | |
.findFirst() | |
.orElseThrow(() -> new RuntimeException("Method not found: class=" + cls + " name=" + methodName)); | |
return resolveGenericReturnType(method); | |
} | |
private static Type resolveGenericReturnType(Method method) { | |
Type genericReturnType = method.getGenericReturnType(); | |
if (genericReturnType instanceof ParameterizedType) { | |
ParameterizedType parameterizedType = (ParameterizedType) genericReturnType; | |
TypeResolver typeResolver = new TypeResolver(); | |
for (Type type : parameterizedType.getActualTypeArguments()) { | |
typeResolver = typeResolver.where(type, String.class); | |
} | |
return typeResolver.resolveType(parameterizedType); | |
} else { | |
return genericReturnType; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cma.sandox; | |
import com.google.common.reflect.TypeResolver; | |
import org.assertj.core.api.AbstractListAssert; | |
import org.assertj.core.api.ObjectAssert; | |
import org.assertj.core.api.WithAssertions; | |
import org.junit.Test; | |
import java.io.InputStream; | |
import java.lang.reflect.Method; | |
import java.lang.reflect.ParameterizedType; | |
import java.lang.reflect.Type; | |
import java.lang.reflect.TypeVariable; | |
import java.util.*; | |
public class TypeResolverTest2 implements WithAssertions { | |
static class Asssert<T> { | |
} | |
static class Api { | |
public static <T> Asssert<T> m(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<T> mSame(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<? extends T> mExtends(List<? extends T> in) { | |
return null; | |
} | |
public static <T> Asssert<? super T> mSuper(List<? extends T> in) { | |
return null; | |
} | |
public static <T> AbstractListAssert<?, List<? extends T>, T, ObjectAssert<T>> complex1(List<? extends T> in) { | |
return null; | |
} | |
public static <T> AbstractListAssert<?, List<? extends T>, T, ObjectAssert<T>> complex2(List<? extends T> in) { | |
return null; | |
} | |
} | |
@Test | |
public void T_and_T_are_equals() { | |
Type m = resolveGenericReturnType(Api.class, "m"); | |
Type mSame = resolveGenericReturnType(Api.class, "mSame"); | |
assertThat(m).isEqualTo(mSame); | |
} | |
@Test | |
public void T_and_QUESTION_MARK_extends_T_are_not_equals() { | |
assertThat(resolveGenericReturnType(Api.class, "m")) | |
.isNotEqualTo(resolveGenericReturnType(Api.class, "mExtends")); | |
} | |
@Test | |
public void T_and_QUESTION_MARK_super_T_are_not_equals() { | |
assertThat(resolveGenericReturnType(Api.class, "m")) | |
.isNotEqualTo(resolveGenericReturnType(Api.class, "mSuper")); | |
} | |
@Test | |
public void complex() { | |
Type m = resolveGenericReturnType(Api.class, "complex1"); | |
Type mSame = resolveGenericReturnType(Api.class, "complex2"); | |
assertThat(m).isEqualTo(mSame); | |
} | |
private static Type resolveGenericReturnType(Class<?> cls, String methodName) { | |
Method method = Arrays.stream(cls.getMethods()) | |
.filter(m -> m.getName().equals(methodName)) | |
.findFirst() | |
.orElseThrow(() -> new RuntimeException("Method not found: class=" + cls + " name=" + methodName)); | |
return TypeCanonizer.canonize(method.getGenericReturnType()); | |
} | |
static class TypeCanonizer { | |
/** | |
* Returns a canonical form of {@code initialType} by replacing all {@link TypeVariable} by {@link Class} | |
* instances. | |
* | |
* <p> | |
* Such a canonical form allows to compare {@link ParameterizedType}s. | |
* </p> | |
*/ | |
public static Type canonize(Type initialType) { | |
if (!(initialType instanceof ParameterizedType)) { | |
return initialType; | |
} | |
ReplacementClassSupplier replacementClassSupplier = new ReplacementClassSupplier(); | |
TypeResolver typeResolver = new TypeResolver(); | |
for (TypeVariable typeVariable : findAllTypeVariables(initialType)) { | |
typeResolver = typeResolver.where(typeVariable, replacementClassSupplier.get()); | |
} | |
return typeResolver.resolveType(initialType); | |
} | |
/** Returns all {@code type}'s {@link TypeVariable} */ | |
private static Set<TypeVariable> findAllTypeVariables(Type type) { | |
Set<TypeVariable> typeVariables = new HashSet<>(); | |
findAllTypeVariables(typeVariables, type); | |
return typeVariables; | |
} | |
/** Adds all {@code type}'s {@link TypeVariable} to {@code typeVariables} */ | |
private static void findAllTypeVariables(Set<TypeVariable> typeVariables, Type type) { | |
if (type instanceof ParameterizedType) { | |
for (Type typeArgument : ((ParameterizedType) type).getActualTypeArguments()) { | |
if (typeArgument instanceof TypeVariable) { | |
typeVariables.add((TypeVariable) typeArgument); | |
} else if (typeArgument instanceof ParameterizedType) { | |
findAllTypeVariables(typeVariables, typeArgument); | |
} | |
} | |
} | |
} | |
} | |
static class ReplacementClassSupplier { | |
static List<Class> REPLACEMENT_TYPES = Arrays.asList( | |
String.class, Integer.class, Exception.class, InputStream.class, System.class); | |
private final Queue<Class> classPool; | |
ReplacementClassSupplier() { | |
this(REPLACEMENT_TYPES); | |
} | |
ReplacementClassSupplier(Collection<Class> classPool) { | |
this.classPool = new ArrayDeque<>(classPool); | |
} | |
/** | |
* Returns a class which has not yet been returned | |
* | |
* @throws IllegalStateException If the replacement class poll is exhausted | |
*/ | |
Class get() { | |
Class clazz = classPool.poll(); | |
if (clazz == null) { | |
throw new IllegalStateException(); | |
} | |
return clazz; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment