Skip to content

Instantly share code, notes, and snippets.

@kisimple
Created August 9, 2017 02:30
Show Gist options
  • Save kisimple/0f1f28f2b8e515e48defa6b4c546dbcc to your computer and use it in GitHub Desktop.
Save kisimple/0f1f28f2b8e515e48defa6b4c546dbcc to your computer and use it in GitHub Desktop.
org.apache.spark.examples.sql.DefaultSource
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.sql
import java.util.Locale
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class DefaultSource extends RelationProvider {
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
ComplicatedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext.sparkSession)
}
}
case class ComplicatedScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation
with Logging
with PrunedFilteredScan
with AggregatedFilteredScan {
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(
StructField("a", IntegerType, nullable = true) ::
StructField("b", LongType, nullable = false) ::
StructField("c", StringType, nullable = false) ::
StructField("d", DoubleType, nullable = false) ::
StructField("e", DataTypes.createDecimalType(), nullable = false) ::
StructField("g", IntegerType, nullable = false) ::
StructField("f", FloatType, nullable = false) ::
StructField("i", ByteType, nullable = false) ::
StructField("j", ShortType, nullable = false) :: Nil)
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
def unhandled(filter: Filter): Boolean = {
filter match {
case EqualTo(col, v) => col == "b"
case EqualNullSafe(col, v) => col == "b"
case LessThan(col, v: Int) => col == "b"
case LessThanOrEqual(col, v: Int) => col == "b"
case GreaterThan(col, v: Int) => col == "b"
case GreaterThanOrEqual(col, v: Int) => col == "b"
case In(col, values) => col == "b"
case IsNull(col) => col == "b"
case IsNotNull(col) => col == "b"
case Not(pred) => unhandled(pred)
case And(left, right) => unhandled(left) || unhandled(right)
case Or(left, right) => unhandled(left) || unhandled(right)
case _ => false
}
}
filters.filter(unhandled)
}
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val rowBuilders = requiredColumns.map {
case "a" => (i: Int) => Seq(i)
case "b" => (i: Int) => Seq(i * 2)
case "c" => (i: Int) =>
val c = (i - 1 + 'a').toChar.toString
Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5)
}
// Predicate test on integer column
def translateFilterOnA(filter: Filter): Int => Boolean = filter match {
case EqualTo("a", v) => (a: Int) => a == v
case EqualNullSafe("a", v) => (a: Int) => a == v
case LessThan("a", v: Int) => (a: Int) => a < v
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
case IsNull("a") => (a: Int) => false // Int can't be null
case IsNotNull("a") => (a: Int) => true
case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a)
case And(left, right) => (a: Int) =>
translateFilterOnA(left)(a) && translateFilterOnA(right)(a)
case Or(left, right) => (a: Int) =>
translateFilterOnA(left)(a) || translateFilterOnA(right)(a)
case _ => (a: Int) => true
}
// Predicate test on string column
def translateFilterOnC(filter: Filter): String => Boolean = filter match {
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
case EqualTo("c", v: String) => _.equals(v)
case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}
def eval(a: Int) = {
val c = (a - 1 + 'a').toChar.toString * 5 +
(a - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5
filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c))
}
sparkSession.sparkContext.parallelize(from to to).filter(eval).map(i =>
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
}
override def buildScan(groupingColumns: Array[String],
aggregateFunctions: Array[AggregateFunc],
filters: Array[Filter]): RDD[Row] = {
val rowBuilders = Array("a", "b", "c", "d", "e", "g", "f", "i", "j").map {
case "a" => (i: Int) => Seq(i)
case "b" => (i: Int) => Seq(i)
case "c" => (i: Int) =>
val c = (i % 2 + 'a').toChar.toString
Seq(c * 5 + c.toUpperCase * 5)
case "d" => (i: Int) => Seq(i)
case "e" => (i: Int) => Seq(i)
case "g" => (i: Int) => Seq(i % 2)
case "f" => (i: Int) => Seq(i)
case "i" => (i: Int) => Seq(i)
case "j" => (i: Int) => Seq(i)
}
// Predicate test on integer column
def translateFilterOnA(filter: Filter): Int => Boolean = filter match {
case EqualTo("a", v) => (a: Int) => a == v
case EqualNullSafe("a", v) => (a: Int) => a == v
case LessThan("a", v: Int) => (a: Int) => a < v
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
case IsNull("a") => (a: Int) => a == 7 // use 7 as NULL
case IsNotNull("a") => (a: Int) => a != 7
case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a)
case And(left, right) => (a: Int) =>
translateFilterOnA(left)(a) && translateFilterOnA(right)(a)
case Or(left, right) => (a: Int) =>
translateFilterOnA(left)(a) || translateFilterOnA(right)(a)
case _ => (a: Int) => true
}
// Predicate test on string column
def translateFilterOnC(filter: Filter): String => Boolean = filter match {
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
case EqualTo("c", v: String) => _.equals(v)
case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}
def eval(a: Int) = {
val c = (a % 2 + 'a').toChar.toString * 5 + (a % 2 + 'a').toChar.toString.toUpperCase * 5
filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c))
}
def columnIndex(c: String): Int = c match {
case "a" => 0
case "b" => 1
case "c" => 2
case "d" => 3
case "e" => 4
case "g" => 5
case "f" => 6
case "i" => 7
case "j" => 8
}
val filtered = sparkSession.sparkContext.parallelize(from to to).filter(eval).map { i =>
rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)
}
val grouped =
if (groupingColumns.isEmpty) {
filtered.map(r => ("NoSuchKey", r))
} else {
filtered.map(r => (groupingColumns.map(c => r(columnIndex(c))).mkString("+"),
r ++ groupingColumns.map(c => r(columnIndex(c)))))
}
val l = groupingColumns.length
val aggregated = grouped.groupByKey()
.map { case (k, it) =>
val ar = new ArrayBuffer[Any]
if(l > 0) {
for (i <- 0 until l) {
// grouping columns
ar += it.head(i + schema.fields.length)
}
}
aggregateFunctions.foreach {
case Sum(c, t) =>
val i = columnIndex(c)
var sum = 0
it.foreach { r => sum += r(i).asInstanceOf[Int] }
t match {
case LongType => ar += java.lang.Long.valueOf(sum)
case DoubleType => ar += java.lang.Double.valueOf(sum)
case dt: DecimalType => ar += java.math.BigDecimal.valueOf(sum)
}
case Count(c) => c match {
case "a" =>
var count = 0
it.foreach { r => if (r(0) != 7) count += 1 } // use 7 as NULL
ar += java.lang.Long.valueOf(count)
case _ =>
ar += java.lang.Long.valueOf(it.size)
}
case CountStar() =>
ar += java.lang.Long.valueOf(it.size)
case Max(c) =>
val i = columnIndex(c)
var max = java.lang.Integer.MIN_VALUE
it.foreach { r => if (r(i).asInstanceOf[Int] > max) max = r(i).asInstanceOf[Int] }
c match {
case "a" => ar += java.lang.Integer.valueOf(max)
case "b" => ar += java.lang.Long.valueOf(max)
case "d" => ar += java.lang.Double.valueOf(max)
case "e" => ar += java.math.BigDecimal.valueOf(max)
case "f" => ar += java.lang.Float.valueOf(max)
case "i" => ar += java.lang.Byte.valueOf(max.toByte)
case "j" => ar += java.lang.Short.valueOf(max.toShort)
}
case Min(c) =>
val i = columnIndex(c)
var min = java.lang.Integer.MAX_VALUE
it.foreach { r => if (r(i).asInstanceOf[Int] < min) min = r(i).asInstanceOf[Int] }
c match {
case "a" => ar += java.lang.Integer.valueOf(min)
case "b" => ar += java.lang.Long.valueOf(min)
case "d" => ar += java.lang.Double.valueOf(min)
case "e" => ar += java.math.BigDecimal.valueOf(min)
case "f" => ar += java.lang.Float.valueOf(min)
case "i" => ar += java.lang.Byte.valueOf(min.toByte)
case "j" => ar += java.lang.Short.valueOf(min.toShort)
}
}
(k, ar)
}
aggregated.map { case (_, aggResult) =>
Row.fromSeq(aggResult)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment