Skip to content

Instantly share code, notes, and snippets.

@rstrickland
Created November 4, 2015 13:48
Show Gist options
  • Save rstrickland/3e56cf1a2a8817fe8c25 to your computer and use it in GitHub Desktop.
Save rstrickland/3e56cf1a2a8817fe8c25 to your computer and use it in GitHub Desktop.
LocalNodeFirstLoadBalancingPolicy
package com.datastax.spark.connector.cql
import java.nio.ByteBuffer
import java.util.{Iterator => JIterator, Collection => JCollection}
import com.datastax.driver.core.policies.LoadBalancingPolicy
import com.datastax.driver.core._
import java.net.{InetAddress, NetworkInterface}
import scala.collection.JavaConversions._
import scala.util.Random
import org.apache.spark.Logging
/** Selects local node first and then nodes in local DC in random order. Never selects nodes from other DCs.
* For writes, if a statement has a routing key set, this LBP is token aware - it prefers the nodes which
* are replicas of the computed token to the other nodes. */
class LocalNodeFirstLoadBalancingPolicy(contactPoints: Set[InetAddress], localDC: Option[String] = None,
shuffleReplicas: Boolean = true) extends LoadBalancingPolicy with Logging {
import LocalNodeFirstLoadBalancingPolicy._
private var nodes = Set.empty[Host]
private var dcToUse = ""
private val random = new Random
private var clusterMetadata: Metadata = _
override def distance(host: Host): HostDistance =
if (host.getDatacenter == dcToUse) {
sameDCHostDistance(host)
} else {
// this insures we keep remote hosts out of our list entirely, even when we get notified of newly joined nodes
HostDistance.IGNORED
}
override def init(cluster: Cluster, hosts: JCollection[Host]) {
nodes = hosts.toSet
// use explicitly set DC if available, otherwise see if all contact points have same DC
// if so, use that DC; if not, throw an error
dcToUse = localDC match {
case Some(local) => local
case None =>
val dcList = dcs(nodesInTheSameDC(contactPoints, hosts.toSet))
if (dcList.size == 1)
dcList.head
else
throw new IllegalArgumentException(s"Contact points contain multiple data centers: ${dcList.mkString(", ")}")
}
clusterMetadata = cluster.getMetadata
}
private def tokenUnawareQueryPlan(query: String, statement: Statement): JIterator[Host] = {
sortNodesByStatusAndProximity(contactPoints, nodes).iterator
}
private def findReplicas(keyspace: String, partitionKey: ByteBuffer): Set[Host] = {
clusterMetadata.getReplicas(Metadata.quote(keyspace), partitionKey).toSet
.filter(host => host.isUp && distance(host) != HostDistance.IGNORED)
}
private def tokenAwareQueryPlan(keyspace: String, statement: Statement): JIterator[Host] = {
assert(keyspace != null)
assert(statement.getRoutingKey != null)
val replicas = findReplicas(keyspace, statement.getRoutingKey)
val (localReplica, otherReplicas) = replicas.partition(isLocalHost)
lazy val maybeShuffled = if (shuffleReplicas) random.shuffle(otherReplicas.toIndexedSeq) else otherReplicas
lazy val otherHosts = tokenUnawareQueryPlan(keyspace, statement).toIterator
.filter(host => !replicas.contains(host) && distance(host) != HostDistance.IGNORED)
(localReplica.iterator #:: maybeShuffled.iterator #:: otherHosts #:: Stream.empty).flatten.iterator
}
override def newQueryPlan (loggedKeyspace: String, statement: Statement): JIterator[Host] = {
val keyspace = if (statement.getKeyspace == null) loggedKeyspace else statement.getKeyspace
if (statement.getRoutingKey == null || keyspace == null)
tokenUnawareQueryPlan(keyspace, statement)
else
tokenAwareQueryPlan(keyspace, statement)
}
override def onAdd(host: Host) {
// The added host might be a "better" version of a host already in the set.
// The nodes added in the init call don't have DC and rack set.
// Therefore we want to really replace the object now, to get full information on DC:
nodes -= host
nodes += host
logInfo(s"Added host ${host.getAddress.getHostAddress} (${host.getDatacenter})")
}
override def onRemove(host: Host) {
nodes -= host
logInfo(s"Removed host ${host.getAddress.getHostAddress} (${host.getDatacenter})")
}
override def close() = { }
override def onUp(host: Host) = { }
override def onDown(host: Host) = { }
private def sameDCHostDistance(host: Host) =
if (isLocalHost(host))
HostDistance.LOCAL
else
HostDistance.REMOTE
private def dcs(hosts: Set[Host]) =
hosts.filter(_.getDatacenter != null).map(_.getDatacenter).toSet
}
object LocalNodeFirstLoadBalancingPolicy {
private val random = new Random
private val localAddresses =
NetworkInterface.getNetworkInterfaces.flatMap(_.getInetAddresses).toSet
/** Returns true if given host is local host */
def isLocalHost(host: Host): Boolean = {
val hostAddress = host.getAddress
hostAddress.isLoopbackAddress || localAddresses.contains(hostAddress)
}
/** Finds the DCs of the contact points and returns hosts in those DC(s) from `allHosts`.
* It guarantees to return at least the hosts pointed by `contactPoints`, even if their
* DC information is missing. Other hosts with missing DC information are not considered.*/
def nodesInTheSameDC(contactPoints: Set[InetAddress], allHosts: Set[Host]): Set[Host] = {
val contactNodes = allHosts.filter(h => contactPoints.contains(h.getAddress))
val contactDCs = contactNodes.map(_.getDatacenter).filter(_ != null).toSet
contactNodes ++ allHosts.filter(h => contactDCs.contains(h.getDatacenter))
}
/** Sorts nodes in the following order:
* 1. live nodes in the same DC as `contactPoints` starting with localhost if up
* 2. down nodes in the same DC as `contactPoints`
*
* Nodes within a group are ordered randomly.
* Nodes from other DCs are not included. */
def sortNodesByStatusAndProximity(contactPoints: Set[InetAddress], hostsToSort: Set[Host]): Seq[Host] = {
val nodesInLocalDC = nodesInTheSameDC(contactPoints, hostsToSort)
val (allUpHosts, downHosts) = nodesInLocalDC.partition(_.isUp)
val (localHost, upHosts) = allUpHosts.partition(isLocalHost)
localHost.toSeq ++ random.shuffle(upHosts.toSeq) ++ random.shuffle(downHosts.toSeq)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment