Skip to content

Instantly share code, notes, and snippets.

@marmbrus
Created September 10, 2015 21:35
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save marmbrus/f3d121a1bc5b6d6b57b9 to your computer and use it in GitHub Desktop.
Save marmbrus/f3d121a1bc5b6d6b57b9 to your computer and use it in GitHub Desktop.
Example of injecting custom planning strategies into Spark SQL.

First a disclaimer: This is an experimental API that exposes internals that are likely to change in between different Spark releases. As a result, most datasources should be written against the stable public API in org.apache.spark.sql.sources. We expose this mostly to get feedback on what optimizations we should add to the stable API in order to get the best performance out of data sources.

We'll start with a simple artificial data source that just returns ranges of consecutive integers.

/** A data source that returns ranges of consecutive integers in a column named `a`. */
case class SimpleRelation(
    start: Int, 
    end: Int)(
    @transient val sqlContext: SQLContext) 
  extends BaseRelation with TableScan {

  val schema = StructType('a.int :: Nil)
  def buildScan() = sqlContext.sparkContext.parallelize(start to end).map(Row(_))
}

Given this we can create tables:

sqlContext.baseRelationToDataFrame(SimpleRelation(1, 1)(sqlContext)).registerTempTable("smallTable")
sqlContext.baseRelationToDataFrame(SimpleRelation(1, 10000000)(sqlContext)).registerTempTable("bigTable")

However, doing a join is pretty slow since we need to shuffle the big table around for no reason:

sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect()
res3: Array[org.apache.spark.sql.Row] = Array([1,1])

This takes about 10 seconds on my cluster. Clearly we can do better. So let's define special physical operators for the case when we are inner joining two of these relations using equality. One will handle the case when there is no overlap and the other when there is. Physical operators must extend SparkPlan and must return an RDD[Row] containing the answer when execute() is called.

import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan

/** A join that just returns the pre-calculated overlap of two ranges of consecutive integers. */
case class OverlappingRangeJoin(leftOutput: Attribute, rightOutput: Attribute, start: Int, end: Int) extends SparkPlan {
  def output: Seq[Attribute] = leftOutput :: rightOutput :: Nil

  def execute(): org.apache.spark.rdd.RDD[Row] = {
    sqlContext.sparkContext.parallelize(start to end).map(i => Row(i, i))
  }

  def children: Seq[SparkPlan] = Nil
}

/** Used when a join is known to produce no results. */
case class EmptyJoin(output: Seq[Attribute]) extends SparkPlan {  
  def execute(): org.apache.spark.rdd.RDD[Row] = {
    sqlContext.sparkContext.emptyRDD
  }

  def children: Seq[SparkPlan] = Nil
}

/** Finds cases where two sets of consecutive integer ranges are inner joined on equality. */
object SmartSimpleJoin extends Strategy with Serializable {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    // Find inner joins between two SimpleRelations where the condition is equality.
    case Join(l @ LogicalRelation(left: SimpleRelation), r @ LogicalRelation(right: SimpleRelation), Inner, Some(EqualTo(a, b))) =>
      // Check if the join condition is comparing `a` from each relation.
      if (a == l.output.head && b == r.output.head || a == r.output.head && b == l.output.head) {
        if ((left.start <= right.end) && (left.end >= right.start)) {
          OverlappingRangeJoin(
            l.output.head,
            r.output.head,
            math.max(left.start, right.start),
            math.min(left.end, right.end)) :: Nil
        } else {
          // Ranges don't overlap, join will be empty
          EmptyJoin(l.output.head :: r.output.head :: Nil) :: Nil
        }
      } else {
        // Join isn't between the the columns output...
        // Let's just let the query planner handle this.
        Nil
      }      
    case _ => Nil // Return an empty list if we don't know how to handle this plan.
  }
}

We can then add these strategies to the query planner through the experimental hook. Added strategies take precedence over built-in ones.

// Add the strategy to the query planner.
sqlContext.experimental.extraStrategies = SmartSimpleJoin :: Nil

sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect()
res4: Array[org.apache.spark.sql.Row] = Array([1,1])

Now our join returns in < 1 second. For more advanced matching of joins and their conditions you should look at the patterns that are available, and the built-in join strategies. Let me know if you have any questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment