Created
September 28, 2018 16:38
-
-
Save ahoy-jon/fd0ef59a6e5ddec439f84d002034a934 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package utils | |
import org.scalatest.FunSuite | |
import scalaz.zio.{IO, RTS, Ref} | |
import utils.CircuitBreaker.{Closed, Open, Status} | |
final class CircuitBreaker[+BreakingFailure](ref: Ref[CircuitBreakerStatus], | |
initStatus: CircuitBreakerStatus, | |
whenOpen: IO[BreakingFailure, Nothing]) { | |
def protect[E >: BreakingFailure, A](io: IO[E, A]): IO[E, A] = { | |
for { | |
status <- this.status | |
x <- { | |
if (status == Closed) { | |
io.redeem(e => ref.update(_.incFail) *> IO.fail(e), b => ref.set(initStatus) *> IO.point(b)) | |
} else whenOpen | |
} | |
} yield { | |
x | |
} | |
} | |
def status: IO[Nothing, Status] = ref.get.map(_.status) | |
def nbRemainingFailure: IO[Nothing, Long] = ref.get.map(_.nbRemainingFailure) | |
} | |
case class CircuitBreakerStatus(nbRemainingFailure: Long, status: Status = Closed) { | |
def incFail: CircuitBreakerStatus = { | |
if (nbRemainingFailure > 1) { | |
copy(nbRemainingFailure = nbRemainingFailure - 1) | |
} else { | |
copy(nbRemainingFailure = nbRemainingFailure - 1, status = Open) | |
} | |
} | |
} | |
object CircuitBreaker { | |
sealed trait Status | |
object Closed extends Status | |
object Open extends Status | |
def apply[BreakingFailure](nbConsecutiveFailure: Long, | |
whenOpen: IO[BreakingFailure, Nothing]): IO[Nothing, CircuitBreaker[BreakingFailure]] = { | |
val initStatus: CircuitBreakerStatus = CircuitBreakerStatus(nbConsecutiveFailure) | |
Ref(initStatus).map(ref => { | |
new CircuitBreaker[BreakingFailure](ref, initStatus, whenOpen) | |
}) | |
} | |
} | |
class IOtoTryTest extends FunSuite with RTS { | |
test("circuit breaker") { | |
sealed trait Error | |
case object Failed extends Error | |
case object Success | |
case object CircuitOpen extends Error | |
unsafeRun(for { | |
circuit <- CircuitBreaker(2, IO.fail(CircuitOpen)) | |
v <- circuit.protect(IO.fail(Failed)).attempt | |
nbR1 <- circuit.nbRemainingFailure | |
w <- circuit.protect(IO.point(Success)).attempt | |
nbR2 <- circuit.nbRemainingFailure | |
_ <- circuit.protect(IO.fail(Failed)).attempt | |
nbR3 <- circuit.nbRemainingFailure | |
_ <- circuit.protect(IO.fail(Failed)).attempt | |
nbR4 <- circuit.nbRemainingFailure | |
status <- circuit.status | |
x <- circuit.protect(IO.point(Success)).attempt | |
nbR5 <- circuit.nbRemainingFailure | |
y <- circuit.protect(IO.fail(Failed)).attempt | |
} yield { | |
assert(v == Left(Failed), "v") | |
assert(nbR1 == 1) | |
assert(w == Right(Success), "w") | |
assert(nbR2 == 2) | |
assert(nbR3 == 1) | |
assert(nbR4 == 0) | |
assert(status == Open) | |
assert(nbR5 == 0) | |
assert(x == Left(CircuitOpen), "x") | |
assert(y == Left(CircuitOpen), "y") | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment