Skip to content

Instantly share code, notes, and snippets.

@kongo2002
Last active December 9, 2021 22:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kongo2002/6c32adf23c7231d20456496eb83b9cc2 to your computer and use it in GitHub Desktop.
Save kongo2002/6c32adf23c7231d20456496eb83b9cc2 to your computer and use it in GitHub Desktop.
circuit breaker
package de.kongo2002
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.withTimeout
/**
* Basic circuit breaker implementation based on the given
* [CircuitBreakerConfig].
*
* The circuit breaker will throw [CircuitBreakerException]
* while it is in [State.Open] state, which is true as long
* as the configured error threshold is exceeded within
* the [CircuitBreakerConfig.resetIntervalMillis] duration.
*
* @see CircuitBreakerConfig
* @see CircuitBreakerException
*/
class CircuitBreaker(private val config: CircuitBreakerConfig) {
private val lock = ReentrantReadWriteLock()
private val read = lock.readLock()
private val write = lock.writeLock()
private var _errors = 0
private var _lastError = 0L
private val state: State
get() {
read.lock()
val errors = _errors
val last = _lastError
read.unlock()
return if (errors >= config.errorThreshold && ((System.nanoTime() - last) < config.resetIntervalMillis * 1_000_000))
State.Open
else if (errors >= config.errorThreshold)
State.HalfOpen
else
State.Closed
}
/**
* Run an arbitrary function in the scope of the circuit breaker.
*
* @throws CircuitBreakerException thrown as long as the circuit breaker is
* in [State.Open] state
*/
suspend fun <T> run(func: suspend CoroutineScope.() -> T): T {
return when (val s = state) {
State.Open ->
throw CircuitBreakerException("circuit breaker is open")
State.HalfOpen, State.Closed ->
timeout(s, func)
}
}
private suspend fun <T> timeout(originalState: State, func: suspend CoroutineScope.() -> T): T {
try {
val result = withTimeout(config.timeoutMillis, func)
if (originalState != State.Closed) {
reset()
}
return result
} catch (ex: Throwable) {
if (config.triggerPredicate(ex)) {
fail()
}
throw ex
}
}
private fun fail() {
write.lock()
_errors += 1
_lastError = System.nanoTime()
write.unlock()
}
private fun reset() {
write.lock()
_errors = 0
_lastError = 0L
write.unlock()
}
enum class State {
Closed,
HalfOpen,
Open
}
}
/**
* Configuration object for the [CircuitBreaker]
*
* @see CircuitBreaker
* @see TimeoutCancellationException
*
* @property timeoutMillis duration in milliseconds until a function
* run via [CircuitBreaker.run] is timed out
* @property resetIntervalMillis duration (in milliseconds) until
* the "open" circuit breaker is transitioned into "half-open" state
* @property errorThreshold number of functions run via [CircuitBreaker.run]
* may fail until the circuit breaker is opened
* @property triggerPredicate predicate function to determine which
* exceptions shall trigger the circuit breaker (defaults to [TimeoutCancellationException])
*/
data class CircuitBreakerConfig(
val timeoutMillis: Long,
val resetIntervalMillis: Long,
val errorThreshold: Int,
val triggerPredicate: (Throwable) -> Boolean = ::isTimeout,
) {
companion object {
/**
* Default predicate for the [CircuitBreaker] to determine what
* exceptions should be considered for eventually transitioning
* into [CircuitBreaker.State.Open]
*/
fun isTimeout(ex: Throwable): Boolean {
return when (ex) {
is TimeoutCancellationException -> true
else -> false
}
}
}
}
/**
* Exception thrown by the [CircuitBreaker] while it is in [CircuitBreaker.State.Open]
*/
class CircuitBreakerException(message: String): RuntimeException(message) {
override fun fillInStackTrace(): Throwable = this
}
package de.kongo2002
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import java.lang.IllegalStateException
import kotlinx.coroutines.delay
class CircuitBreakerSpec : FunSpec({
test("triggers on timeouts") {
var exceptionCount = 0
var circuitBreakerExceptions = 0
val breaker =
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5))
for (i in 0 until 7) {
try {
breaker.run {
delay(200)
}
} catch (ex: CircuitBreakerException) {
circuitBreakerExceptions++
} catch (ex: Throwable) {
exceptionCount++
}
}
exceptionCount.shouldBe(5)
circuitBreakerExceptions.shouldBe(2)
}
test("does not trigger on non-timeout exceptions") {
var exceptionCount = 0
var circuitBreakerExceptions = 0
val breaker =
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5))
for (i in 0 until 7) {
try {
breaker.run {
throw IllegalStateException("EXPECTED")
}
} catch (ex: CircuitBreakerException) {
circuitBreakerExceptions++
} catch (ex: Throwable) {
exceptionCount++
}
}
exceptionCount.shouldBe(7)
circuitBreakerExceptions.shouldBe(0)
}
test("triggers on custom exceptions") {
var exceptionCount = 0
var circuitBreakerExceptions = 0
val breaker =
CircuitBreaker(
CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5,
triggerPredicate = { ex -> ex is IllegalStateException })
)
for (i in 0 until 7) {
try {
breaker.run {
throw IllegalStateException("EXPECTED")
}
} catch (ex: CircuitBreakerException) {
circuitBreakerExceptions++
} catch (ex: Throwable) {
exceptionCount++
}
}
exceptionCount.shouldBe(5)
circuitBreakerExceptions.shouldBe(2)
}
test("correctly passes success results") {
val breaker =
CircuitBreaker(CircuitBreakerConfig(timeoutMillis = 50, resetIntervalMillis = 10_000, errorThreshold = 5))
for (i in 0 until 10) {
val expected = i * 2
val result = breaker.run { expected }
result.shouldBe(expected)
}
}
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment