Skip to content

Instantly share code, notes, and snippets.

@cykl
Last active July 5, 2017 21:11
Show Gist options
  • Save cykl/0daeb8c1f17c6eae5feacbebf50d4b8d to your computer and use it in GitHub Desktop.
Save cykl/0daeb8c1f17c6eae5feacbebf50d4b8d to your computer and use it in GitHub Desktop.
Code sample to figure out how to check whether the return type of two methods are equals (including generics). See https://github.com/joel-costigliola/assertj-core/issues/1005
package cma.sandox;
import com.google.common.reflect.TypeResolver;
import org.assertj.core.api.WithAssertions;
import org.junit.Test;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.List;
public class TypeResolverTest implements WithAssertions {
static class Api {
public static <T> Asssert<T> m(List<? extends T> in) {
return null;
}
public static <T> Asssert<T> mSame(List<? extends T> in) {
return null;
}
public static <T> Asssert<? extends T> mExtends(List<? extends T> in) {
return null;
}
public static <T> Asssert<? super T> mSuper(List<? extends T> in) {
return null;
}
}
@Test
public void T_and_T_are_equals() {
Type m = resolveGenericReturnType(Api.class, "m");
Type mSame = resolveGenericReturnType(Api.class, "mSame");
assertThat(m).isEqualTo(mSame);
}
@Test
public void T_and_QUESTION_MARK_extends_T_are_not_equals() {
assertThat(resolveGenericReturnType(Api.class, "m"))
.isNotEqualTo(resolveGenericReturnType(Api.class, "mExtends"));
}
@Test
public void T_and_QUESTION_MARK_super_T_are_not_equals() {
assertThat(resolveGenericReturnType(Api.class, "m"))
.isNotEqualTo(resolveGenericReturnType(Api.class, "mSuper"));
}
private static Type resolveGenericReturnType(Class<?> cls, String methodName) {
Method method = Arrays.stream(cls.getMethods())
.filter(m -> m.getName().equals(methodName))
.findFirst()
.orElseThrow(() -> new RuntimeException("Method not found: class=" + cls + " name=" + methodName));
return resolveGenericReturnType(method);
}
private static Type resolveGenericReturnType(Method method) {
Type genericReturnType = method.getGenericReturnType();
if (genericReturnType instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) genericReturnType;
TypeResolver typeResolver = new TypeResolver();
for (Type type : parameterizedType.getActualTypeArguments()) {
typeResolver = typeResolver.where(type, String.class);
}
return typeResolver.resolveType(parameterizedType);
} else {
return genericReturnType;
}
}
}
package cma.sandox;
import com.google.common.reflect.TypeResolver;
import org.assertj.core.api.AbstractListAssert;
import org.assertj.core.api.ObjectAssert;
import org.assertj.core.api.WithAssertions;
import org.junit.Test;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.*;
public class TypeResolverTest2 implements WithAssertions {
static class Asssert<T> {
}
static class Api {
public static <T> Asssert<T> m(List<? extends T> in) {
return null;
}
public static <T> Asssert<T> mSame(List<? extends T> in) {
return null;
}
public static <T> Asssert<? extends T> mExtends(List<? extends T> in) {
return null;
}
public static <T> Asssert<? super T> mSuper(List<? extends T> in) {
return null;
}
public static <T> AbstractListAssert<?, List<? extends T>, T, ObjectAssert<T>> complex1(List<? extends T> in) {
return null;
}
public static <T> AbstractListAssert<?, List<? extends T>, T, ObjectAssert<T>> complex2(List<? extends T> in) {
return null;
}
}
@Test
public void T_and_T_are_equals() {
Type m = resolveGenericReturnType(Api.class, "m");
Type mSame = resolveGenericReturnType(Api.class, "mSame");
assertThat(m).isEqualTo(mSame);
}
@Test
public void T_and_QUESTION_MARK_extends_T_are_not_equals() {
assertThat(resolveGenericReturnType(Api.class, "m"))
.isNotEqualTo(resolveGenericReturnType(Api.class, "mExtends"));
}
@Test
public void T_and_QUESTION_MARK_super_T_are_not_equals() {
assertThat(resolveGenericReturnType(Api.class, "m"))
.isNotEqualTo(resolveGenericReturnType(Api.class, "mSuper"));
}
@Test
public void complex() {
Type m = resolveGenericReturnType(Api.class, "complex1");
Type mSame = resolveGenericReturnType(Api.class, "complex2");
assertThat(m).isEqualTo(mSame);
}
private static Type resolveGenericReturnType(Class<?> cls, String methodName) {
Method method = Arrays.stream(cls.getMethods())
.filter(m -> m.getName().equals(methodName))
.findFirst()
.orElseThrow(() -> new RuntimeException("Method not found: class=" + cls + " name=" + methodName));
return TypeCanonizer.canonize(method.getGenericReturnType());
}
static class TypeCanonizer {
/**
* Returns a canonical form of {@code initialType} by replacing all {@link TypeVariable} by {@link Class}
* instances.
*
* <p>
* Such a canonical form allows to compare {@link ParameterizedType}s.
* </p>
*/
public static Type canonize(Type initialType) {
if (!(initialType instanceof ParameterizedType)) {
return initialType;
}
ReplacementClassSupplier replacementClassSupplier = new ReplacementClassSupplier();
TypeResolver typeResolver = new TypeResolver();
for (TypeVariable typeVariable : findAllTypeVariables(initialType)) {
typeResolver = typeResolver.where(typeVariable, replacementClassSupplier.get());
}
return typeResolver.resolveType(initialType);
}
/** Returns all {@code type}'s {@link TypeVariable} */
private static Set<TypeVariable> findAllTypeVariables(Type type) {
Set<TypeVariable> typeVariables = new HashSet<>();
findAllTypeVariables(typeVariables, type);
return typeVariables;
}
/** Adds all {@code type}'s {@link TypeVariable} to {@code typeVariables} */
private static void findAllTypeVariables(Set<TypeVariable> typeVariables, Type type) {
if (type instanceof ParameterizedType) {
for (Type typeArgument : ((ParameterizedType) type).getActualTypeArguments()) {
if (typeArgument instanceof TypeVariable) {
typeVariables.add((TypeVariable) typeArgument);
} else if (typeArgument instanceof ParameterizedType) {
findAllTypeVariables(typeVariables, typeArgument);
}
}
}
}
}
static class ReplacementClassSupplier {
static List<Class> REPLACEMENT_TYPES = Arrays.asList(
String.class, Integer.class, Exception.class, InputStream.class, System.class);
private final Queue<Class> classPool;
ReplacementClassSupplier() {
this(REPLACEMENT_TYPES);
}
ReplacementClassSupplier(Collection<Class> classPool) {
this.classPool = new ArrayDeque<>(classPool);
}
/**
* Returns a class which has not yet been returned
*
* @throws IllegalStateException If the replacement class poll is exhausted
*/
Class get() {
Class clazz = classPool.poll();
if (clazz == null) {
throw new IllegalStateException();
}
return clazz;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment