Created
September 12, 2020 05:52
-
-
Save vinaysshenoy/eba388d52c286528b7d8473188a5db8c to your computer and use it in GitHub Desktop.
Composite JUnit5 extension for composing a group of domain specific extensions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.jupiter.api.extension.AfterAllCallback | |
import org.junit.jupiter.api.extension.AfterEachCallback | |
import org.junit.jupiter.api.extension.AfterTestExecutionCallback | |
import org.junit.jupiter.api.extension.BeforeAllCallback | |
import org.junit.jupiter.api.extension.BeforeEachCallback | |
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback | |
import org.junit.jupiter.api.extension.ConditionEvaluationResult | |
import org.junit.jupiter.api.extension.ExecutionCondition | |
import org.junit.jupiter.api.extension.Extension | |
import org.junit.jupiter.api.extension.ExtensionContext | |
import org.junit.jupiter.api.extension.InvocationInterceptor | |
import org.junit.jupiter.api.extension.LifecycleMethodExecutionExceptionHandler | |
import org.junit.jupiter.api.extension.ParameterContext | |
import org.junit.jupiter.api.extension.ParameterResolver | |
import org.junit.jupiter.api.extension.TestExecutionExceptionHandler | |
import org.junit.jupiter.api.extension.TestInstanceFactory | |
import org.junit.jupiter.api.extension.TestInstanceFactoryContext | |
import org.junit.jupiter.api.extension.TestInstancePostProcessor | |
import org.junit.jupiter.api.extension.TestInstancePreDestroyCallback | |
import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider | |
import org.junit.jupiter.api.extension.TestWatcher | |
import java.lang.reflect.Parameter | |
import java.util.Optional | |
/** | |
* This is a custom [Extension] that can be used to group various other annotations | |
* together. | |
* | |
* This is useful for cases where you want certain extensions to be added to different | |
* test classes, but want to avoid creating base classes and inheriting from them. | |
**/ | |
class CompositeExtension( | |
private var extensions: List<Extension> = emptyList() | |
) : ExecutionCondition, | |
TestInstanceFactory, | |
TestInstancePostProcessor, | |
TestInstancePreDestroyCallback, | |
ParameterResolver, | |
TestWatcher, | |
BeforeAllCallback, | |
BeforeEachCallback, | |
BeforeTestExecutionCallback, | |
AfterTestExecutionCallback, | |
AfterEachCallback, | |
AfterAllCallback, | |
TestExecutionExceptionHandler, | |
LifecycleMethodExecutionExceptionHandler { | |
init { | |
validateRegisteredExtensions() | |
} | |
private val unsupportedExtensionTypes = setOf( | |
InvocationInterceptor::class.java, | |
TestTemplateInvocationContextProvider::class.java | |
) | |
private val parameterResolverLookup = mutableMapOf<Parameter, ParameterResolver>() | |
fun register(extension: Extension): CompositeExtension { | |
extensions += extension | |
validateRegisteredExtensions() | |
return this | |
} | |
private fun validateRegisteredExtensions() { | |
extensions.forEach { extension -> | |
val firstUnsupportedExtensionType = unsupportedExtensionTypes.firstOrNull { | |
it.isAssignableFrom( | |
extension::class.java | |
) | |
} | |
if (firstUnsupportedExtensionType != null) { | |
throw IllegalStateException( | |
"Extensions of type ${firstUnsupportedExtensionType.name} are not supported yet!" | |
) | |
} | |
} | |
} | |
override fun evaluateExecutionCondition(context: ExtensionContext): ConditionEvaluationResult { | |
val results = findExtensions<ExecutionCondition>().map { it.evaluateExecutionCondition(context) } | |
val firstDisabledResult = results.find { it.isDisabled } | |
return firstDisabledResult ?: results.last() | |
} | |
override fun createTestInstance( | |
factoryContext: TestInstanceFactoryContext, | |
extensionContext: ExtensionContext | |
): Any { | |
val instanceFactories = findExtensions<TestInstanceFactory>() | |
return when { | |
instanceFactories.isEmpty() -> factoryContext.testClass.newInstance() | |
instanceFactories.size > 1 -> { | |
val foundFactories = instanceFactories.joinToString { it.javaClass.name } | |
throw IllegalStateException( | |
"Only one `TestInstanceFactory` can be registered! Found: [$foundFactories]" | |
) | |
} | |
else -> instanceFactories.first().createTestInstance(factoryContext, extensionContext) | |
} | |
} | |
override fun postProcessTestInstance(testInstance: Any, context: ExtensionContext) { | |
findExtensions<TestInstancePostProcessor>().forEach { it.postProcessTestInstance(testInstance, context) } | |
} | |
override fun preDestroyTestInstance(context: ExtensionContext) { | |
findExtensions<TestInstancePreDestroyCallback>().forEach { it.preDestroyTestInstance(context) } | |
} | |
override fun supportsParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Boolean { | |
val extensionThatCanHandleParameter = findExtensions<ParameterResolver>().find { | |
it.supportsParameter(parameterContext, extensionContext) | |
} | |
if (extensionThatCanHandleParameter != null) { | |
parameterResolverLookup[parameterContext.parameter] = extensionThatCanHandleParameter | |
} | |
return extensionThatCanHandleParameter != null | |
} | |
override fun resolveParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Any { | |
return parameterResolverLookup.getValue(parameterContext.parameter).resolveParameter( | |
parameterContext, | |
extensionContext | |
) | |
} | |
override fun testDisabled(context: ExtensionContext, reason: Optional<String>) { | |
findExtensions<TestWatcher>().forEach { it.testDisabled(context, reason) } | |
} | |
override fun testSuccessful(context: ExtensionContext) { | |
findExtensions<TestWatcher>().forEach { it.testSuccessful(context) } | |
} | |
override fun testAborted(context: ExtensionContext, cause: Throwable?) { | |
findExtensions<TestWatcher>().forEach { it.testAborted(context, cause) } | |
} | |
override fun testFailed(context: ExtensionContext?, cause: Throwable?) { | |
findExtensions<TestWatcher>().forEach { it.testFailed(context, cause) } | |
} | |
override fun beforeAll(context: ExtensionContext) { | |
findExtensions<BeforeAllCallback>().forEach { it.beforeAll(context) } | |
} | |
override fun beforeEach(context: ExtensionContext) { | |
findExtensions<BeforeEachCallback>().forEach { it.beforeEach(context) } | |
} | |
override fun beforeTestExecution(context: ExtensionContext) { | |
findExtensions<BeforeTestExecutionCallback>().forEach { it.beforeTestExecution(context) } | |
} | |
override fun afterTestExecution(context: ExtensionContext) { | |
findExtensions<AfterTestExecutionCallback>() | |
.reversed() // This is to retain JUnit 5's wrapping behaviour | |
.forEach { it.afterTestExecution(context) } | |
} | |
override fun afterEach(context: ExtensionContext) { | |
findExtensions<AfterEachCallback>() | |
.reversed() // This is to retain JUnit 5's wrapping behaviour | |
.forEach { it.afterEach(context) } | |
} | |
override fun afterAll(context: ExtensionContext) { | |
findExtensions<AfterAllCallback>() | |
.reversed() // This is to retain JUnit 5's wrapping behaviour | |
.forEach { it.afterAll(context) } | |
} | |
override fun handleTestExecutionException(context: ExtensionContext, throwable: Throwable) { | |
val exceptionHandlers = findExtensions<TestExecutionExceptionHandler>() | |
var finalThrowable: Throwable? = throwable | |
exceptionHandlers.forEach { handler -> | |
if (finalThrowable != null) { | |
finalThrowable = tryAndReturnException { handler.handleTestExecutionException(context, finalThrowable) } | |
} | |
} | |
if (finalThrowable != null) throw finalThrowable as Throwable | |
} | |
override fun handleBeforeAllMethodExecutionException(context: ExtensionContext, throwable: Throwable) { | |
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>() | |
var finalThrowable: Throwable? = throwable | |
exceptionHandlers.forEach { handler -> | |
if (finalThrowable != null) { | |
finalThrowable = tryAndReturnException { | |
handler.handleBeforeAllMethodExecutionException(context, finalThrowable) | |
} | |
} | |
} | |
if (finalThrowable != null) throw finalThrowable as Throwable | |
} | |
override fun handleBeforeEachMethodExecutionException(context: ExtensionContext, throwable: Throwable) { | |
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>() | |
var finalThrowable: Throwable? = throwable | |
exceptionHandlers.forEach { handler -> | |
if (finalThrowable != null) { | |
finalThrowable = tryAndReturnException { | |
handler.handleBeforeEachMethodExecutionException(context, finalThrowable) | |
} | |
} | |
} | |
if (finalThrowable != null) throw finalThrowable as Throwable | |
} | |
override fun handleAfterEachMethodExecutionException(context: ExtensionContext, throwable: Throwable) { | |
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>() | |
var finalThrowable: Throwable? = throwable | |
exceptionHandlers.forEach { handler -> | |
if (finalThrowable != null) { | |
finalThrowable = tryAndReturnException { | |
handler.handleAfterEachMethodExecutionException(context, finalThrowable) | |
} | |
} | |
} | |
if (finalThrowable != null) throw finalThrowable as Throwable | |
} | |
override fun handleAfterAllMethodExecutionException(context: ExtensionContext, throwable: Throwable) { | |
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>() | |
var finalThrowable: Throwable? = throwable | |
exceptionHandlers.forEach { handler -> | |
if (finalThrowable != null) { | |
finalThrowable = tryAndReturnException { | |
handler.handleAfterAllMethodExecutionException(context, finalThrowable) | |
} | |
} | |
} | |
if (finalThrowable != null) throw finalThrowable as Throwable | |
} | |
private inline fun <reified T> findExtensions(): List<T> { | |
return extensions.filterIsInstance<T>() | |
} | |
/* | |
* | |
* This is really awkward, but couldn't think of a better way to conform to JUnit 5's behaviour. | |
* | |
* Implementors must perform one of the following. | |
* | |
* 1. Swallow the supplied throwable, thereby preventing propagation. | |
* 2. Rethrow the supplied throwable as is. | |
* 3. Throw a new exception, potentially wrapping the supplied throwable. | |
* | |
* If the supplied throwable is swallowed, subsequent TestExecutionExceptionHandlers will not be invoked; | |
* otherwise, the next registered TestExecutionExceptionHandler (if there is one) will be invoked with any | |
* Throwable thrown by this handler. | |
**/ | |
private inline fun tryAndReturnException(block: () -> Unit): Throwable? { | |
var thrown: Throwable? = null | |
try { | |
block() | |
} catch (e: Throwable) { | |
thrown = e | |
} | |
return thrown | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment