Skip to content

Instantly share code, notes, and snippets.

@renelink
Created March 18, 2022 07:48
Show Gist options
  • Save renelink/3e2a32dfad40428645368c1b76438f06 to your computer and use it in GitHub Desktop.
Save renelink/3e2a32dfad40428645368c1b76438f06 to your computer and use it in GitHub Desktop.
Inject Spring beans as method parameters in JUnit5
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import java.lang.reflect.Parameter;
import java.util.Optional;
/**
* Annotation your spring boot test with
*
* <code>@ExtendWith(SpringMethodParameterResolver.class)</code>
*
* and just define beans in your test methods as parameters.
*
* <code>
* @BeforeEach
* public void setup(EntityManager entityManager, @Qualifier("test") DataSource ds){
* // ...
* }
* </code>
* or
* <code>
* @Test
* public void someTest(EntityManager entityManager){
* // ...
* }
* </code>
*
*/
public class SpringMethodParameterResolver implements ParameterResolver {
@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
try {
return resolveParameter(parameterContext, extensionContext) != null;
} catch (NoSuchBeanDefinitionException e) {
throw new ParameterResolutionException("No bean definition found for parameter " + parameterContext.getParameter(), e);
}
}
@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
Parameter parameter = parameterContext.getParameter();
Qualifier qualifier = parameter.getAnnotation(Qualifier.class);
Optional<String> beanName = Optional.ofNullable(qualifier).map(Qualifier::value);
Class<?> type = parameter.getType();
return getBean(extensionContext, type, beanName);
}
private Object getBean(ExtensionContext extensionContext, Class<?> type, Optional<String> beanName) {
ApplicationContext applicationContext = SpringExtension.getApplicationContext(extensionContext);
Optional<Object> bean = beanName.map(n -> applicationContext.getBean(n, type));
return bean.orElseGet(() -> applicationContext.getBean(type));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment