Skip to content

Instantly share code, notes, and snippets.

@jshiell
Created November 27, 2018 15:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jshiell/7ffb24c95440c0fd67d58641824d2473 to your computer and use it in GitHub Desktop.
Save jshiell/7ffb24c95440c0fd67d58641824d2473 to your computer and use it in GitHub Desktop.
Retry extension for JUnit 5
package com.springer.oscar.test
import com.springer.oscar.test.RetryExtension.Companion.TEST_PASSED
import org.junit.AssumptionViolatedException
import org.junit.jupiter.api.TestTemplate
import org.junit.jupiter.api.extension.*
import org.junit.platform.commons.util.AnnotationUtils.findAnnotation
import org.junit.platform.commons.util.AnnotationUtils.isAnnotated
import org.opentest4j.TestAbortedException
import org.opentest4j.TestSkippedException
import java.util.*
import java.util.Spliterators.spliteratorUnknownSize
import java.util.stream.Stream
import java.util.stream.StreamSupport.stream
class RetryExtension : TestTemplateInvocationContextProvider {
companion object {
internal const val TEST_PASSED: String = "testPassed"
}
override fun provideTestTemplateInvocationContexts(context: ExtensionContext): Stream<TestTemplateInvocationContext> {
val annotation = findAnnotation(context.requiredTestMethod, RetryTest::class.java)
.orElseThrow() { IllegalStateException("The test method must be annotated with @RetryTest") }
val spliterator: Spliterator<TestTemplateInvocationContext> = spliteratorUnknownSize(
RetryTestTemplateIterator(context.displayName, annotation.maxAttempts, context.getStore(storeNamespace(context))), Spliterator.NONNULL)
return stream(spliterator, false)
}
override fun supportsTestTemplate(context: ExtensionContext): Boolean = isAnnotated(context.testMethod, RetryTest::class.java)
}
internal fun storeNamespace(context: ExtensionContext) =
ExtensionContext.Namespace.create(context.requiredTestClass.name, context.requiredTestMethod.name)
internal class RetryTestTemplateIterator(private val displayName: String,
private val maxAttempts: Int,
private val store: ExtensionContext.Store) : Iterator<TestTemplateInvocationContext> {
companion object {
private const val BASE_DELAY = 1000L
}
private var currentAttempt = 0
override fun hasNext(): Boolean = currentAttempt < maxAttempts && store.get(TEST_PASSED) != true
override fun next(): TestTemplateInvocationContext {
if (hasNext()) {
val delay = BASE_DELAY * currentAttempt
if (delay > 0) {
Thread.sleep(BASE_DELAY * currentAttempt)
}
currentAttempt += 1
return RetryInvocationContext(displayName, currentAttempt, maxAttempts, delay, store)
}
throw NoSuchElementException()
}
}
internal class RetryInvocationContext(private val displayName: String,
private val currentAttempt: Int,
private val maxAttempts: Int,
private val delay: Long,
private val store: ExtensionContext.Store) : TestTemplateInvocationContext {
override fun getDisplayName(invocationIndex: Int): String =
"$displayName (attempt $currentAttempt/$maxAttempts" + if (delay > 0) {
", delayed ${delay}ms)"
} else {
")"
}
override fun getAdditionalExtensions(): List<Extension> = listOf(
RetryAfterTestExecutionCallback(store), CheckException(currentAttempt, maxAttempts))
}
internal class CheckException(private val currentAttempt: Int,
private val maxAttempts: Int) : TestExecutionExceptionHandler {
override fun handleTestExecutionException(context: ExtensionContext, throwable: Throwable) {
if (currentAttempt >= maxAttempts) {
throw throwable
} else {
throw TestAbortedException("Test attempt failed (attempt $currentAttempt/$maxAttempts)")
}
}
}
internal class RetryAfterTestExecutionCallback(private val store: ExtensionContext.Store) : AfterTestExecutionCallback {
override fun afterTestExecution(context: ExtensionContext) {
val testPassed = context.executionException
.filter { it.javaClass != AssumptionViolatedException::class.java }
.map { false }
.orElse(true)
store.put(TEST_PASSED, testPassed)
}
}
@Target(AnnotationTarget.FUNCTION, AnnotationTarget.PROPERTY_GETTER, AnnotationTarget.PROPERTY_SETTER, AnnotationTarget.CLASS, AnnotationTarget.FILE)
@Retention(AnnotationRetention.RUNTIME)
@TestTemplate
@ExtendWith(RetryExtension::class)
annotation class RetryTest(val maxAttempts: Int = 3) {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment