Last active
December 9, 2021 22:46
-
-
Save kongo2002/6c32adf23c7231d20456496eb83b9cc2 to your computer and use it in GitHub Desktop.
circuit breaker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.kongo2002 | |
import java.util.concurrent.locks.ReentrantReadWriteLock | |
import kotlinx.coroutines.CoroutineScope | |
import kotlinx.coroutines.TimeoutCancellationException | |
import kotlinx.coroutines.withTimeout | |
/** | |
* Basic circuit breaker implementation based on the given | |
* [CircuitBreakerConfig]. | |
* | |
* The circuit breaker will throw [CircuitBreakerException] | |
* while it is in [State.Open] state, which is true as long | |
* as the configured error threshold is exceeded within | |
* the [CircuitBreakerConfig.resetIntervalMillis] duration. | |
* | |
* @see CircuitBreakerConfig | |
* @see CircuitBreakerException | |
*/ | |
class CircuitBreaker(private val config: CircuitBreakerConfig) { | |
private val lock = ReentrantReadWriteLock() | |
private val read = lock.readLock() | |
private val write = lock.writeLock() | |
private var _errors = 0 | |
private var _lastError = 0L | |
private val state: State | |
get() { | |
read.lock() | |
val errors = _errors | |
val last = _lastError | |
read.unlock() | |
return if (errors >= config.errorThreshold && ((System.nanoTime() - last) < config.resetIntervalMillis * 1_000_000)) | |
State.Open | |
else if (errors >= config.errorThreshold) | |
State.HalfOpen | |
else | |
State.Closed | |
} | |
/** | |
* Run an arbitrary function in the scope of the circuit breaker. | |
* | |
* @throws CircuitBreakerException thrown as long as the circuit breaker is | |
* in [State.Open] state | |
*/ | |
suspend fun <T> run(func: suspend CoroutineScope.() -> T): T { | |
return when (val s = state) { | |
State.Open -> | |
throw CircuitBreakerException("circuit breaker is open") | |
State.HalfOpen, State.Closed -> | |
timeout(s, func) | |
} | |
} | |
private suspend fun <T> timeout(originalState: State, func: suspend CoroutineScope.() -> T): T { | |
try { | |
val result = withTimeout(config.timeoutMillis, func) | |
if (originalState != State.Closed) { | |
reset() | |
} | |
return result | |
} catch (ex: Throwable) { | |
if (config.triggerPredicate(ex)) { | |
fail() | |
} | |
throw ex | |
} | |
} | |
private fun fail() { | |
write.lock() | |
_errors += 1 | |
_lastError = System.nanoTime() | |
write.unlock() | |
} | |
private fun reset() { | |
write.lock() | |
_errors = 0 | |
_lastError = 0L | |
write.unlock() | |
} | |
enum class State { | |
Closed, | |
HalfOpen, | |
Open | |
} | |
} | |
/** | |
* Configuration object for the [CircuitBreaker] | |
* | |
* @see CircuitBreaker | |
* @see TimeoutCancellationException | |
* | |
* @property timeoutMillis duration in milliseconds until a function | |
* run via [CircuitBreaker.run] is timed out | |
* @property resetIntervalMillis duration (in milliseconds) until | |
* the "open" circuit breaker is transitioned into "half-open" state | |
* @property errorThreshold number of functions run via [CircuitBreaker.run] | |
* may fail until the circuit breaker is opened | |
* @property triggerPredicate predicate function to determine which | |
* exceptions shall trigger the circuit breaker (defaults to [TimeoutCancellationException]) | |
*/ | |
data class CircuitBreakerConfig( | |
val timeoutMillis: Long, | |
val resetIntervalMillis: Long, | |
val errorThreshold: Int, | |
val triggerPredicate: (Throwable) -> Boolean = ::isTimeout, | |
) { | |
companion object { | |
/** | |
* Default predicate for the [CircuitBreaker] to determine what | |
* exceptions should be considered for eventually transitioning | |
* into [CircuitBreaker.State.Open] | |
*/ | |
fun isTimeout(ex: Throwable): Boolean { | |
return when (ex) { | |
is TimeoutCancellationException -> true | |
else -> false | |
} | |
} | |
} | |
} | |
/** | |
* Exception thrown by the [CircuitBreaker] while it is in [CircuitBreaker.State.Open] | |
*/ | |
class CircuitBreakerException(message: String): RuntimeException(message) { | |
override fun fillInStackTrace(): Throwable = this | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.kongo2002 | |
import io.kotest.core.spec.style.FunSpec | |
import io.kotest.matchers.shouldBe | |
import java.lang.IllegalStateException | |
import kotlinx.coroutines.delay | |
class CircuitBreakerSpec : FunSpec({ | |
test("triggers on timeouts") { | |
var exceptionCount = 0 | |
var circuitBreakerExceptions = 0 | |
val breaker = | |
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5)) | |
for (i in 0 until 7) { | |
try { | |
breaker.run { | |
delay(200) | |
} | |
} catch (ex: CircuitBreakerException) { | |
circuitBreakerExceptions++ | |
} catch (ex: Throwable) { | |
exceptionCount++ | |
} | |
} | |
exceptionCount.shouldBe(5) | |
circuitBreakerExceptions.shouldBe(2) | |
} | |
test("does not trigger on non-timeout exceptions") { | |
var exceptionCount = 0 | |
var circuitBreakerExceptions = 0 | |
val breaker = | |
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5)) | |
for (i in 0 until 7) { | |
try { | |
breaker.run { | |
throw IllegalStateException("EXPECTED") | |
} | |
} catch (ex: CircuitBreakerException) { | |
circuitBreakerExceptions++ | |
} catch (ex: Throwable) { | |
exceptionCount++ | |
} | |
} | |
exceptionCount.shouldBe(7) | |
circuitBreakerExceptions.shouldBe(0) | |
} | |
test("triggers on custom exceptions") { | |
var exceptionCount = 0 | |
var circuitBreakerExceptions = 0 | |
val breaker = | |
CircuitBreaker( | |
CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5, | |
triggerPredicate = { ex -> ex is IllegalStateException }) | |
) | |
for (i in 0 until 7) { | |
try { | |
breaker.run { | |
throw IllegalStateException("EXPECTED") | |
} | |
} catch (ex: CircuitBreakerException) { | |
circuitBreakerExceptions++ | |
} catch (ex: Throwable) { | |
exceptionCount++ | |
} | |
} | |
exceptionCount.shouldBe(5) | |
circuitBreakerExceptions.shouldBe(2) | |
} | |
test("correctly passes success results") { | |
val breaker = | |
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5)) | |
for (i in 0 until 10) { | |
val expected = i * 2 | |
val result = breaker.run { expected } | |
result.shouldBe(expected) | |
} | |
} | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment