Skip to content

Instantly share code, notes, and snippets.

@nelanka
Last active January 2, 2018 17:03
Show Gist options
  • Save nelanka/891e9ac82fc83a6ab561 to your computer and use it in GitHub Desktop.
Save nelanka/891e9ac82fc83a6ab561 to your computer and use it in GitHub Desktop.
Shard Region Graceful Shutdown Experiment
package net.nelanka.akka_sandbox.cluster_sharding.experiments
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success}
import akka.actor.{Actor, ActorRef, ActorSystem, FSM, Props, Terminated}
import akka.cluster.Cluster
import akka.cluster.sharding.ShardCoordinator.LeastShardAllocationStrategy
import akka.cluster.sharding.{ClusterSharding, ClusterShardingSettings, ShardRegion}
import akka.testkit.SocketUtil
import com.typesafe.config.ConfigFactory
/**
* Gracefully shut down a cluster sharding node
*
* Ref: http://doc.akka.io/docs/akka/2.4.2/scala/cluster-sharding.html#Graceful_Shutdown
*
* 1. jvm gets the shutdown signal
* 2. node tells all local shard regions to shut down gracefully
* 3. node leaves cluster
* 4. node gives singletons a grace period to migrate
* 5. actor system is shutdown
* 6. jvm exits
*/
object FSMNodeShutdownCoordinator {
// Shutdown options
final case class NodeShutdownOpts(
nodeShutdownSingletonMigrationDelay: FiniteDuration,
actorSystemShutdownTimeout: FiniteDuration
)
// Messaging Protocol
sealed trait NodeShutdownProtocol
final case class RegisterRegions(regions: Set[ActorRef]) extends NodeShutdownProtocol
case object StartNodeShutdown extends NodeShutdownProtocol
case object NodeLeftCluster extends NodeShutdownProtocol
case object TerminateNode extends NodeShutdownProtocol
// FSM State
sealed trait State
case object AwaitNodeShutdownInitiation extends State
case object AwaitShardRegionsShutdown extends State
case object AwaitClusterExit extends State
case object AwaitNodeTerminationSignal extends State
// FSM Data
sealed trait Data
final case class ManagedRegions(shardRegions: Set[ActorRef]) extends Data
// Register ShardRegions for graceful shutdown
def register(
shutdownOpts: NodeShutdownOpts,
shardRegions: Set[ActorRef]
)(implicit system: ActorSystem): Unit = {
// 1 - jvm gets the shutdown signal
sys.addShutdownHook {
println("NodeShutdownCoordinator> Initiating node shutdown")
val nodeShutdownActor = system.actorOf(Props(
new NodeGracefulShutdownCoordinator(shutdownOpts)))
nodeShutdownActor ! StartNodeShutdown(shardRegions)
println("NodeShutdownCoordinator> Awaiting node shutdown ...")
Await.result(system.whenTerminated, shutdownOpts.actorSystemShutdownTimeout)
}
}
}
import net.nelanka.akka_sandbox.cluster_sharding.experiments.NodeShutdownCoordinator._
class NodeShutdownCoordinator(
shutdownOpts: NodeShutdownOpts
)(implicit system: ActorSystem) extends FSM[State, Data] {
startWith(AwaitNodeShutdownInitiation, ManagedRegions(Set.empty[ActorRef]))
// Wait for a ShardRegions to be registered
when(AwaitNodeShutdownInitiation) {
case Event(StartNodeShutdown(shardRegions), _) =>
println("NodeGracefulShutdownCoordinator> Notifying local shard regions to shut down")
// 2 - node tells all local shard regions to shut down gracefully
shardRegions.foreach { shardRegion =>
context.watch(shardRegion)
shardRegion ! ShardRegion.GracefulShutdown
}
println(
s"NodeShutdownCoordinator> Waiting for ${shardRegions.size} " +
"local shard region(s) to shut down ...")
goto(AwaitShardRegionsShutdown) using ManagedRegions(shardRegions)
}
// 3 - node leaves cluster
when(AwaitShardRegionsShutdown) {
case Event(Terminated(actor), ManagedRegions(shardRegions)) =>
println("NodeShutdownCoordinator> Shard region terminated")
if (shardRegions.contains(actor)) {
val remainingRegions = shardRegions - actor
if (remainingRegions.isEmpty) {
println("NodeShutdownCoordinator> All local shard region terminated.")
println("NodeShutdownCoordinator> Waiting on cluster exit ...")
val cluster = Cluster(context.system)
cluster.registerOnMemberRemoved(self ! NodeLeftCluster)
cluster.leave(cluster.selfAddress)
goto(AwaitClusterExit)
}
else {
println(
s"NodeShutdownCoordinator> Waiting for ${remainingRegions.size} " +
"local shard region(s) to shut down ...")
goto(AwaitShardRegionsShutdown) using ManagedRegions(remainingRegions)
}
}
else {
stay()
}
}
// 4 - node gives singletons a grace period to migrate - may not have to do this since
// we wait above for all the regions to exit. Where can the singleton live?
when(AwaitClusterExit) {
case Event(NodeLeftCluster, _) =>
import context.dispatcher
println("NodeShutdownCoordinator> Waiting on cluster singleton migration ...")
system.scheduler.scheduleOnce(
shutdownOpts.nodeShutdownSingletonMigrationDelay, self, TerminateNode)
goto(AwaitNodeTerminationSignal)
}
// 5 - actor system is shutdown
when(AwaitNodeTerminationSignal) {
case Event(TerminateNode, _) =>
println("NodeShutdownCoordinator> Terminating actor system ...")
// Important: this is NOT an Akka thread-pool (since those we're shutting down)
val ec = scala.concurrent.ExecutionContext.global
system.terminate().onComplete {
case Success(ex) =>
println("NodeShutdownCoordinator> ActorSystem shutdown complete, killing jvm")
System.exit(0)
case Failure(ex) =>
System.err.println(s"NodeShutdownCoordinator> Shutdown failed: $ex")
System.exit(-1)
}(ec) // Important: Use a non-Akka EC here
stop()
}
whenUnhandled {
case Event(msg, state) =>
println(s"NodeShutdownCoordinator> Received unhandled request $msg in state $stateName/$state")
stay()
}
initialize()
}
object ShardRegionGracefulShutdownExperiment extends App {
object MessageProcessorActor {
def props = Props[MessageProcessorActor]
val idExtractor: ShardRegion.ExtractEntityId = {
case msg => ("Region", msg)
}
val shardResolver: ShardRegion.ExtractShardId = {
case _ => "Shard"
}
val shardName = "MessageProcessor"
}
private class MessageProcessorActor extends Actor {
println(s"MessageProcessorActor> Construction")
override def receive: Receive = {
case msg @ Message(n) =>
println(s"MessageProcessorActor> Received $msg")
case EntityGracefulShutdown =>
println(s"MessageProcessorActor> Received EntityGracefulShutdown. Initiating graceful shutdown")
context.stop(self)
}
}
// Test parameters
val CLUSTER_NAME = "ClusterSystem"
val PORT = SocketUtil.temporaryServerAddress().getPort
// Messages
final case class Message(n: Int)
case object EntityGracefulShutdown
def genConfig(port: Int): String = {
s"""
|akka {
| actor {
| provider = "akka.cluster.ClusterActorRefProvider"
| warn-about-java-serializer-usage = off
| }
|
| cluster {
| seed-nodes = ["akka.tcp://$CLUSTER_NAME@127.0.0.1:$PORT"]
| }
|
| remote {
| netty.tcp {
| hostname = 127.0.0.1
| port = $port
| }
| }
|
| persistence.journal.plugin = "akka.persistence.journal.inmem"
| persistence.snapshot-store.plugin = "akka.persistence.snapshot-store.local"
|}
""".stripMargin
}
def startShardRegion(system: ActorSystem): ActorRef = {
val strategyConfig =
system.settings.config.getConfig("akka.cluster.sharding.least-shard-allocation-strategy")
val strategy = new LeastShardAllocationStrategy(
strategyConfig.getInt("rebalance-threshold"),
strategyConfig.getInt("max-simultaneous-rebalance"))
ClusterSharding(system).start(
typeName = MessageProcessorActor.shardName,
entityProps = MessageProcessorActor.props,
settings = ClusterShardingSettings(system),
extractEntityId = MessageProcessorActor.idExtractor,
extractShardId = MessageProcessorActor.shardResolver,
allocationStrategy = strategy,
handOffStopMessage = EntityGracefulShutdown
)
}
def createClusterShardingActorSystem(port: Int = 0): (ActorSystem, ActorRef) = {
val nodeConfig = genConfig(port)
println(s"Creating ActorSystem with config:\n$nodeConfig")
val config = ConfigFactory.parseString(nodeConfig).withFallback(ConfigFactory.load())
val system = ActorSystem(CLUSTER_NAME, config)
val shardRegion = startShardRegion(system)
(system, shardRegion)
}
implicit val (system, shardRegion) = createClusterShardingActorSystem(PORT)
NodeShutdownCoordinator.register(NodeShutdownOpts(
nodeShutdownSingletonMigrationDelay = 10 seconds,
actorSystemShutdownTimeout = 3 minutes),
Set(shardRegion))
println("Sleeping while waiting for the cluster to start ...")
Thread.sleep(5000)
println("Sending Message(1) to ShardRegion ... ")
shardRegion ! Message(1)
Thread.sleep(1000)
println("Simulating shutdown ...")
sys.runtime.exit(0)
Await.result(system.whenTerminated, 5 minutes)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment