Created November 4, 2019 12:38
Explorer what spark join is
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType}
import org.scalatest.{FlatSpec, Matchers}
class JoinSpec extends FlatSpec with Matchers {
@transient lazy val logger: Logger = Logger.getLogger(getClass)
lazy val spark: SparkSession = {
.appName("spark test of join")
val sc = spark.sparkContext
private val customers = spark.createDataFrame(
Row(1, "harbor"),
Row(2, "mr.wu"),
Row(3, "babaozhou")
schema = StructType(
StructField("cid", IntegerType),
StructField("name", StringType)
private val orders = spark.createDataFrame(
Row(1, 1, 50.0d),
Row(2, 2, 10d),
Row(3, 2, 10d),
Row(4, 2, 10d),
Row(5, 1000, 19d)
)), schema = StructType(
StructField("oid", IntegerType),
StructField("cid", IntegerType),
StructField("amount", DoubleType)
"SparkJoin inner" should "collect only matching rows from both sides" in {
val innerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "inner")
val innerJoinResult = innerJoinResultDF.collect()
innerJoinResult should (have size 4 and contain allOf(
// |cid|oid|amount| name|
Row(1, 1, 50.0d, "harbor"),
Row(2, 2, 10d, "mr.wu"),
Row(2, 3, 10d, "mr.wu"),
Row(2, 4, 10d, "mr.wu")
"SparkJoin cross" should "create a Cartesian Product" in {
val crossJoinResultDF = orders.crossJoin(customers)
val crossJoinResult = crossJoinResultDF.collect()
crossJoinResult should (have size 15 and contain allOf(
// |oid| cid| amount| cid| name|
Row(1, 1, 50.0, 1, "harbor"),
Row(1, 1, 50.0, 2, "mr.wu"),
Row(1, 1, 50.0, 3, "babaozhou"),
Row(2, 2, 10.0, 1, "harbor"),
Row(2, 2, 10.0, 2, "mr.wu"),
Row(2, 2, 10.0, 3, "babaozhou"),
Row(3, 2, 10.0, 1, "harbor"),
Row(3, 2, 10.0, 2, "mr.wu"),
Row(3, 2, 10.0, 3, "babaozhou"),
Row(4, 2, 10.0, 1, "harbor"),
Row(4, 2, 10.0, 2, "mr.wu"),
Row(4, 2, 10.0, 3, "babaozhou"),
Row(5, 1000, 19.0, 1, "harbor"),
Row(5, 1000, 19.0, 2, "mr.wu"),
Row(5, 1000, 19.0, 3, "babaozhou")
"SparkJoin outer" should "full outer join 相比内连接多了null数据" in {
// "outer", "full", "fullouter", "full_outer"
val outerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "outer")
val outerJoinResult = outerJoinResultDF.collect()
outerJoinResult should (have size 6 and contain allOf(
// | cid| oid| amount| name|
Row( 1, 1, 50.0d, "harbor"),
Row( 2, 2, 10d, "mr.wu"),
Row( 2, 3, 10d, "mr.wu"),
Row( 2, 4, 10d, "mr.wu"),
Row( 3, null, null, "babaozhou"),
Row( 1000, 5, 19d, null)
"SparkJoin left order join customer" should "left outer join 相比内连接多了左边key不为null的null数据" in {
// "leftouter", "left", "left_outer"
val leftJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left")
val leftJoinResult = leftJoinResultDF.collect()
leftJoinResult should (have size 5 and contain allOf(
// | cid| oid| amount| name|
Row( 1, 1, 50.0d, "harbor"),
Row( 2, 2, 10d, "mr.wu"),
Row( 2, 3, 10d, "mr.wu"),
Row( 2, 4, 10d, "mr.wu"),
Row( 1000, 5, 19d, null)
"SparkJoin left customer join order" should "left outer join 相比内连接多了左边key不为null的null数据" in {
// "leftouter", "left", "left_outer"
val leftJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left")
val leftJoinResult = leftJoinResultDF.collect()
leftJoinResult should (have size 5 and contain allOf(
// |cid| name| oid| amount|
Row(1, "harbor", 1, 50.0d),
Row(2, "mr.wu", 2, 10d),
Row(2, "mr.wu", 3, 10d),
Row(2, "mr.wu", 4, 10d),
Row(3, "babaozhou", null, null)
"SparkJoin right customer join order" should "和left order join customer除了列数据顺序不一样之外其他都一样" in {
// "rightouter", "right", "right_outer"
val rightJoinResultDF = customers.join(orders, Seq("cid"), joinType = "right")
val rightJoinResult = rightJoinResultDF.collect()
rightJoinResult should (have size 5 and contain allOf(
// | cid| name| oid| amount|
Row( 1, "harbor", 1, 50.0d),
Row( 2, "mr.wu", 2, 10d),
Row( 2, "mr.wu", 3, 10d),
Row( 2, "mr.wu", 4, 10d),
Row( 1000, null, 5, 19d)
"SparkJoin right order join customer" should "和left customer join order除了列数据顺序不一样之外其他都一样" in {
// "rightouter", "right", "right_outer"
val rightJoinResultDF = orders.join(customers, Seq("cid"), joinType = "right")
val rightJoinResult = rightJoinResultDF.collect()
rightJoinResult should (have size 5 and contain allOf(
// |cid| oid| amount| name|
Row(1, 1, 50.0d, "harbor"),
Row(2, 2, 10d, "mr.wu"),
Row(2, 3, 10d, "mr.wu"),
Row(2, 4, 10d, "mr.wu"),
Row(3, null, null, "babaozhou")
"SparkJoin left_semi order join customer" should "和inner除了少了右边独有的列之外其他都一样" in {
// "leftsemi", "left_semi"
val leftSemiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_semi")
val leftSemiJoinResult = leftSemiJoinResultDF.collect()
leftSemiJoinResult should (have size 4 and contain allOf(
// |cid|oid|amount|
Row(1, 1, 50.0d),
Row(2, 2, 10d),
Row(2, 3, 10d),
Row(2, 4, 10d)
"SparkJoin left_semi custom join order" should "和inner除了少了右边独有的列还有去重了之外其他都一样" in {
// "leftsemi", "left_semi"
val leftSemiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_semi")
val leftSemiJoinResult = leftSemiJoinResultDF.collect()
leftSemiJoinResult should (have size 2 and contain allOf(
// |cid |name |
Row(1, "harbor"),
Row(2, "mr.wu")
"SparkJoin left_anti custom join order" should "和 left_semi custom join order 组合在一起就是完整的 customs 表" in {
// "leftanti", "left_anti"
val leftAntiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_anti")
val leftAntiJoinResult = leftAntiJoinResultDF.collect()
leftAntiJoinResult should (have size 1 and contain (
// |cid|name|
Row(3, "babaozhou")
"SparkJoin left_anti order join customer" should "和 left_semi order join customer 组合在一起就是完整的 orders 表" in {
// "leftanti", "left_anti"
val leftAntiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_anti")
val leftAntiJoinResult = leftAntiJoinResultDF.collect()
leftAntiJoinResult should (have size 1 and contain (
// |cid|oid|amount|
Row(1000,5, 19d)
