Skip to content

Instantly share code, notes, and snippets.

@crakjie
Last active September 5, 2016 15:34
Show Gist options
  • Save crakjie/d6066785d81ad8df828b to your computer and use it in GitHub Desktop.
Save crakjie/d6066785d81ad8df828b to your computer and use it in GitHub Desktop.
CassandraLeftJoinRDD
package com.datastax.spark.connector.rdd
import org.apache.spark.metrics.InputMetricsUpdater
import com.datastax.driver.core.Session
import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.rdd.reader._
import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, RangePredicate}
import com.datastax.spark.connector.util.{CountingIterator, CqlWhereParser}
import com.datastax.spark.connector.writer._
import com.datastax.spark.connector.util.Quote._
import org.apache.spark.rdd.RDD
import org.apache.spark.{Partition, TaskContext}
import scala.reflect.ClassTag
/**
* An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified
* Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take
* advantage of RDDs that have been partitioned with the
* [[com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner]]
*
* @tparam L item type on the left side of the join (any RDD)
* @tparam R item type on the right side of the join (fetched from Cassandra)
*/
class CassandraLeftJoinRDD[L, R] private[connector](
left: RDD[L],
val keyspaceName: String,
val tableName: String,
val connector: CassandraConnector,
val columnNames: ColumnSelector = AllColumns,
val joinColumns: ColumnSelector = PartitionKeyColumns,
val where: CqlWhereClause = CqlWhereClause.empty,
val limit: Option[Long] = None,
val clusteringOrder: Option[ClusteringOrder] = None,
val readConf: ReadConf = ReadConf())(
implicit
val leftClassTag: ClassTag[L],
val rightClassTag: ClassTag[R],
@transient val rowWriterFactory: RowWriterFactory[L],
@transient val rowReaderFactory: RowReaderFactory[R])
extends CassandraRDD[(L, Seq[R])](left.sparkContext, left.dependencies)
with CassandraTableRowReaderProvider[R] {
override type Self = CassandraJoinRDD[L, R]
override protected val classTag = rightClassTag
override protected def copy(
columnNames: ColumnSelector = columnNames,
where: CqlWhereClause = where,
limit: Option[Long] = limit,
clusteringOrder: Option[ClusteringOrder] = None,
readConf: ReadConf = readConf,
connector: CassandraConnector = connector): Self = {
new CassandraJoinRDD[L, R](
left = left,
keyspaceName = keyspaceName,
tableName = tableName,
connector = connector,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf)
}
lazy val joinColumnNames: Seq[ColumnRef] = joinColumns match {
case AllColumns => throw new IllegalArgumentException(
"Unable to join against all columns in a Cassandra Table. Only primary key columns allowed.")
case PartitionKeyColumns =>
tableDef.partitionKey.map(col => col.columnName: ColumnRef)
case SomeColumns(cs @ _*) =>
checkColumnsExistence(cs)
cs.map {
case c: ColumnRef => c
case _ => throw new IllegalArgumentException(
"Unable to join against unnamed columns. No CQL Functions allowed.")
}
}
override def count(): Long = {
columnNames match {
case SomeColumns(_) =>
logWarning("You are about to count rows but an explicit projection has been specified.")
case _ =>
}
val counts =
new CassandraJoinRDD[L, Long](
left = left,
connector = connector,
keyspaceName = keyspaceName,
tableName = tableName,
columnNames = SomeColumns(RowCountRef),
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf= readConf)
counts.map(_._2).reduce(_ + _)
}
/** This method will create the RowWriter required before the RDD is serialized.
* This is called during getPartitions */
protected def checkValidJoin(): Seq[ColumnRef] = {
val partitionKeyColumnNames = tableDef.partitionKey.map(_.columnName).toSet
val primaryKeyColumnNames = tableDef.primaryKey.map(_.columnName).toSet
val colNames = joinColumnNames.map(_.columnName).toSet
// Initialize RowWriter and Query to be used for accessing Cassandra
rowWriter.columnNames
singleKeyCqlQuery.length
def checkSingleColumn(column: ColumnRef): Unit = {
require(
primaryKeyColumnNames.contains(column.columnName),
s"Can't pushdown join on column $column because it is not part of the PRIMARY KEY")
}
// Make sure we have all of the clustering indexes between the 0th position and the max requested
// in the join:
val chosenClusteringColumns = tableDef.clusteringColumns
.filter(cc => colNames.contains(cc.columnName))
if (!tableDef.clusteringColumns.startsWith(chosenClusteringColumns)) {
val maxCol = chosenClusteringColumns.last
val maxIndex = maxCol.componentIndex.get
val requiredColumns = tableDef.clusteringColumns.takeWhile(_.componentIndex.get <= maxIndex)
val missingColumns = requiredColumns.toSet -- chosenClusteringColumns.toSet
throw new IllegalArgumentException(
s"Can't pushdown join on column $maxCol without also specifying [ $missingColumns ]")
}
val missingPartitionKeys = partitionKeyColumnNames -- colNames
require(
missingPartitionKeys.isEmpty,
s"Can't join without the full partition key. Missing: [ $missingPartitionKeys ]")
joinColumnNames.foreach(checkSingleColumn)
joinColumnNames
}
lazy val rowWriter = implicitly[RowWriterFactory[L]].rowWriter(
tableDef, joinColumnNames.toIndexedSeq)
def on(joinColumns: ColumnSelector): CassandraJoinRDD[L, R] = {
new CassandraJoinRDD[L, R](
left = left,
connector = connector,
keyspaceName = keyspaceName,
tableName = tableName,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf)
}
//We need to make sure we get selectedColumnRefs before serialization so that our RowReader is
//built
lazy val singleKeyCqlQuery: (String) = {
val whereClauses = where.predicates.flatMap(CqlWhereParser.parse)
val joinColumns = joinColumnNames.map(_.columnName)
val joinColumnPredicates = whereClauses.collect {
case EqPredicate(c, _) if joinColumns.contains(c) => c
case InPredicate(c) if joinColumns.contains(c) => c
case InListPredicate(c, _) if joinColumns.contains(c) => c
case RangePredicate(c, _, _) if joinColumns.contains(c) => c
}.toSet
require(
joinColumnPredicates.isEmpty,
s"""Columns specified in both the join on clause and the where clause.
|Partition key columns are always part of the join clause.
|Columns in both: ${joinColumnPredicates.mkString(", ")}""".stripMargin
)
logDebug("Generating Single Key Query Prepared Statement String")
logDebug(s"SelectedColumns : $selectedColumnRefs -- JoinColumnNames : $joinColumnNames")
val columns = selectedColumnRefs.map(_.cql).mkString(", ")
val joinWhere = joinColumnNames.map(_.columnName).map(name => s"${quote(name)} = :$name")
val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("")
val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
val filter = (where.predicates ++ joinWhere).mkString(" AND ")
val quotedKeyspaceName = quote(keyspaceName)
val quotedTableName = quote(tableName)
val query =
s"SELECT $columns " +
s"FROM $quotedKeyspaceName.$quotedTableName " +
s"WHERE $filter $limitClause $orderBy"
logDebug(s"Query : $query")
query
}
/**
* When computing a CassandraPartitionKeyRDD the data is selected via single CQL statements
* from the specified C* Keyspace and Table. This will be preformed on whatever data is
* available in the previous RDD in the chain.
*/
override def compute(split: Partition, context: TaskContext): Iterator[(L, Seq[R])] = {
val session = connector.openSession()
implicit val pv = protocolVersion(session)
val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(consistencyLevel)
val bsb = new BoundStatementBuilder[L](rowWriter, stmt, pv, where.values)
val metricsUpdater = InputMetricsUpdater(context, readConf)
val rowIterator = fetchIterator(session, bsb, left.iterator(split, context))
val countingIterator = new CountingIterator(rowIterator, limit)
context.addTaskCompletionListener { (context) =>
val duration = metricsUpdater.finish() / 1000000000d
logDebug(
f"Fetched ${countingIterator.count} rows " +
f"from $keyspaceName.$tableName " +
f"for partition ${split.index} in $duration%.3f s.")
session.close()
}
countingIterator
}
private def fetchIterator(
session: Session,
bsb: BoundStatementBuilder[L],
lastIt: Iterator[L]): Iterator[(L, Seq[R])] = {
val columnNamesArray = selectedColumnRefs.map(_.selectedAs).toArray
implicit val pv = protocolVersion(session)
for (leftSide <- lastIt ) yield {
val rightSide = {
val rs = session.execute(bsb.bind(leftSide))
val iterator = new PrefetchingResultSetIterator(rs, fetchSize)
iterator.map(rowReader.read(_, columnNamesArray))
})
(leftSide, rightSide)
}
}
override protected def getPartitions: Array[Partition] = {
verify()
checkValidJoin()
left.partitions
}
override def getPreferredLocations(split: Partition): Seq[String] = left.preferredLocations(split)
override def toEmptyCassandraRDD: EmptyCassandraRDD[(L, Seq[R])] =
new EmptyCassandraRDD[(L, Seq[R])](
sc = left.sparkContext,
keyspaceName = keyspaceName,
tableName = tableName,
columnNames = columnNames,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf)
}
@victorvilladev
Copy link

protocolVersion is not getting recognized as a valid function.
Am I missing a lib?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment