Created
May 13, 2022 14:26
-
-
Save ledoyen/72790168f80cb21e7491998b9a506b67 to your computer and use it in GitHub Desktop.
Combinator Arguments provider for JUnit-Params
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.github.ledoyen; | |
import org.junit.jupiter.api.extension.ExtensionContext; | |
import org.junit.jupiter.params.provider.Arguments; | |
import org.junit.jupiter.params.provider.ArgumentsProvider; | |
import org.junit.jupiter.params.provider.ArgumentsSource; | |
import java.lang.annotation.Documented; | |
import java.lang.annotation.ElementType; | |
import java.lang.annotation.Retention; | |
import java.lang.annotation.RetentionPolicy; | |
import java.lang.annotation.Target; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.Set; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
import static org.assertj.core.api.Assertions.assertThat; | |
@Target(ElementType.METHOD) | |
@Retention(RetentionPolicy.RUNTIME) | |
@Documented | |
@ArgumentsSource(AllCombinations.AllCombinationsProvider.class) | |
public @interface AllCombinations { | |
class AllCombinationsProvider implements ArgumentsProvider { | |
private static final List<TypeSupport> supportedTypes = List.of( | |
new TypeSupport(Class::isEnum, c -> Arrays.stream(c.getEnumConstants())), | |
new TypeSupport(Boolean.class::equals, c -> Stream.of(false, true)), | |
new TypeSupport(boolean.class::equals, c -> Stream.of(false, true)) | |
); | |
@Override | |
public Stream<? extends Arguments> provideArguments(ExtensionContext extensionContext) throws Exception { | |
Class<?>[] parameterTypes = extensionContext.getRequiredTestMethod().getParameterTypes(); | |
validate(parameterTypes); | |
return generateCombinations(Arrays.asList(parameterTypes)).stream(); | |
} | |
private void validate(Class<?>[] parameterTypes) { | |
Function<Class<?>, Boolean> isSupported = supportedTypes.stream().map(ts -> ts.isSupportedFunction) | |
.reduce(c -> false, (f1, f2) -> (c -> f1.apply(c) || f2.apply(c))); | |
Set<Class<?>> unsupportedParameterTypes = Arrays.stream(parameterTypes) | |
.filter(p -> !isSupported.apply(p)) | |
.collect(Collectors.toSet()); | |
assertThat(unsupportedParameterTypes).as("unsupported parameter types").isEmpty(); | |
} | |
private List<Arguments> generateCombinations(List<Class<?>> parameters) { | |
List<Arguments> combinationsAccumulator = new ArrayList<>(); | |
generateCombinationsRecursively(parameters, combinationsAccumulator, new Object[parameters.size()]); | |
return combinationsAccumulator; | |
} | |
private void generateCombinationsRecursively(List<Class<?>> parameters, List<Arguments> accumulator, Object data[]) { | |
if (parameters.isEmpty()) { | |
Object[] combination = data.clone(); | |
accumulator.add(Arguments.arguments(combination)); | |
} else { | |
Class<?> currentParameter = parameters.get(0); | |
Stream<Object> values = supportedTypes.stream() | |
.filter(st -> st.isSupportedFunction.apply(currentParameter)) | |
.findFirst() | |
.orElseThrow(IllegalStateException::new) | |
.valueGenerator.apply(currentParameter); | |
values.forEach(v -> { | |
data[data.length - parameters.size()] = v; | |
generateCombinationsRecursively(parameters.subList(1, parameters.size()), accumulator, data); | |
}); | |
} | |
} | |
private record TypeSupport(Function<Class<?>, Boolean> isSupportedFunction, | |
Function<Class<?>, Stream<Object>> valueGenerator) { | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment