Skip to content

Instantly share code, notes, and snippets.

@zhouyuan
Created April 13, 2023 10:09
Show Gist options
  • Save zhouyuan/8c41cb1b579b3ca5bb5879ff7260c139 to your computer and use it in GitHub Desktop.
Save zhouyuan/8c41cb1b579b3ca5bb5879ff7260c139 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
index 171e93c1bf..53662c6560 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.time.{Duration, Period}
+
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
@@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa
}
}
}
+
+ test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") {
+ val table = "approx_count_distinct_for_ansi_intervals_tbl"
+ withTable(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ val endpoints = (0 to 5).map(_ / 10)
+
+ val relation = spark.table(table).logicalPlan
+ val ymAttr = relation.output.find(_.name == "col1").get
+ val ymAggFunc =
+ ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_))))
+ val ymAggExpr = ymAggFunc.toAggregateExpression()
+ val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)()
+
+ val dtAttr = relation.output.find(_.name == "col2").get
+ val dtAggFunc =
+ ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_))))
+ val dtAggExpr = dtAggFunc.toAggregateExpression()
+ val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)()
+ val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation))
+ checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index 5ff15c9710..9237c9e948 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql
import java.sql.{Date, Timestamp}
-import java.time.LocalDateTime
+import java.time.{Duration, LocalDateTime, Period}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
@@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession {
import testImplicits._
- private val table = "percentile_test"
+ private val table = "percentile_approx"
test("percentile_approx, single percentile value") {
withTempView(table) {
@@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
Row(18, 17, 17, 17))
}
}
+
+ test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") {
+ withTempView(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ checkAnswer(
+ spark.sql(
+ s"""SELECT
+ | percentile_approx(col1, 0.5),
+ | SUM(null),
+ | percentile_approx(col2, 0.5)
+ |FROM $table
+ """.stripMargin),
+ Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
new file mode 100644
index 0000000000..05513cddcc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Query tests for the Bloom filter aggregate and filter function.
+ */
+class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
+ val funcId_might_contain = new FunctionIdentifier("might_contain")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ // Register 'bloom_filter_agg' to builtin.
+ spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg,
+ new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
+ (children: Seq[Expression]) => children.size match {
+ case 1 => new BloomFilterAggregate(children.head)
+ case 2 => new BloomFilterAggregate(children.head, children(1))
+ case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
+ })
+
+ // Register 'might_contain' to builtin.
+ spark.sessionState.functionRegistry.registerFunction(funcId_might_contain,
+ new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
+ (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
+ }
+
+ override def afterAll(): Unit = {
+ spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
+ spark.sessionState.functionRegistry.dropFunction(funcId_might_contain)
+ super.afterAll()
+ }
+
+ test("Test bloom_filter_agg and might_contain") {
+ val conf = SQLConf.get
+ val table = "bloom_filter_test"
+ for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
+ for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) {
+ val sqlString = s"""
+ |SELECT every(might_contain(
+ | (SELECT bloom_filter_agg(col,
+ | cast($numEstimatedItems as long),
+ | cast($numBits as long))
+ | FROM $table),
+ | col)) positive_membership_test,
+ | every(might_contain(
+ | (SELECT bloom_filter_agg(col,
+ | cast($numEstimatedItems as long),
+ | cast($numBits as long))
+ | FROM values (-1L), (100001L), (20000L) as t(col)),
+ | col)) negative_membership_test
+ |FROM $table
+ """.stripMargin
+ withTempView(table) {
+ (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L))
+ .toDF("col").createOrReplaceTempView(table)
+ // Validate error messages as well as answers when there's no error.
+ if (numEstimatedItems <= 0) {
+ val exception = intercept[AnalysisException] {
+ spark.sql(sqlString)
+ }
+ assert(exception.getMessage.contains(
+ "The estimated number of items must be a positive value"))
+ } else if (numBits <= 0) {
+ val exception = intercept[AnalysisException] {
+ spark.sql(sqlString)
+ }
+ assert(exception.getMessage.contains("The number of bits must be a positive value"))
+ } else {
+ checkAnswer(spark.sql(sqlString), Row(true, false))
+ }
+ }
+ }
+ }
+ }
+
+ test("Test that bloom_filter_agg errors out disallowed input value types") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a)
+ |FROM values (1.2), (2.5) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "Input to function bloom_filter_agg should have been a bigint value"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, 2)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+ val exception3 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, cast(2 as long), 5)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception3.getMessage.contains(
+ "function bloom_filter_agg should have been a bigint value followed with two bigint"))
+
+ val exception4 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, null, 5)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments"))
+
+ val exception5 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT bloom_filter_agg(a, 5, null)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments"))
+ }
+
+ test("Test that might_contain errors out disallowed input value types") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""|SELECT might_contain(1.0, 1L)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "Input to function might_contain should have been binary followed by a value with bigint"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""|SELECT might_contain(NULL, 0.1)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "Input to function might_contain should have been binary followed by a value with bigint"))
+ }
+
+ test("Test that might_contain errors out non-constant Bloom filter") {
+ val exception1 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT might_contain(cast(a as binary), cast(5 as long))
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception1.getMessage.contains(
+ "The Bloom filter binary input to might_contain should be either a constant value or " +
+ "a scalar subquery expression"))
+
+ val exception2 = intercept[AnalysisException] {
+ spark.sql("""
+ |SELECT might_contain((select cast(a as binary)), cast(5 as long))
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)"""
+ .stripMargin)
+ }
+ assert(exception2.getMessage.contains(
+ "The Bloom filter binary input to might_contain should be either a constant value or " +
+ "a scalar subquery expression"))
+ }
+
+ test("Test that might_contain can take a constant value input") {
+ checkAnswer(spark.sql(
+ """SELECT might_contain(
+ |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
+ |cast(201 as long))""".stripMargin),
+ Row(false))
+ }
+
+ test("Test that bloom_filter_agg produces a NULL with empty input") {
+ checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""),
+ Row(null))
+ }
+
+ test("Test NULL inputs for might_contain") {
+ checkAnswer(spark.sql(
+ s"""
+ |SELECT might_contain(null, null) both_null,
+ | might_contain(null, 1L) null_bf,
+ | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)),
+ | null) null_value
+ """.stripMargin),
+ Row(null, null, null))
+ }
+
+ test("Test that a query with bloom_filter_agg has partial aggregates") {
+ assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""")
+ .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
+ .collect({case agg: BaseAggregateExec => agg}).size == 2)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala
index 13039bbbf6..a596ebc6b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.log4j.Level
+import org.apache.logging.log4j.Level
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.test.SharedSparkSession
@@ -65,7 +65,7 @@ class CTEHintSuite extends QueryTest with SharedSparkSession {
}
val warningMessages = logAppender.loggingEvents
.filter(_.getLevel == Level.WARN)
- .map(_.getRenderedMessage)
+ .map(_.getMessage.getFormattedMessage)
.filter(_.contains("hint"))
assert(warningMessages.size == warnings.size)
warnings.foreach { w =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
index 7ee533ac26..e758c6f8df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.plans.logical.WithCTE
+import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, LessThan, Literal, Or}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.internal.SQLConf
@@ -42,7 +43,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df, Nil)
assert(
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty,
+ df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
"Non-deterministic With-CTE with multiple references should be not inlined.")
}
}
@@ -59,7 +60,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df, Nil)
assert(
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty,
+ df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
"Non-deterministic With-CTE with multiple references should be not inlined.")
}
}
@@ -76,10 +77,10 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df, Row(0, 1) :: Row(1, 2) :: Nil)
assert(
- df.queryExecution.analyzed.find(_.isInstanceOf[WithCTE]).nonEmpty,
+ df.queryExecution.analyzed.exists(_.isInstanceOf[WithCTE]),
"With-CTE should not be inlined in analyzed plan.")
assert(
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).isEmpty,
+ !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
"With-CTE with one reference should be inlined in optimized plan.")
}
}
@@ -107,8 +108,8 @@ abstract class CTEInlineSuiteBase
"With-CTE should contain 2 CTE defs after analysis.")
assert(
df.queryExecution.optimizedPlan.collect {
- case WithCTE(_, cteDefs) => cteDefs
- }.head.length == 2,
+ case r: RepartitionOperation => r
+ }.length == 6,
"With-CTE should contain 2 CTE def after optimization.")
}
}
@@ -136,8 +137,8 @@ abstract class CTEInlineSuiteBase
"With-CTE should contain 2 CTE defs after analysis.")
assert(
df.queryExecution.optimizedPlan.collect {
- case WithCTE(_, cteDefs) => cteDefs
- }.head.length == 1,
+ case r: RepartitionOperation => r
+ }.length == 4,
"One CTE def should be inlined after optimization.")
}
}
@@ -163,7 +164,7 @@ abstract class CTEInlineSuiteBase
"With-CTE should contain 2 CTE defs after analysis.")
assert(
df.queryExecution.optimizedPlan.collect {
- case WithCTE(_, cteDefs) => cteDefs
+ case r: RepartitionOperation => r
}.isEmpty,
"CTEs with one reference should all be inlined after optimization.")
}
@@ -248,7 +249,7 @@ abstract class CTEInlineSuiteBase
"With-CTE should contain 2 CTE defs after analysis.")
assert(
df.queryExecution.optimizedPlan.collect {
- case WithCTE(_, cteDefs) => cteDefs
+ case r: RepartitionOperation => r
}.isEmpty,
"Deterministic CTEs should all be inlined after optimization.")
}
@@ -272,6 +273,372 @@ abstract class CTEInlineSuiteBase
assert(ex.message.contains("Table or view not found: v1"))
}
}
+
+ test("CTE Predicate push-down and column pruning") {
+ withView("t") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+ val df = sql(
+ s"""with
+ |v as (
+ | select c1, c2, 's' c3, rand() c4 from t
+ |),
+ |vv as (
+ | select v1.c1, v1.c2, rand() c5 from v v1, v v2
+ | where v1.c1 > 0 and v1.c3 = 's' and v1.c2 = v2.c2
+ |)
+ |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2
+ |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1
+ """.stripMargin)
+ checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+ assert(
+ df.queryExecution.analyzed.collect {
+ case WithCTE(_, cteDefs) => cteDefs
+ }.head.length == 2,
+ "With-CTE should contain 2 CTE defs after analysis.")
+ val cteRepartitions = df.queryExecution.optimizedPlan.collect {
+ case r: RepartitionOperation => r
+ }
+ assert(cteRepartitions.length == 6,
+ "CTE should not be inlined after optimization.")
+ val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct
+ // Check column pruning and predicate push-down.
+ assert(distinctCteRepartitions.length == 2)
+ assert(distinctCteRepartitions(1).collectFirst {
+ case p: Project if p.projectList.length == 3 => p
+ }.isDefined, "CTE columns should be pruned.")
+ assert(distinctCteRepartitions(1).collectFirst {
+ case f: Filter if f.condition.semanticEquals(GreaterThan(f.output(1), Literal(0))) => f
+ }.isDefined, "Predicate 'c2 > 0' should be pushed down to the CTE def 'v'.")
+ assert(distinctCteRepartitions(0).collectFirst {
+ case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f
+ }.isDefined, "CTE 'vv' definition contains predicate 'c1 > 0'.")
+ assert(distinctCteRepartitions(1).collectFirst {
+ case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f
+ }.isEmpty, "Predicate 'c1 > 0' should be not pushed down to the CTE def 'v'.")
+ // Check runtime repartition reuse.
+ assert(
+ collectWithSubqueries(df.queryExecution.executedPlan) {
+ case r: ReusedExchangeExec => r
+ }.length == 2,
+ "CTE repartition is reused.")
+ }
+ }
+
+ test("CTE Predicate push-down and column pruning - combined predicate") {
+ withView("t") {
+ Seq((0, 1, 2), (1, 2, 3)).toDF("c1", "c2", "c3").createOrReplaceTempView("t")
+ val df = sql(
+ s"""with
+ |v as (
+ | select c1, c2, c3, rand() c4 from t
+ |),
+ |vv as (
+ | select v1.c1, v1.c2, rand() c5 from v v1, v v2
+ | where v1.c1 > 0 and v2.c3 < 5 and v1.c2 = v2.c2
+ |)
+ |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2
+ |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1
+ """.stripMargin)
+ checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+ assert(
+ df.queryExecution.analyzed.collect {
+ case WithCTE(_, cteDefs) => cteDefs
+ }.head.length == 2,
+ "With-CTE should contain 2 CTE defs after analysis.")
+ val cteRepartitions = df.queryExecution.optimizedPlan.collect {
+ case r: RepartitionOperation => r
+ }
+ assert(cteRepartitions.length == 6,
+ "CTE should not be inlined after optimization.")
+ val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct
+ // Check column pruning and predicate push-down.
+ assert(distinctCteRepartitions.length == 2)
+ assert(distinctCteRepartitions(1).collectFirst {
+ case p: Project if p.projectList.length == 3 => p
+ }.isDefined, "CTE columns should be pruned.")
+ assert(
+ distinctCteRepartitions(1).collectFirst {
+ case f: Filter
+ if f.condition.semanticEquals(
+ And(
+ GreaterThan(f.output(1), Literal(0)),
+ Or(
+ GreaterThan(f.output(0), Literal(0)),
+ LessThan(f.output(2), Literal(5))))) =>
+ f
+ }.isDefined,
+ "Predicate 'c2 > 0 AND (c1 > 0 OR c3 < 5)' should be pushed down to the CTE def 'v'.")
+ // Check runtime repartition reuse.
+ assert(
+ collectWithSubqueries(df.queryExecution.executedPlan) {
+ case r: ReusedExchangeExec => r
+ }.length == 2,
+ "CTE repartition is reused.")
+ }
+ }
+
+ test("Views with CTEs - 1 temp view") {
+ withView("t", "t2") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+ sql(
+ s"""with
+ |v as (
+ | select c1 + c2 c3 from t
+ |)
+ |select sum(c3) s from v
+ """.stripMargin).createOrReplaceTempView("t2")
+ val df = sql(
+ s"""with
+ |v as (
+ | select c1 * c2 c3 from t
+ |)
+ |select sum(c3) from v except select s from t2
+ """.stripMargin)
+ checkAnswer(df, Row(2) :: Nil)
+ }
+ }
+
+ test("Views with CTEs - 2 temp views") {
+ withView("t", "t2", "t3") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+ sql(
+ s"""with
+ |v as (
+ | select c1 + c2 c3 from t
+ |)
+ |select sum(c3) s from v
+ """.stripMargin).createOrReplaceTempView("t2")
+ sql(
+ s"""with
+ |v as (
+ | select c1 * c2 c3 from t
+ |)
+ |select sum(c3) s from v
+ """.stripMargin).createOrReplaceTempView("t3")
+ val df = sql("select s from t3 except select s from t2")
+ checkAnswer(df, Row(2) :: Nil)
+ }
+ }
+
+ test("Views with CTEs - temp view + sql view") {
+ withTable("t") {
+ withView ("t2", "t3") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").write.saveAsTable("t")
+ sql(
+ s"""with
+ |v as (
+ | select c1 + c2 c3 from t
+ |)
+ |select sum(c3) s from v
+ """.stripMargin).createOrReplaceTempView("t2")
+ sql(
+ s"""create view t3 as
+ |with
+ |v as (
+ | select c1 * c2 c3 from t
+ |)
+ |select sum(c3) s from v
+ """.stripMargin)
+ val df = sql("select s from t3 except select s from t2")
+ checkAnswer(df, Row(2) :: Nil)
+ }
+ }
+ }
+
+ test("Union of Dataframes with CTEs") {
+ val a = spark.sql("with t as (select 1 as n) select * from t ")
+ val b = spark.sql("with t as (select 2 as n) select * from t ")
+ val df = a.union(b)
+ checkAnswer(df, Row(1) :: Row(2) :: Nil)
+ }
+
+ test("CTE definitions out of original order when not inlined") {
+ withView("t1", "t2") {
+ Seq((1, 2, 10, 100), (2, 3, 20, 200)).toDF("workspace_id", "issue_id", "shard_id", "field_id")
+ .createOrReplaceTempView("issue_current")
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+ "org.apache.spark.sql.catalyst.optimizer.InlineCTE") {
+ val df = sql(
+ """
+ |WITH cte_0 AS (
+ | SELECT workspace_id, issue_id, shard_id, field_id FROM issue_current
+ |),
+ |cte_1 AS (
+ | WITH filtered_source_table AS (
+ | SELECT * FROM cte_0 WHERE shard_id in ( 10 )
+ | )
+ | SELECT source_table.workspace_id, field_id FROM cte_0 source_table
+ | INNER JOIN (
+ | SELECT workspace_id, issue_id FROM filtered_source_table GROUP BY 1, 2
+ | ) target_table
+ | ON source_table.issue_id = target_table.issue_id
+ | AND source_table.workspace_id = target_table.workspace_id
+ | WHERE source_table.shard_id IN ( 10 )
+ |)
+ |SELECT * FROM cte_1
+ """.stripMargin)
+ checkAnswer(df, Row(1, 100) :: Nil)
+ }
+ }
+ }
+
+ test("Make sure CTESubstitution places WithCTE back in the plan correctly.") {
+ withView("t") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+
+ // CTE on both sides of join - WithCTE placed over first common parent, i.e., the join.
+ val df1 = sql(
+ s"""
+ |select count(v1.c3), count(v2.c3) from (
+ | with
+ | v1 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v1
+ |) v1 join (
+ | with
+ | v2 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v2
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df1, Row(2, 2) :: Nil)
+ df1.queryExecution.analyzed match {
+ case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 2)
+ case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got $other")
+ }
+
+ // CTE on one side of join - WithCTE placed back where it was.
+ val df2 = sql(
+ s"""
+ |select count(v1.c3), count(v2.c3) from (
+ | select c1, c2, rand() c3 from t
+ |) v1 join (
+ | with
+ | v2 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v2
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df2, Row(2, 2) :: Nil)
+ df2.queryExecution.analyzed match {
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _, _, _)) =>
+ assert(cteDefs.length == 1)
+ case other => fail(s"Expect pattern Aggregate(Join(_, WithCTE(_))) but got $other")
+ }
+
+ // CTE on one side of join and both sides of union - WithCTE placed on first common parent.
+ val df3 = sql(
+ s"""
+ |select count(v1.c3), count(v2.c3) from (
+ | select c1, c2, rand() c3 from t
+ |) v1 join (
+ | select * from (
+ | with
+ | v1 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v1
+ | )
+ | union all
+ | select * from (
+ | with
+ | v2 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v2
+ | )
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df3, Row(4, 4) :: Nil)
+ df3.queryExecution.analyzed match {
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union, cteDefs)), _, _, _)) =>
+ assert(cteDefs.length == 2)
+ case other => fail(
+ s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got $other")
+ }
+
+ // CTE on one side of join and one side of union - WithCTE placed back where it was.
+ val df4 = sql(
+ s"""
+ |select count(v1.c3), count(v2.c3) from (
+ | select c1, c2, rand() c3 from t
+ |) v1 join (
+ | select * from (
+ | with
+ | v1 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v1
+ | )
+ | union all
+ | select c1, c2, rand() c3 from t
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df4, Row(4, 4) :: Nil)
+ df4.queryExecution.analyzed match {
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)), _, _, _))
+ if children.head.find(_.isInstanceOf[WithCTE]).isDefined =>
+ assert(
+ children.head.collect {
+ case w: WithCTE => w
+ }.head.cteDefs.length == 1)
+ case other => fail(
+ s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got $other")
+ }
+
+ // CTE on both sides of join and one side of union - WithCTE placed on first common parent.
+ val df5 = sql(
+ s"""
+ |select count(v1.c3), count(v2.c3) from (
+ | with
+ | v1 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v1
+ |) v1 join (
+ | select c1, c2, rand() c3 from t
+ | union all
+ | select * from (
+ | with
+ | v2 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v2
+ | )
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df5, Row(4, 4) :: Nil)
+ df5.queryExecution.analyzed match {
+ case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 2)
+ case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got $other")
+ }
+
+ // CTE as root node - WithCTE placed back where it was.
+ val df6 = sql(
+ s"""
+ |with
+ |v1 as (
+ | select c1, c2, rand() c3 from t
+ |)
+ |select count(v1.c3), count(v2.c3) from
+ |v1 join (
+ | with
+ | v2 as (
+ | select c1, c2, rand() c3 from t
+ | )
+ | select * from v2
+ |) v2 on v1.c1 = v2.c1
+ """.stripMargin)
+ checkAnswer(df6, Row(2, 2) :: Nil)
+ df6.queryExecution.analyzed match {
+ case WithCTE(_, cteDefs) => assert(cteDefs.length == 2)
+ case other => fail(s"Expect pattern WithCTE(_) but got $other")
+ }
+ }
+ }
}
class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index bad28152e4..4de409f56d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
-import org.apache.spark.sql.execution.{ExecSubqueryExpression, RDDScanExec, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -604,9 +604,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
uncacheTable("t2")
}
- // One side of join is not partitioned in the desired way. Since the number of partitions of
- // the side that has already partitioned is smaller than the side that is not partitioned,
- // we shuffle both side.
+ // One side of join is not partitioned in the desired way. We'll only shuffle this side.
withTempView("t1", "t2") {
testData.repartition(6, $"value").createOrReplaceTempView("t1")
testData2.repartition(3, $"a").createOrReplaceTempView("t2")
@@ -614,7 +612,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
spark.catalog.cacheTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
- verifyNumExchanges(query, 2)
+ verifyNumExchanges(query, 1)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
@@ -881,8 +879,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
test("SPARK-23312: vectorized cache reader can be disabled") {
Seq(true, false).foreach { vectorized =>
withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
- val df = spark.range(10).cache()
- df.queryExecution.executedPlan.foreach {
+ val df1 = spark.range(10).cache()
+ val df2 = spark.range(10).cache()
+ val union = df1.union(df2)
+ union.queryExecution.executedPlan.foreach {
case i: InMemoryTableScanExec =>
assert(i.supportsColumnar == vectorized)
case _ =>
@@ -891,6 +891,19 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
}
}
+ test("SPARK-37369: Avoid redundant ColumnarToRow transition on InMemoryTableScan") {
+ Seq(true, false).foreach { vectorized =>
+ withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
+ val cache = spark.range(10).cache()
+ val df = cache.filter($"id" > 0)
+ val columnarToRow = df.queryExecution.executedPlan.collect {
+ case c: ColumnarToRowExec => c
+ }
+ assert(columnarToRow.isEmpty)
+ }
+ }
+ }
+
private def checkIfNoJobTriggered[T](f: => T): T = {
var numJobTriggered = 0
val jobListener = new SparkListener {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index 7be54d49a9..978e3f8d36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -100,6 +100,19 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
}
}
+ test("char type values should not be padded when charVarcharAsString is true") {
+ withSQLConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key -> "true") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(a STRING, b CHAR(5), c CHAR(5)) USING $format partitioned by (c)")
+ sql("INSERT INTO t VALUES ('abc', 'abc', 'abc')")
+ checkAnswer(sql("SELECT b FROM t WHERE b='abc'"), Row("abc"))
+ checkAnswer(sql("SELECT b FROM t WHERE b in ('abc')"), Row("abc"))
+ checkAnswer(sql("SELECT c FROM t WHERE c='abc'"), Row("abc"))
+ checkAnswer(sql("SELECT c FROM t WHERE c in ('abc')"), Row("abc"))
+ }
+ }
+ }
+
test("varchar type values length check and trim: partitioned columns") {
(0 to 5).foreach { n =>
// SPARK-34192: we need to create a a new table for each round of test because of
@@ -332,8 +345,8 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql(s"CREATE TABLE t(c STRUCT<c: $typeName(5)>) USING $format")
sql("INSERT INTO t SELECT struct(null)")
checkAnswer(spark.table("t"), Row(Row(null)))
- val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')"))
- assert(e.getCause.getMessage.contains(s"Exceeds char/varchar type length limitation: 5"))
+ val e = intercept[RuntimeException](sql("INSERT INTO t SELECT struct('123456')"))
+ assert(e.getMessage.contains(s"Exceeds char/varchar type length limitation: 5"))
}
}
@@ -843,27 +856,6 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
}
}
- // TODO(SPARK-33875): Move these tests to super after DESCRIBE COLUMN v2 implemented
- test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") {
- withTable("t") {
- sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format")
- checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"),
- Row("varchar(3)"))
- checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"),
- Row("char(5)"))
- }
- }
-
- // TODO(SPARK-33898): Move these tests to super after SHOW CREATE TABLE for v2 implemented
- test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") {
- withTable("t") {
- sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format")
- val rest = sql("SHOW CREATE TABLE t").head().getString(0)
- assert(rest.contains("VARCHAR(3)"))
- assert(rest.contains("CHAR(5)"))
- }
- }
-
test("SPARK-34114: should not trim right for read-side length check and char padding") {
Seq("char", "varchar").foreach { typ =>
withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index e7ca431726..1f8dc6f80d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -137,6 +137,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}
+ test("SPARK-34805: as propagates metadata from nested column") {
+ val metadata = new MetadataBuilder
+ metadata.putString("key", "value")
+ val df = spark.createDataFrame(sparkContext.emptyRDD[Row],
+ StructType(Seq(
+ StructField("parent", StructType(Seq(
+ StructField("child", StringType, metadata = metadata.build())
+ ))))
+ ))
+ val newCol = df("parent.child")
+ assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
+ }
+
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
@@ -281,9 +294,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
testData.select(isnan($"a"), isnan($"b")),
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
- checkAnswer(
- sql("select isnan(15), isnan('invalid')"),
- Row(false, false))
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ sql("select isnan(15), isnan('invalid')"),
+ Row(false, false))
+ }
}
test("nanvl") {
@@ -932,7 +947,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
}
+ test("SPARK-37646: lit") {
+ assert(lit($"foo") == $"foo")
+ assert(lit(Symbol("foo")) == $"foo")
+ assert(lit(1) == Column(Literal(1)))
+ assert(lit(null) == Column(Literal(null, NullType)))
+ }
+
test("typedLit") {
+ assert(typedLit($"foo") == $"foo")
+ assert(typedLit(Symbol("foo")) == $"foo")
+ assert(typedLit(1) == Column(Literal(1)))
+ assert(typedLit[String](null) == Column(Literal(null, StringType)))
+
val df = Seq(Tuple1(0)).toDF("a")
// Only check the types `lit` cannot handle
checkAnswer(
@@ -1017,17 +1044,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should throw an exception if any intermediate structs don't exist") {
intercept[AnalysisException] {
- structLevel2.withColumn("a", 'a.withField("x.b", lit(2)))
+ structLevel2.withColumn("a", Symbol("a").withField("x.b", lit(2)))
}.getMessage should include("No such struct field x in a")
intercept[AnalysisException] {
- structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2)))
+ structLevel3.withColumn("a", Symbol("a").withField("a.x.b", lit(2)))
}.getMessage should include("No such struct field x in a")
}
test("withField should throw an exception if intermediate field is not a struct") {
intercept[AnalysisException] {
- structLevel1.withColumn("a", 'a.withField("b.a", lit(2)))
+ structLevel1.withColumn("a", Symbol("a").withField("b.a", lit(2)))
}.getMessage should include("struct argument should be struct type, got: int")
}
@@ -1041,7 +1068,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
StructField("a", structType, nullable = false))),
nullable = false))))
- structLevel2.withColumn("a", 'a.withField("a.b", lit(2)))
+ structLevel2.withColumn("a", Symbol("a").withField("a.b", lit(2)))
}.getMessage should include("Ambiguous reference to fields")
}
@@ -1060,7 +1087,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add field to struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("d", lit(4))),
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4))),
Row(Row(1, null, 3, 4)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1101,7 +1128,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add null field to struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))),
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(null).cast(IntegerType))),
Row(Row(1, null, 3, null)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1114,7 +1141,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add multiple fields to struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))),
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("e", lit(5))),
Row(Row(1, null, 3, 4, 5)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1128,7 +1155,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add multiple fields to nullable struct") {
checkAnswer(
- nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))),
+ nullableStructLevel1.withColumn("a", Symbol("a")
+ .withField("d", lit(4)).withField("e", lit(5))),
Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1142,8 +1170,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add field to nested struct") {
Seq(
- structLevel2.withColumn("a", 'a.withField("a.d", lit(4))),
- structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4))))
+ structLevel2.withColumn("a", Symbol("a").withField("a.d", lit(4))),
+ structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("d", lit(4))))
).foreach { df =>
checkAnswer(
df,
@@ -1204,7 +1232,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add field to deeply nested struct") {
checkAnswer(
- structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))),
+ structLevel3.withColumn("a", Symbol("a").withField("a.a.d", lit(4))),
Row(Row(Row(Row(1, null, 3, 4)))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1221,7 +1249,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace field in struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(2))),
Row(Row(1, 2, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1233,7 +1261,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace field in nullable struct") {
checkAnswer(
- nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))),
+ nullableStructLevel1.withColumn("a", Symbol("a").withField("b", lit("foo"))),
Row(null) :: Row(Row(1, "foo", 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1259,7 +1287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace field with null value in struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))),
+ structLevel1.withColumn("a", Symbol("a").withField("c", lit(null).cast(IntegerType))),
Row(Row(1, null, null)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1271,7 +1299,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace multiple fields in struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))),
+ structLevel1.withColumn("a", Symbol("a").withField("a", lit(10)).withField("b", lit(20))),
Row(Row(10, 20, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1283,7 +1311,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace multiple fields in nullable struct") {
checkAnswer(
- nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))),
+ nullableStructLevel1.withColumn("a", Symbol("a").withField("a", lit(10))
+ .withField("b", lit(20))),
Row(null) :: Row(Row(10, 20, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1296,7 +1325,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace field in nested struct") {
Seq(
structLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
- structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2))))
+ structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("b", lit(2))))
).foreach { df =>
checkAnswer(
df,
@@ -1377,7 +1406,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("b", lit(100))),
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(100))),
Row(Row(1, 100, 100)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1389,7 +1418,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace fields in struct in given order") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))),
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(2)).withField("b", lit(20))),
Row(Row(1, 20, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1401,7 +1430,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add field and then replace same field in struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))),
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("d", lit(5))),
Row(Row(1, null, 3, 5)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1425,7 +1454,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))),
+ df.withColumn("a", Symbol("a").withField("`a.b`.`e.f`", lit(2))),
Row(Row(Row(1, 2, 3))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1437,7 +1466,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
intercept[AnalysisException] {
- df.withColumn("a", 'a.withField("a.b.e.f", lit(2)))
+ df.withColumn("a", Symbol("a").withField("a.b.e.f", lit(2)))
}.getMessage should include("No such struct field a in a.b")
}
@@ -1452,7 +1481,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))),
Row(Row(2, 1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1461,7 +1490,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))),
Row(Row(1, 2)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1474,7 +1503,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should add field to struct because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))),
Row(Row(1, 1, 2)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1484,7 +1513,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))),
Row(Row(1, 1, 2)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1512,7 +1541,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should replace nested field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswer(
- mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))),
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))),
Row(Row(Row(2, 1), Row(1, 1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1527,7 +1556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))),
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))),
Row(Row(Row(1, 1), Row(2, 1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1546,11 +1575,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("withField should throw an exception because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
intercept[AnalysisException] {
- mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2)))
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2)))
}.getMessage should include("No such struct field A in a, B")
intercept[AnalysisException] {
- mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2)))
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2)))
}.getMessage should include("No such struct field b in a, B")
}
}
@@ -1757,17 +1786,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should throw an exception if any intermediate structs don't exist") {
intercept[AnalysisException] {
- structLevel2.withColumn("a", 'a.dropFields("x.b"))
+ structLevel2.withColumn("a", Symbol("a").dropFields("x.b"))
}.getMessage should include("No such struct field x in a")
intercept[AnalysisException] {
- structLevel3.withColumn("a", 'a.dropFields("a.x.b"))
+ structLevel3.withColumn("a", Symbol("a").dropFields("a.x.b"))
}.getMessage should include("No such struct field x in a")
}
test("dropFields should throw an exception if intermediate field is not a struct") {
intercept[AnalysisException] {
- structLevel1.withColumn("a", 'a.dropFields("b.a"))
+ structLevel1.withColumn("a", Symbol("a").dropFields("b.a"))
}.getMessage should include("struct argument should be struct type, got: int")
}
@@ -1781,13 +1810,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
StructField("a", structType, nullable = false))),
nullable = false))))
- structLevel2.withColumn("a", 'a.dropFields("a.b"))
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b"))
}.getMessage should include("Ambiguous reference to fields")
}
test("dropFields should drop field in struct") {
checkAnswer(
- structLevel1.withColumn("a", 'a.dropFields("b")),
+ structLevel1.withColumn("a", Symbol("a").dropFields("b")),
Row(Row(1, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1810,7 +1839,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop multiple fields in struct") {
Seq(
structLevel1.withColumn("a", $"a".dropFields("b", "c")),
- structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c"))
+ structLevel1.withColumn("a", Symbol("a").dropFields("b").dropFields("c"))
).foreach { df =>
checkAnswer(
df,
@@ -1824,7 +1853,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should throw an exception if no fields will be left in struct") {
intercept[AnalysisException] {
- structLevel1.withColumn("a", 'a.dropFields("a", "b", "c"))
+ structLevel1.withColumn("a", Symbol("a").dropFields("a", "b", "c"))
}.getMessage should include("cannot drop all fields in struct")
}
@@ -1848,7 +1877,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop field in nested struct") {
checkAnswer(
- structLevel2.withColumn("a", 'a.dropFields("a.b")),
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b")),
Row(Row(Row(1, 3))) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
@@ -1861,7 +1890,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop multiple fields in nested struct") {
checkAnswer(
- structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")),
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b", "a.c")),
Row(Row(Row(1))) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
@@ -1898,7 +1927,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop field in deeply nested struct") {
checkAnswer(
- structLevel3.withColumn("a", 'a.dropFields("a.a.b")),
+ structLevel3.withColumn("a", Symbol("a").dropFields("a.a.b")),
Row(Row(Row(Row(1, 3)))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1922,7 +1951,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- structLevel1.withColumn("a", 'a.dropFields("b")),
+ structLevel1.withColumn("a", Symbol("a").dropFields("b")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1933,7 +1962,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1941,7 +1970,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1953,7 +1982,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should not drop field in struct because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")),
Row(Row(1, 1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1962,7 +1991,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")),
Row(Row(1, 1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1975,7 +2004,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should drop nested field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswer(
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")),
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")),
Row(Row(Row(1), Row(1, 1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -1989,7 +2018,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")),
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")),
Row(Row(Row(1, 1), Row(1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -2007,18 +2036,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
test("dropFields should throw an exception because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
intercept[AnalysisException] {
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a"))
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a"))
}.getMessage should include("No such struct field A in a, B")
intercept[AnalysisException] {
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a"))
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a"))
}.getMessage should include("No such struct field b in a, B")
}
}
test("dropFields should drop only fields that exist") {
checkAnswer(
- structLevel1.withColumn("a", 'a.dropFields("d")),
+ structLevel1.withColumn("a", Symbol("a").dropFields("d")),
Row(Row(1, null, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -2028,7 +2057,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
nullable = false))))
checkAnswer(
- structLevel1.withColumn("a", 'a.dropFields("b", "d")),
+ structLevel1.withColumn("a", Symbol("a").dropFields("b", "d")),
Row(Row(1, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
@@ -2737,19 +2766,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
- assert(e.getMessage.contains("divide by zero"))
+ assert(e.getMessage.contains("Division by zero"))
val e2 = intercept[SparkException] {
Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
- assert(e2.getMessage.contains("divide by zero"))
+ assert(e2.getMessage.contains("Division by zero"))
val e3 = intercept[SparkException] {
Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
- assert(e3.getMessage.contains("divide by zero"))
+ assert(e3.getMessage.contains("Division by zero"))
}
test("SPARK-34875: divide day-time interval by numeric") {
@@ -2784,19 +2813,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
- assert(e.getMessage.contains("divide by zero"))
+ assert(e.getMessage.contains("Division by zero"))
val e2 = intercept[SparkException] {
Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
- assert(e2.getMessage.contains("divide by zero"))
+ assert(e2.getMessage.contains("Division by zero"))
val e3 = intercept[SparkException] {
Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
- assert(e3.getMessage.contains("divide by zero"))
+ assert(e3.getMessage.contains("Division by zero"))
}
test("SPARK-34896: return day-time interval from dates subtraction") {
@@ -2928,4 +2957,47 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
}
}
}
+
+ test("SPARK-36778: add ilike API for scala") {
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ // null handling
+ val nullDf = Seq("a", null).toDF("src")
+ checkAnswer(nullDf.filter($"src".ilike("A")), Row("a"))
+ checkAnswer(nullDf.filter($"src".ilike(null)), spark.emptyDataFrame)
+ // simple pattern
+ val simpleDf = Seq("a", "A", "abdef", "a_%b", "addb", "abC", "a\nb").toDF("src")
+ checkAnswer(simpleDf.filter($"src".ilike("a")), Seq("a", "A").toDF())
+ checkAnswer(simpleDf.filter($"src".ilike("A")), Seq("a", "A").toDF())
+ checkAnswer(simpleDf.filter($"src".ilike("b")), spark.emptyDataFrame)
+ checkAnswer(simpleDf.filter($"src".ilike("aBdef")), Seq("abdef").toDF())
+ checkAnswer(simpleDf.filter($"src".ilike("a\\__b")), Seq("a_%b").toDF())
+ checkAnswer(simpleDf.filter($"src".ilike("A_%b")), Seq("a_%b", "addb", "a\nb").toDF())
+ checkAnswer(simpleDf.filter($"src".ilike("a%")), simpleDf)
+ checkAnswer(simpleDf.filter($"src".ilike("a_b")), Seq("a\nb").toDF())
+ // double-escaping backslash
+ val dEscDf = Seq("""\__""", """\\__""").toDF("src")
+ checkAnswer(dEscDf.filter($"src".ilike("""\\\__""")), Seq("""\__""").toDF())
+ checkAnswer(dEscDf.filter($"src".ilike("""%\\%\%""")), spark.emptyDataFrame)
+ // unicode
+ val uncDf = Seq("a\u20ACA", "A€a", "a€AA", "a\u20ACaz", "ЀЁЂѺΏỀ").toDF("src")
+ checkAnswer(uncDf.filter($"src".ilike("_\u20AC_")), Seq("a\u20ACA", "A€a").toDF())
+ checkAnswer(uncDf.filter($"src".ilike("_€_")), Seq("a\u20ACA", "A€a").toDF())
+ checkAnswer(uncDf.filter($"src".ilike("_\u20AC_a")), Seq("a€AA").toDF())
+ checkAnswer(uncDf.filter($"src".ilike("_€_Z")), Seq("a\u20ACaz").toDF())
+ checkAnswer(uncDf.filter($"src".ilike("ѐёђѻώề")), Seq("ЀЁЂѺΏỀ").toDF())
+ // scalastyle:on
+ }
+
+ test("SPARK-39093: divide period by integral expression") {
+ val df = Seq(((Period.ofMonths(10)), 2)).toDF("pd", "num")
+ checkAnswer(df.select($"pd" / ($"num" + 3)),
+ Seq((Period.ofMonths(2))).toDF)
+ }
+
+ test("SPARK-39093: divide duration by integral expression") {
+ val df = Seq(((Duration.ofDays(10)), 2)).toDF("dd", "num")
+ checkAnswer(df.select($"dd" / ($"num" + 3)),
+ Seq((Duration.ofDays(2))).toDF)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
index 711a4bc3fd..b683f3573b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
@@ -82,16 +82,16 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
test("schema_of_csv - infers schemas") {
checkAnswer(
spark.range(1).select(schema_of_csv(lit("0.1,1"))),
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>")))
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>")))
checkAnswer(
spark.range(1).select(schema_of_csv("0.1,1")),
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>")))
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>")))
}
test("schema_of_csv - infers schemas using options") {
val df = spark.range(1)
.select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava))
- checkAnswer(df, Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>")))
+ checkAnswer(df, Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>")))
}
test("to_csv - struct") {
@@ -220,7 +220,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
val input = concat_ws(",", lit(0.1), lit(1))
checkAnswer(
spark.range(1).select(schema_of_csv(input)),
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>")))
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>")))
}
test("optional datetime parser does not affect csv time formatting") {
@@ -353,4 +353,22 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
}
}
}
+
+ test("SPARK-37326: Handle incorrectly formatted timestamp_ntz values in from_csv") {
+ val fromCsvDF = Seq("2021-08-12T15:16:23.000+11:00").toDF("csv")
+ .select(
+ from_csv(
+ $"csv",
+ StructType(StructField("a", TimestampNTZType) :: Nil),
+ Map.empty[String, String]) as "value")
+ .selectExpr("value.a")
+ checkAnswer(fromCsvDF, Row(null))
+ }
+
+ test("SPARK-38955: disable lineSep option in from_csv and schema_of_csv") {
+ val df = Seq[String]("1,2\n2").toDF("csv")
+ val actual = df.select(from_csv(
+ $"csv", schema_of_csv("1,2\n2"), Map.empty[String, String].asJava))
+ checkAnswer(actual, Row(Row(1, "2\n2")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c074a207af..c460f1e43b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -875,6 +875,11 @@ class DataFrameAggregateSuite extends QueryTest
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)
+ checkAnswer(
+ courseSales.groupBy("course").agg(max_by(col("year"), col("earnings"))),
+ Row("dotNET", 2013) :: Row("Java", 2013) :: Nil
+ )
+
checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("b") :: Nil
@@ -931,6 +936,11 @@ class DataFrameAggregateSuite extends QueryTest
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)
+ checkAnswer(
+ courseSales.groupBy("course").agg(min_by(col("year"), col("earnings"))),
+ Row("dotNET", 2012) :: Row("Java", 2012) :: Nil
+ )
+
checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
@@ -1015,10 +1025,15 @@ class DataFrameAggregateSuite extends QueryTest
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"),
Nil)
- val error = intercept[AnalysisException] {
- sql("SELECT COUNT_IF(x) FROM tempView")
+ // When ANSI mode is on, it will implicit cast the string as boolean and throw a runtime
+ // error. Here we simply test with ANSI mode off.
+ if (!conf.ansiEnabled) {
+ val error = intercept[AnalysisException] {
+ sql("SELECT COUNT_IF(x) FROM tempView")
+ }
+ assert(error.message.contains("cannot resolve 'count_if(tempview.x)' due to data type " +
+ "mismatch: argument 1 requires boolean type, however, 'tempview.x' is of string type"))
}
- assert(error.message.contains("function count_if requires boolean type"))
}
}
@@ -1125,9 +1140,11 @@ class DataFrameAggregateSuite extends QueryTest
val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col")
checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1))
- val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col")
- // Spark implicit casts string literal "a" to int to match the key type.
- checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1))
+ if (!conf.ansiEnabled) {
+ val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col")
+ // Spark implicit casts string literal "a" to int to match the key type.
+ checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1))
+ }
val arrayDF = Seq(Tuple1(Seq(1))).toDF("col")
val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count())
@@ -1261,12 +1278,12 @@ class DataFrameAggregateSuite extends QueryTest
val error = intercept[SparkException] {
checkAnswer(df2.select(sum($"year-month")), Nil)
}
- assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
+ assert(error.toString contains "SparkArithmeticException: integer overflow")
val error2 = intercept[SparkException] {
checkAnswer(df2.select(sum($"day")), Nil)
}
- assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
+ assert(error2.toString contains "SparkArithmeticException: long overflow")
}
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
@@ -1395,12 +1412,12 @@ class DataFrameAggregateSuite extends QueryTest
val error = intercept[SparkException] {
checkAnswer(df2.select(avg($"year-month")), Nil)
}
- assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
+ assert(error.toString contains "SparkArithmeticException: integer overflow")
val error2 = intercept[SparkException] {
checkAnswer(df2.select(avg($"day")), Nil)
}
- assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
+ assert(error2.toString contains "SparkArithmeticException: long overflow")
val df3 = intervalData.filter($"class" > 4)
val avgDF3 = df3.select(avg($"year-month"), avg($"day"))
@@ -1432,6 +1449,16 @@ class DataFrameAggregateSuite extends QueryTest
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id")
checkAnswer(df, Row(2, 3, 1))
}
+
+ test("SPARK-41035: Reuse of literal in distinct aggregations should work") {
+ val res = sql(
+ """select a, count(distinct 100), count(distinct b, 100)
+ |from values (1, 2), (4, 5), (4, 6) as data(a, b)
+ |group by a;
+ |""".stripMargin
+ )
+ checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil)
+ }
}
case class B(c: Option[Double])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala
new file mode 100644
index 0000000000..749efe95c5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+class DataFrameAsOfJoinSuite extends QueryTest
+ with SharedSparkSession
+ with AdaptiveSparkPlanHelper {
+
+ def prepareForAsOfJoin(): (DataFrame, DataFrame) = {
+ val schema1 = StructType(
+ StructField("a", IntegerType, false) ::
+ StructField("b", StringType, false) ::
+ StructField("left_val", StringType, false) :: Nil)
+ val rowSeq1: List[Row] = List(Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c"))
+ val df1 = spark.createDataFrame(rowSeq1.asJava, schema1)
+
+ val schema2 = StructType(
+ StructField("a", IntegerType) ::
+ StructField("b", StringType) ::
+ StructField("right_val", IntegerType) :: Nil)
+ val rowSeq2: List[Row] = List(Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3),
+ Row(6, "y", 6), Row(7, "z", 7))
+ val df2 = spark.createDataFrame(rowSeq2.asJava, schema2)
+
+ (df1, df2)
+ }
+
+ test("as-of join - simple") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(
+ df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
+ Seq(
+ Row(1, "x", "a", 1, "v", 1),
+ Row(5, "y", "b", 3, "x", 3),
+ Row(10, "z", "c", 7, "z", 7)
+ )
+ )
+ }
+
+ test("as-of join - usingColumns") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"),
+ Seq(
+ Row(10, "z", "c", 7, "z", 7)
+ )
+ )
+ }
+
+ test("as-of join - usingColumns, left outer") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"),
+ joinType = "left", tolerance = null, allowExactMatches = true, direction = "backward"),
+ Seq(
+ Row(1, "x", "a", null, null, null),
+ Row(5, "y", "b", null, null, null),
+ Row(10, "z", "c", 7, "z", 7)
+ )
+ )
+ }
+
+ test("as-of join - tolerance = 1") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = lit(1), allowExactMatches = true, direction = "backward"),
+ Seq(
+ Row(1, "x", "a", 1, "v", 1)
+ )
+ )
+ }
+
+ test("as-of join - tolerance should be a constant") {
+ val (df1, df2) = prepareForAsOfJoin()
+ val errMsg = intercept[AnalysisException] {
+ df1.joinAsOf(
+ df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true,
+ direction = "backward")
+ }.getMessage
+ assert(errMsg.contains("Input argument tolerance must be a constant."))
+ }
+
+ test("as-of join - tolerance should be non-negative") {
+ val (df1, df2) = prepareForAsOfJoin()
+ val errMsg = intercept[AnalysisException] {
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward")
+ }.getMessage
+ assert(errMsg.contains("Input argument tolerance must be non-negative."))
+ }
+
+ test("as-of join - allowExactMatches = false") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = null, allowExactMatches = false, direction = "backward"),
+ Seq(
+ Row(5, "y", "b", 3, "x", 3),
+ Row(10, "z", "c", 7, "z", 7)
+ )
+ )
+ }
+
+ test("as-of join - direction = \"forward\"") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "forward"),
+ Seq(
+ Row(1, "x", "a", 1, "v", 1),
+ Row(5, "y", "b", 6, "y", 6)
+ )
+ )
+ }
+
+ test("as-of join - direction = \"nearest\"") {
+ val (df1, df2) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty,
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "nearest"),
+ Seq(
+ Row(1, "x", "a", 1, "v", 1),
+ Row(5, "y", "b", 6, "y", 6),
+ Row(10, "z", "c", 7, "z", 7)
+ )
+ )
+ }
+
+ test("as-of join - self") {
+ val (df1, _) = prepareForAsOfJoin()
+ checkAnswer(
+ df1.joinAsOf(
+ df1, df1.col("a"), df1.col("a"), usingColumns = Seq.empty,
+ joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest"),
+ Seq(
+ Row(1, "x", "a", 5, "y", "b"),
+ Row(5, "y", "b", 1, "x", "a"),
+ Row(10, "z", "c", 5, "y", "b")
+ )
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 38b9a75dfb..697cce9b50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -247,6 +247,93 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Row(2743272264L, 2180413220L))
}
+ test("misc aes function") {
+ val key16 = "abcdefghijklmnop"
+ val key24 = "abcdefghijklmnop12345678"
+ val key32 = "abcdefghijklmnop12345678ABCDEFGH"
+ val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w=="
+ val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w=="
+ val encryptedText32 = "9J3iZbIxnmaG+OIA9Amd+A=="
+ val encryptedEmptyText16 = "jmTOhz8XTbskI/zYFFgOFQ=="
+ val encryptedEmptyText24 = "9RDK70sHNzqAFRcpfGM5gQ=="
+ val encryptedEmptyText32 = "j9IDsCvlYXtcVJUf4FAjQQ=="
+
+ val df1 = Seq("Spark", "").toDF
+
+ // Successful encryption
+ Seq(
+ (key16, encryptedText16, encryptedEmptyText16),
+ (key24, encryptedText24, encryptedEmptyText24),
+ (key32, encryptedText32, encryptedEmptyText32)).foreach {
+ case (key, encryptedText, encryptedEmptyText) =>
+ checkAnswer(
+ df1.selectExpr(s"base64(aes_encrypt(value, '$key', 'ECB'))"),
+ Seq(Row(encryptedText), Row(encryptedEmptyText)))
+ checkAnswer(
+ df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 'ECB'))"),
+ Seq(Row(encryptedText), Row(encryptedEmptyText)))
+ }
+
+ // Encryption failure - input or key is null
+ Seq(key16, key24, key32).foreach { key =>
+ checkAnswer(
+ df1.selectExpr(s"aes_encrypt(cast(null as string), '$key')"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df1.selectExpr(s"aes_encrypt(cast(null as binary), '$key')"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df1.selectExpr(s"aes_encrypt(cast(null as string), binary('$key'))"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df1.selectExpr(s"aes_encrypt(cast(null as binary), binary('$key'))"),
+ Seq(Row(null), Row(null)))
+ }
+ checkAnswer(
+ df1.selectExpr("aes_encrypt(value, cast(null as string))"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df1.selectExpr("aes_encrypt(value, cast(null as binary))"),
+ Seq(Row(null), Row(null)))
+
+ val df2 = Seq(
+ (encryptedText16, encryptedText24, encryptedText32),
+ (encryptedEmptyText16, encryptedEmptyText24, encryptedEmptyText32)
+ ).toDF("value16", "value24", "value32")
+
+ // Successful decryption
+ Seq(
+ ("value16", key16),
+ ("value24", key24),
+ ("value32", key32)).foreach {
+ case (colName, key) =>
+ checkAnswer(
+ df2.selectExpr(s"cast(aes_decrypt(unbase64($colName), '$key', 'ECB') as string)"),
+ Seq(Row("Spark"), Row("")))
+ checkAnswer(
+ df2.selectExpr(s"cast(aes_decrypt(unbase64($colName), binary('$key'), 'ECB') as string)"),
+ Seq(Row("Spark"), Row("")))
+ }
+
+ // Decryption failure - input or key is null
+ Seq(key16, key24, key32).foreach { key =>
+ checkAnswer(
+ df2.selectExpr(s"aes_decrypt(cast(null as binary), '$key')"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df2.selectExpr(s"aes_decrypt(cast(null as binary), binary('$key'))"),
+ Seq(Row(null), Row(null)))
+ }
+ Seq("value16", "value24", "value32").foreach { colName =>
+ checkAnswer(
+ df2.selectExpr(s"aes_decrypt($colName, cast(null as string))"),
+ Seq(Row(null), Row(null)))
+ checkAnswer(
+ df2.selectExpr(s"aes_decrypt($colName, cast(null as binary))"),
+ Seq(Row(null), Row(null)))
+ }
+ }
+
test("string function find_in_set") {
val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b")
@@ -394,6 +481,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
spark.sql("drop temporary function fStringLength")
}
+ test("SPARK-38130: array_sort with lambda of non-orderable items") {
+ val df6 = Seq((Array[Map[String, Int]](Map("a" -> 1), Map("b" -> 2, "c" -> 3),
+ Map()), "x")).toDF("a", "b")
+ checkAnswer(
+ df6.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"),
+ Seq(
+ Row(Seq[Map[String, Int]](Map(), Map("a" -> 1), Map("b" -> 2, "c" -> 3))))
+ )
+ }
+
test("sort_array/array_sort functions") {
val df = Seq(
(Array[Int](2, 1, 3), Array("b", "c", "a")),
@@ -482,8 +579,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
test("array size function - legacy") {
- withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
- testSizeOfArray(sizeOfNull = -1)
+ if (!conf.ansiEnabled) {
+ withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
+ testSizeOfArray(sizeOfNull = -1)
+ }
}
}
@@ -622,6 +721,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
}
+ test("SPARK-40292: arrays_zip should retain field names in nested structs") {
+ val df = spark.sql("""
+ select
+ named_struct(
+ 'arr_1', array(named_struct('a', 1, 'b', 2)),
+ 'arr_2', array(named_struct('p', 1, 'q', 2)),
+ 'field', named_struct(
+ 'arr_3', array(named_struct('x', 1, 'y', 2))
+ )
+ ) as obj
+ """)
+
+ val res = df.selectExpr("arrays_zip(obj.arr_1, obj.arr_2, obj.field.arr_3) as arr")
+
+ val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
+ .elementType.asInstanceOf[StructType].fieldNames
+ assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3"))
+ }
+
+ test("SPARK-40470: array_zip should return field names in GetArrayStructFields") {
+ val df = spark.read.json(Seq(
+ """
+ {
+ "arr": [
+ {
+ "obj": {
+ "nested": {
+ "field1": [1],
+ "field2": [2]
+ }
+ }
+ }
+ ]
+ }
+ """).toDS())
+
+ val res = df
+ .selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr")
+ .select(col("arr.field1"), col("arr.field2"))
+
+ val fieldNames = res.schema.fieldNames
+ assert(fieldNames.toSeq === Seq("field1", "field2"))
+
+ checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil)
+ }
+
+ test("SPARK-40470: arrays_zip should return field names in GetMapValue") {
+ val df = spark.sql("""
+ select
+ map(
+ 'arr_1', array(1, 2),
+ 'arr_2', array(3, 4)
+ ) as map_obj
+ """)
+
+ val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr")
+
+ val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
+ .elementType.asInstanceOf[StructType].fieldNames
+ assert(fieldNames.toSeq === Seq("arr_1", "arr_2"))
+
+ checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4))))
+ }
+
def testSizeOfMap(sizeOfNull: Any): Unit = {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
@@ -635,8 +798,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
test("map size function - legacy") {
- withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
- testSizeOfMap(sizeOfNull = -1: Int)
+ if (!conf.ansiEnabled) {
+ withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
+ testSizeOfMap(sizeOfNull = -1: Int)
+ }
}
}
@@ -930,15 +1095,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(false))
)
- val e1 = intercept[AnalysisException] {
- OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)")
- }
- val errorMsg1 =
- s"""
- |Input to function array_contains should have been array followed by a
- |value with same element type, but it's [array<int>, decimal(38,29)].
+ if (!conf.ansiEnabled) {
+ val e1 = intercept[AnalysisException] {
+ OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)")
+ }
+ val errorMsg1 =
+ s"""
+ |Input to function array_contains should have been array followed by a
+ |value with same element type, but it's [array<int>, decimal(38,29)].
""".stripMargin.replace("\n", " ").trim()
- assert(e1.message.contains(errorMsg1))
+ assert(e1.message.contains(errorMsg1))
+ }
val e2 = intercept[AnalysisException] {
OneRowRelation().selectExpr("array_contains(array(1), 'foo')")
@@ -1367,41 +1534,43 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(null), Row(null), Row(null))
)
}
- checkAnswer(
- df.select(element_at(df("a"), 4)),
- Seq(Row(null), Row(null), Row(null))
- )
- checkAnswer(
- df.select(element_at(df("a"), df("b"))),
- Seq(Row("1"), Row(""), Row(null))
- )
- checkAnswer(
- df.selectExpr("element_at(a, b)"),
- Seq(Row("1"), Row(""), Row(null))
- )
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ df.select(element_at(df("a"), 4)),
+ Seq(Row(null), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(element_at(df("a"), df("b"))),
+ Seq(Row("1"), Row(""), Row(null))
+ )
+ checkAnswer(
+ df.selectExpr("element_at(a, b)"),
+ Seq(Row("1"), Row(""), Row(null))
+ )
- checkAnswer(
- df.select(element_at(df("a"), 1)),
- Seq(Row("1"), Row(null), Row(null))
- )
- checkAnswer(
- df.select(element_at(df("a"), -1)),
- Seq(Row("3"), Row(""), Row(null))
- )
+ checkAnswer(
+ df.select(element_at(df("a"), 1)),
+ Seq(Row("1"), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(element_at(df("a"), -1)),
+ Seq(Row("3"), Row(""), Row(null))
+ )
- checkAnswer(
- df.selectExpr("element_at(a, 4)"),
- Seq(Row(null), Row(null), Row(null))
- )
+ checkAnswer(
+ df.selectExpr("element_at(a, 4)"),
+ Seq(Row(null), Row(null), Row(null))
+ )
- checkAnswer(
- df.selectExpr("element_at(a, 1)"),
- Seq(Row("1"), Row(null), Row(null))
- )
- checkAnswer(
- df.selectExpr("element_at(a, -1)"),
- Seq(Row("3"), Row(""), Row(null))
- )
+ checkAnswer(
+ df.selectExpr("element_at(a, 1)"),
+ Seq(Row("1"), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.selectExpr("element_at(a, -1)"),
+ Seq(Row("3"), Row(""), Row(null))
+ )
+ }
val e1 = intercept[AnalysisException] {
Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)")
@@ -1463,10 +1632,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row("a"))
)
- checkAnswer(
- OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
- Seq(Row(null))
- )
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
+ Seq(Row(null))
+ )
+ }
val e3 = intercept[AnalysisException] {
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')")
@@ -1541,10 +1712,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
// Simple test cases
def simpleTest(): Unit = {
- checkAnswer (
- df.select(concat($"i1", $"s1")),
- Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
- )
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ df.select(concat($"i1", $"s1")),
+ Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
+ )
+ }
checkAnswer(
df.select(concat($"i1", $"i2", $"i3")),
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
@@ -2328,7 +2501,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex3 = intercept[AnalysisException] {
df.selectExpr("transform(a, x -> x)")
}
- assert(ex3.getMessage.contains("cannot resolve 'a'"))
+ assert(ex3.getErrorClass == "MISSING_COLUMN")
+ assert(ex3.messageParameters.head == "a")
}
test("map_filter") {
@@ -2399,7 +2573,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex4 = intercept[AnalysisException] {
df.selectExpr("map_filter(a, (k, v) -> k > v)")
}
- assert(ex4.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4.getErrorClass == "MISSING_COLUMN")
+ assert(ex4.messageParameters.head == "a")
}
test("filter function - array for primitive type not containing null") {
@@ -2558,7 +2733,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex4 = intercept[AnalysisException] {
df.selectExpr("filter(a, x -> x)")
}
- assert(ex4.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4.getErrorClass == "MISSING_COLUMN")
+ assert(ex4.messageParameters.head == "a")
}
test("exists function - array for primitive type not containing null") {
@@ -2690,7 +2866,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex4 = intercept[AnalysisException] {
df.selectExpr("exists(a, x -> x)")
}
- assert(ex4.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4.getErrorClass == "MISSING_COLUMN")
+ assert(ex4.messageParameters.head == "a")
}
test("forall function - array for primitive type not containing null") {
@@ -2836,12 +3013,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex4 = intercept[AnalysisException] {
df.selectExpr("forall(a, x -> x)")
}
- assert(ex4.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4.getErrorClass == "MISSING_COLUMN")
+ assert(ex4.messageParameters.head == "a")
val ex4a = intercept[AnalysisException] {
df.select(forall(col("a"), x => x))
}
- assert(ex4a.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4a.getErrorClass == "MISSING_COLUMN")
+ assert(ex4a.messageParameters.head == "a")
}
test("aggregate function - array for primitive type not containing null") {
@@ -3018,7 +3197,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex5 = intercept[AnalysisException] {
df.selectExpr("aggregate(a, 0, (acc, x) -> x)")
}
- assert(ex5.getMessage.contains("cannot resolve 'a'"))
+ assert(ex5.getErrorClass == "MISSING_COLUMN")
+ assert(ex5.messageParameters.head == "a")
}
test("map_zip_with function - map of primitive types") {
@@ -3571,7 +3751,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val ex4 = intercept[AnalysisException] {
df.selectExpr("zip_with(a1, a, (acc, x) -> x)")
}
- assert(ex4.getMessage.contains("cannot resolve 'a'"))
+ assert(ex4.getErrorClass == "MISSING_COLUMN")
+ assert(ex4.messageParameters.head == "a")
}
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index a803fa88ed..4298d503b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
}
}
+ Seq("left_semi", "left_anti").foreach { joinType =>
+ test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
+ // aggregated dataframe
+ val ids = Seq(1, 2, 3).toDF("id").distinct()
+
+ // self-joined via joinType
+ val result = ids.withColumn("id", $"id" + 1)
+ .join(ids, usingColumns = Seq("id"), joinType = joinType).collect()
+
+ val expected = joinType match {
+ case "left_semi" => 2
+ case "left_anti" => 1
+ case _ => -1 // unsupported test type, test will always fail
+ }
+ assert(result.length == expected)
+ }
+ }
+
def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
case Filter(_, child) => extractLeftDeepInnerJoins(child)
@@ -499,4 +517,26 @@ class DataFrameJoinSuite extends QueryTest
)
}
}
+
+ test("SPARK-39376: Hide duplicated columns in star expansion of subquery alias from USING JOIN") {
+ val joinDf = testData2.as("testData2").join(
+ testData3.as("testData3"), usingColumns = Seq("a"), joinType = "fullouter")
+ val equivalentQueries = Seq(
+ joinDf.select($"*"),
+ joinDf.as("r").select($"*"),
+ joinDf.as("r").select($"r.*")
+ )
+ equivalentQueries.foreach { query =>
+ checkAnswer(query,
+ Seq(
+ Row(1, 1, null),
+ Row(1, 2, null),
+ Row(2, 1, 2),
+ Row(2, 2, 2),
+ Row(3, 1, null),
+ Row(3, 2, null)
+ )
+ )
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 20ae995af6..8dbc57c042 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -444,21 +444,25 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
}
test("replace float with nan") {
- checkAnswer(
- createNaNDF().na.replace("*", Map(
- 1.0f -> Float.NaN
- )),
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ createNaNDF().na.replace("*", Map(
+ 1.0f -> Float.NaN
+ )),
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
+ }
}
test("replace double with nan") {
- checkAnswer(
- createNaNDF().na.replace("*", Map(
- 1.0 -> Double.NaN
- )),
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ createNaNDF().na.replace("*", Map(
+ 1.0 -> Double.NaN
+ )),
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
+ }
}
test("SPARK-34417: test fillMap() for column with a dot in the name") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 32cbb8b457..1a0c95beb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import java.time.LocalDateTime
import java.util.Locale
import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst
@@ -323,17 +324,6 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession {
checkAnswer(df, expected)
}
- test("pivoting column list") {
- val exception = intercept[RuntimeException] {
- trainingSales
- .groupBy($"sales.year")
- .pivot(struct(lower($"sales.course"), $"training"))
- .agg(sum($"sales.earnings"))
- .collect()
- }
- assert(exception.getMessage.contains("Unsupported literal type"))
- }
-
test("SPARK-26403: pivoting by array column") {
val df = Seq(
(2, Seq.empty[String]),
@@ -352,4 +342,16 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession {
percentile_approx(col("value"), array(lit(0.5)), lit(10000)))
checkAnswer(actual, Row(Array(2.5), Array(3.0)))
}
+
+ test("SPARK-38133: Grouping by TIMESTAMP_NTZ should not corrupt results") {
+ checkAnswer(
+ courseSales.withColumn("ts", $"year".cast("string").cast("timestamp_ntz"))
+ .groupBy("ts")
+ .pivot("course", Seq("dotNET", "Java"))
+ .agg(sum($"earnings"))
+ .select("ts", "dotNET", "Java"),
+ Row(LocalDateTime.of(2012, 1, 1, 0, 0, 0, 0), 15000.0, 20000.0) ::
+ Row(LocalDateTime.of(2013, 1, 1, 0, 0, 0, 0), 48000.0, 30000.0) :: Nil
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index fc549e307c..917f80e581 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -63,13 +63,15 @@ class DataFrameRangeSuite extends QueryTest with SharedSparkSession with Eventua
val res7 = spark.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)
- val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
- assert(res8.count == 3)
- assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
- val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
- assert(res9.count == 2)
- assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+ if (!conf.ansiEnabled) {
+ val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ assert(res8.count == 3)
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+ val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ assert(res9.count == 2)
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+ }
// only end provided as argument
val res10 = spark.range(10).select("id")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index e520cbea48..4d0dd46b95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -481,7 +481,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
val ex = intercept[AnalysisException](
df3.join(df1, year($"df1.timeStr") === year($"df3.tsStr"))
)
- assert(ex.message.contains("cannot resolve 'df1.timeStr'"))
+ assert(ex.message.contains("Column 'df1.timeStr' does not exist."))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index 7a0cd420d4..a5414f3e80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -17,12 +17,16 @@
package org.apache.spark.sql
+import java.time.LocalDateTime
+
import org.scalatest.BeforeAndAfterEach
-import org.apache.spark.sql.catalyst.plans.logical.Expand
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types._
class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
with BeforeAndAfterEach {
@@ -79,7 +83,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
- df.groupBy(session_window($"time", "10 seconds"), 'id)
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
@@ -109,7 +113,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
- df.groupBy(session_window($"time", "10 seconds"), 'id)
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum_distinct(col("value")).as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
@@ -138,7 +142,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
- df.groupBy(session_window($"time", "10 seconds"), 'id)
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
@@ -167,7 +171,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55)
checkAnswer(
- df.groupBy(session_window($"time", "10 seconds"), 'id)
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
@@ -377,4 +381,118 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
)
}
}
+
+ test("SPARK-36465: filter out events with invalid gap duration.") {
+ val df = Seq(
+ ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id")
+
+ checkAnswer(
+ df.groupBy(session_window($"time", "x sec"))
+ .agg(count("*").as("counts"))
+ .orderBy($"session_window.start".asc)
+ .select($"session_window.start".cast("string"), $"session_window.end".cast("string"),
+ $"counts"),
+ Seq()
+ )
+
+ withTempTable { table =>
+ checkAnswer(
+ spark.sql("select session_window(time, " +
+ """case when value = 1 then "2 seconds" when value = 2 then "invalid gap duration" """ +
+ s"""else "20 seconds" end), value from $table""")
+ .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
+ $"value"),
+ Seq(
+ Row("2016-03-27 19:39:27", "2016-03-27 19:39:47", 4),
+ Row("2016-03-27 19:39:34", "2016-03-27 19:39:36", 1)
+ )
+ )
+ }
+ }
+
+ test("SPARK-36724: Support timestamp_ntz as a type of time column for SessionWindow") {
+ val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"),
+ (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id")
+ val aggDF =
+ df.groupBy(session_window($"time", "10 seconds"))
+ .agg(count("*").as("counts"))
+ .orderBy($"session_window.start".asc)
+ .select($"session_window.start".cast("string"),
+ $"session_window.end".cast("string"), $"counts")
+
+ val aggregate = aggDF.queryExecution.analyzed.children(0).children(0)
+ assert(aggregate.isInstanceOf[Aggregate])
+
+ val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0)
+ assert(timeWindow.isInstanceOf[AttributeReference])
+
+ val attributeReference = timeWindow.asInstanceOf[AttributeReference]
+ assert(attributeReference.name == "session_window")
+
+ val expectedSchema = StructType(
+ Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType)))
+ assert(attributeReference.dataType == expectedSchema)
+
+ checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2)))
+ }
+
+ test("SPARK-38227: 'start' and 'end' fields should be nullable") {
+ // We expect the fields in window struct as nullable since the dataType of SessionWindow
+ // defines them as nullable. The rule 'SessionWindowing' should respect the dataType.
+ val df1 = Seq(
+ ("hello", "2016-03-27 09:00:05", 1),
+ ("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value")
+ val df2 = Seq(
+ ("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+ ("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", "time", "value")
+
+ val udf = spark.udf.register("gapDuration", (s: String) => {
+ if (s == "hello") {
+ "1 second"
+ } else if (s == "structured") {
+ // zero gap duration will be filtered out from aggregation
+ "0 second"
+ } else if (s == "world") {
+ // negative gap duration will be filtered out from aggregation
+ "-10 seconds"
+ } else {
+ "10 seconds"
+ }
+ })
+
+ def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = {
+ schema.find(_.name == colName) match {
+ case Some(StructField(_, st: StructType, _, _)) =>
+ assertFieldInWindowStruct(st, "start")
+ assertFieldInWindowStruct(st, "end")
+
+ case _ => fail("Failed to find suitable window column from DataFrame!")
+ }
+ }
+
+ def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = {
+ val field = windowType.fields.find(_.name == fieldName)
+ assert(field.isDefined, s"'$fieldName' field should exist in window struct")
+ assert(field.get.nullable, s"'$fieldName' field should be nullable")
+ }
+
+ for {
+ df <- Seq(df1, df2)
+ nullable <- Seq(true, false)
+ } {
+ val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder(
+ StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
+ // session window without dynamic gap
+ val windowedProject = dfWithDesiredNullability
+ .select(session_window($"time", "10 seconds").as("session"), $"value")
+ val schema = windowedProject.queryExecution.optimizedPlan.schema
+ validateWindowColumnInSchema(schema, "session")
+
+ // session window with dynamic gap
+ val windowedProject2 = dfWithDesiredNullability
+ .select(session_window($"time", udf($"id")).as("session"), $"value")
+ val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
+ validateWindowColumnInSchema(schema2, "session")
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
index f8e0cfc32a..ca04adf642 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
@@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.optimizer.RemoveNoopUnion
import org.apache.spark.sql.catalyst.plans.logical.Union
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession, SQLTestData}
@@ -339,7 +341,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
).toDF("date", "timestamp", "decimal")
val widenTypedRows = Seq(
- (new Timestamp(2), 10.5D, "string")
+ (new Timestamp(2), 10.5D, "2021-01-01 00:00:00")
).toDF("date", "timestamp", "decimal")
dates.union(widenTypedRows).collect()
@@ -536,24 +538,25 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
}
test("union by name - type coercion") {
- var df1 = Seq((1, "a")).toDF("c0", "c1")
- var df2 = Seq((3, 1L)).toDF("c1", "c0")
- checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil)
-
- df1 = Seq((1, 1.0)).toDF("c0", "c1")
- df2 = Seq((8L, 3.0)).toDF("c1", "c0")
+ var df1 = Seq((1, 1.0)).toDF("c0", "c1")
+ var df2 = Seq((8L, 3.0)).toDF("c1", "c0")
checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil)
-
- df1 = Seq((2.0f, 7.4)).toDF("c0", "c1")
- df2 = Seq(("a", 4.0)).toDF("c1", "c0")
- checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil)
-
- df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2")
- df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1")
- val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0")
- checkAnswer(df1.unionByName(df2.unionByName(df3)),
- Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil
- )
+ if (!conf.ansiEnabled) {
+ df1 = Seq((1, "a")).toDF("c0", "c1")
+ df2 = Seq((3, 1L)).toDF("c1", "c0")
+ checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil)
+
+ df1 = Seq((2.0f, 7.4)).toDF("c0", "c1")
+ df2 = Seq(("a", 4.0)).toDF("c1", "c0")
+ checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil)
+
+ df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2")
+ df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1")
+ val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0")
+ checkAnswer(df1.unionByName(df2.unionByName(df3)),
+ Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil
+ )
+ }
}
test("union by name - check case sensitivity") {
@@ -802,7 +805,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
StructType(Seq(StructField("topLevelCol", nestedStructType2))))
val union = df1.unionByName(df2, allowMissingColumns = true)
- assert(union.schema.toDDL == "`topLevelCol` STRUCT<`b`: STRING, `a`: STRING>")
+ assert(union.schema.toDDL == "topLevelCol STRUCT<b: STRING, a: STRING>")
checkAnswer(union, Row(Row("b", null)) :: Row(Row("b", "a")) :: Nil)
}
@@ -834,15 +837,15 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
StructType(Seq(StructField("topLevelCol", nestedStructType2))))
var unionDf = df1.unionByName(df2, true)
- assert(unionDf.schema.toDDL == "`topLevelCol` " +
- "STRUCT<`b`: STRUCT<`ba`: STRING, `bb`: STRING>, `a`: STRUCT<`aa`: STRING>>")
+ assert(unionDf.schema.toDDL == "topLevelCol " +
+ "STRUCT<b: STRUCT<ba: STRING, bb: STRING>, a: STRUCT<aa: STRING>>")
checkAnswer(unionDf,
Row(Row(Row("ba", null), null)) ::
Row(Row(Row(null, "bb"), Row("aa"))) :: Nil)
unionDf = df2.unionByName(df1, true)
- assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRUCT<`aa`: STRING>, " +
- "`b`: STRUCT<`bb`: STRING, `ba`: STRING>>")
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<a: STRUCT<aa: STRING>, " +
+ "b: STRUCT<bb: STRING, ba: STRING>>")
checkAnswer(unionDf,
Row(Row(null, Row(null, "ba"))) ::
Row(Row(Row("aa"), Row("bb", null))) :: Nil)
@@ -999,8 +1002,9 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
}.getMessage
assert(errMsg.contains("Union can only be performed on tables with" +
" the compatible column types." +
- " struct<c1:int,c2:int,c3:struct<c3:int,c5:int>> <> struct<c1:int,c2:int,c3:struct<c3:int>>" +
- " at the third column of the second table"))
+ " The third column of the second table is struct<c1:int,c2:int,c3:struct<c3:int,c5:int>>" +
+ " type which is not compatible with struct<c1:int,c2:int,c3:struct<c3:int>> at same" +
+ " column of first table"))
// diff Case sensitive attributes names and diff sequence scenario for unionByName
df1 = Seq((1, 2, UnionClass1d(1, 2, Struct3(1)))).toDF("a", "b", "c")
@@ -1039,26 +1043,355 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
}
}
- test("SPARK-36673: Union of structs with different orders") {
+ test("SPARK-36797: Union should resolve nested columns as top-level columns") {
+ // Different nested field names, but same nested field types. Union resolves column by position.
val df1 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))
val df2 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))
- val err1 = intercept[AnalysisException](df1.union(df2).collect())
-
- assert(err1.message
- .contains("Union can only be performed on tables with the compatible column types"))
+ checkAnswer(df1.union(df2),
+ Row(0, Row(0, Row(0))) :: Row(0, Row(0, Row(0))) :: Row(1, Row(5, Row(10))) ::
+ Row(1, Row(5, Row(10))) :: Nil)
- val df3 = spark.range(2).withColumn("nested",
- struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string"))))
+ val df3 = spark.range(2).withColumn("nested array",
+ array(struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2")))))
val df4 = spark.range(2).withColumn("nested",
+ array(struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1")))))
+
+ checkAnswer(df3.union(df4),
+ Row(0, Seq(Row(0, Row(0)))) :: Row(0, Seq(Row(0, Row(0)))) :: Row(1, Seq(Row(5, Row(10)))) ::
+ Row(1, Seq(Row(5, Row(10)))) :: Nil)
+
+ val df5 = spark.range(2).withColumn("nested array",
+ map(struct(expr("id * 5 AS key1")),
+ struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2")))))
+ val df6 = spark.range(2).withColumn("nested",
+ map(struct(expr("id * 5 AS key2")),
+ struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1")))))
+
+ checkAnswer(df5.union(df6),
+ Row(0, Map(Row(0) -> Row(0, Row(0)))) ::
+ Row(0, Map(Row(0) -> Row(0, Row(0)))) ::
+ Row(1, Map(Row(5) ->Row(5, Row(10)))) ::
+ Row(1, Map(Row(5) ->Row(5, Row(10)))) :: Nil)
+
+ // Different nested field names, and different nested field types.
+ val df7 = spark.range(2).withColumn("nested",
+ struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string"))))
+ val df8 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1"))))
- val err2 = intercept[AnalysisException](df3.union(df4).collect())
- assert(err2.message
+ val err = intercept[AnalysisException](df7.union(df8).collect())
+ assert(err.message
.contains("Union can only be performed on tables with the compatible column types"))
}
+
+ test("SPARK-36546: Add unionByName support to arrays of structs") {
+ val arrayType1 = ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType),
+ StructField("bb", StringType)
+ ))
+ )
+ val arrayValues1 = Seq(Row("ba", "bb"))
+
+ val arrayType2 = ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType),
+ StructField("ba", StringType)
+ ))
+ )
+ val arrayValues2 = Seq(Row("bb", "ba"))
+
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(arrayValues1) :: Nil),
+ StructType(Seq(StructField("arr", arrayType1))))
+
+ val df2 = spark.createDataFrame(
+ sparkContext.parallelize(Row(arrayValues2) :: Nil),
+ StructType(Seq(StructField("arr", arrayType2))))
+
+ var unionDf = df1.unionByName(df2)
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<ba: STRING, bb: STRING>>")
+ checkAnswer(unionDf,
+ Row(Seq(Row("ba", "bb"))) ::
+ Row(Seq(Row("ba", "bb"))) :: Nil)
+
+ unionDf = df2.unionByName(df1)
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<bb: STRING, ba: STRING>>")
+ checkAnswer(unionDf,
+ Row(Seq(Row("bb", "ba"))) ::
+ Row(Seq(Row("bb", "ba"))) :: Nil)
+
+ val arrayType3 = ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType)
+ ))
+ )
+ val arrayValues3 = Seq(Row("ba"))
+
+ val arrayType4 = ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType)
+ ))
+ )
+ val arrayValues4 = Seq(Row("bb"))
+
+ val df3 = spark.createDataFrame(
+ sparkContext.parallelize(Row(arrayValues3) :: Nil),
+ StructType(Seq(StructField("arr", arrayType3))))
+
+ val df4 = spark.createDataFrame(
+ sparkContext.parallelize(Row(arrayValues4) :: Nil),
+ StructType(Seq(StructField("arr", arrayType4))))
+
+ assertThrows[AnalysisException] {
+ df3.unionByName(df4)
+ }
+
+ unionDf = df3.unionByName(df4, true)
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<ba: STRING, bb: STRING>>")
+ checkAnswer(unionDf,
+ Row(Seq(Row("ba", null))) ::
+ Row(Seq(Row(null, "bb"))) :: Nil)
+
+ assertThrows[AnalysisException] {
+ df4.unionByName(df3)
+ }
+
+ unionDf = df4.unionByName(df3, true)
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<bb: STRING, ba: STRING>>")
+ checkAnswer(unionDf,
+ Row(Seq(Row("bb", null))) ::
+ Row(Seq(Row(null, "ba"))) :: Nil)
+ }
+
+ test("SPARK-36546: Add unionByName support to nested arrays of structs") {
+ val nestedStructType1 = StructType(Seq(
+ StructField("b", ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType),
+ StructField("bb", StringType)
+ ))
+ ))
+ ))
+ val nestedStructValues1 = Row(Seq(Row("ba", "bb")))
+
+ val nestedStructType2 = StructType(Seq(
+ StructField("b", ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType),
+ StructField("ba", StringType)
+ ))
+ ))
+ ))
+ val nestedStructValues2 = Row(Seq(Row("bb", "ba")))
+
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType1))))
+
+ val df2 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType2))))
+
+ var unionDf = df1.unionByName(df2)
+ assert(unionDf.schema.toDDL == "topLevelCol " +
+ "STRUCT<b: ARRAY<STRUCT<ba: STRING, bb: STRING>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Row("ba", "bb")))) ::
+ Row(Row(Seq(Row("ba", "bb")))) :: Nil)
+
+ unionDf = df2.unionByName(df1)
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" +
+ "b: ARRAY<STRUCT<bb: STRING, ba: STRING>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Row("bb", "ba")))) ::
+ Row(Row(Seq(Row("bb", "ba")))) :: Nil)
+
+ val nestedStructType3 = StructType(Seq(
+ StructField("b", ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType)
+ ))
+ ))
+ ))
+ val nestedStructValues3 = Row(Seq(Row("ba")))
+
+ val nestedStructType4 = StructType(Seq(
+ StructField("b", ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType)
+ ))
+ ))
+ ))
+ val nestedStructValues4 = Row(Seq(Row("bb")))
+
+ val df3 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues3) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType3))))
+
+ val df4 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues4) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType4))))
+
+ assertThrows[AnalysisException] {
+ df3.unionByName(df4)
+ }
+
+ unionDf = df3.unionByName(df4, true)
+ assert(unionDf.schema.toDDL == "topLevelCol " +
+ "STRUCT<b: ARRAY<STRUCT<ba: STRING, bb: STRING>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Row("ba", null)))) ::
+ Row(Row(Seq(Row(null, "bb")))) :: Nil)
+
+ assertThrows[AnalysisException] {
+ df4.unionByName(df3)
+ }
+
+ unionDf = df4.unionByName(df3, true)
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" +
+ "b: ARRAY<STRUCT<bb: STRING, ba: STRING>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Row("bb", null)))) ::
+ Row(Row(Seq(Row(null, "ba")))) :: Nil)
+ }
+
+ test("SPARK-36546: Add unionByName support to multiple levels of nested arrays of structs") {
+ val nestedStructType1 = StructType(Seq(
+ StructField("b", ArrayType(
+ ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType),
+ StructField("bb", StringType)
+ ))
+ )
+ ))
+ ))
+ val nestedStructValues1 = Row(Seq(Seq(Row("ba", "bb"))))
+
+ val nestedStructType2 = StructType(Seq(
+ StructField("b", ArrayType(
+ ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType),
+ StructField("ba", StringType)
+ ))
+ )
+ ))
+ ))
+ val nestedStructValues2 = Row(Seq(Seq(Row("bb", "ba"))))
+
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType1))))
+
+ val df2 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType2))))
+
+ var unionDf = df1.unionByName(df2)
+ assert(unionDf.schema.toDDL == "topLevelCol " +
+ "STRUCT<b: ARRAY<ARRAY<STRUCT<ba: STRING, bb: STRING>>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Seq(Row("ba", "bb"))))) ::
+ Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Nil)
+
+ unionDf = df2.unionByName(df1)
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" +
+ "b: ARRAY<ARRAY<STRUCT<bb: STRING, ba: STRING>>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Seq(Row("bb", "ba"))))) ::
+ Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Nil)
+
+ val nestedStructType3 = StructType(Seq(
+ StructField("b", ArrayType(
+ ArrayType(
+ StructType(Seq(
+ StructField("ba", StringType)
+ ))
+ )
+ ))
+ ))
+ val nestedStructValues3 = Row(Seq(Seq(Row("ba"))))
+
+ val nestedStructType4 = StructType(Seq(
+ StructField("b", ArrayType(
+ ArrayType(
+ StructType(Seq(
+ StructField("bb", StringType)
+ ))
+ )
+ ))
+ ))
+ val nestedStructValues4 = Row(Seq(Seq(Row("bb"))))
+
+ val df3 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues3) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType3))))
+
+ val df4 = spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedStructValues4) :: Nil),
+ StructType(Seq(StructField("topLevelCol", nestedStructType4))))
+
+ assertThrows[AnalysisException] {
+ df3.unionByName(df4)
+ }
+
+ unionDf = df3.unionByName(df4, true)
+ assert(unionDf.schema.toDDL == "topLevelCol " +
+ "STRUCT<b: ARRAY<ARRAY<STRUCT<ba: STRING, bb: STRING>>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Seq(Row("ba", null))))) ::
+ Row(Row(Seq(Seq(Row(null, "bb"))))) :: Nil)
+
+ assertThrows[AnalysisException] {
+ df4.unionByName(df3)
+ }
+
+ unionDf = df4.unionByName(df3, true)
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" +
+ "b: ARRAY<ARRAY<STRUCT<bb: STRING, ba: STRING>>>>")
+ checkAnswer(unionDf,
+ Row(Row(Seq(Seq(Row("bb", null))))) ::
+ Row(Row(Seq(Seq(Row(null, "ba"))))) :: Nil)
+ }
+
+ test("SPARK-37371: UnionExec should support columnar if all children support columnar") {
+ def checkIfColumnar(
+ plan: SparkPlan,
+ targetPlan: (SparkPlan) => Boolean,
+ isColumnar: Boolean): Unit = {
+ val target = plan.collect {
+ case p if targetPlan(p) => p
+ }
+ assert(target.nonEmpty)
+ assert(target.forall(_.supportsColumnar == isColumnar))
+ }
+
+ Seq(true, false).foreach { supported =>
+ withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> supported.toString) {
+ val df1 = Seq(1, 2, 3).toDF("i").cache()
+ val df2 = Seq(4, 5, 6).toDF("j").cache()
+
+ val union = df1.union(df2)
+ checkIfColumnar(union.queryExecution.executedPlan,
+ _.isInstanceOf[InMemoryTableScanExec], supported)
+ checkIfColumnar(union.queryExecution.executedPlan,
+ _.isInstanceOf[InMemoryTableScanExec], supported)
+ checkIfColumnar(union.queryExecution.executedPlan, _.isInstanceOf[UnionExec], supported)
+ checkAnswer(union, Row(1) :: Row(2) :: Row(3) :: Row(4) :: Row(5) :: Row(6) :: Nil)
+
+ val nonColumnarUnion = df1.union(Seq(7, 8, 9).toDF("k"))
+ checkIfColumnar(nonColumnarUnion.queryExecution.executedPlan,
+ _.isInstanceOf[UnionExec], false)
+ checkAnswer(nonColumnarUnion,
+ Row(1) :: Row(2) :: Row(3) :: Row(7) :: Row(8) :: Row(9) :: Nil)
+ }
+ }
+ }
}
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d9b75c7794..a696c3fd49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
@@ -86,7 +86,9 @@ class DataFrameSuite extends QueryTest
test("access complex data") {
assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
- assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
+ if (!conf.ansiEnabled) {
+ assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
+ }
assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1)
}
@@ -480,7 +482,7 @@ class DataFrameSuite extends QueryTest
testData.select("key").coalesce(1).select("key"),
testData.select("key").collect().toSeq)
- assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0)
+ assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1)
}
test("convert $\"attribute name\" into unresolved attribute") {
@@ -631,7 +633,19 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
- test("withColumns") {
+ test("withColumns: public API, with Map input") {
+ val df = testData.toDF().withColumns(Map(
+ "newCol1" -> (col("key") + 1), "newCol2" -> (col("key") + 2)
+ ))
+ checkAnswer(
+ df,
+ testData.collect().map { case Row(key: Int, value: String) =>
+ Row(key, value, key + 1, key + 2)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+ }
+
+ test("withColumns: internal method") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
@@ -655,7 +669,7 @@ class DataFrameSuite extends QueryTest
assert(err2.getMessage.contains("Found duplicate column(s)"))
}
- test("withColumns: case sensitive") {
+ test("withColumns: internal method, case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
@@ -674,7 +688,7 @@ class DataFrameSuite extends QueryTest
}
}
- test("withColumns: given metadata") {
+ test("withColumns: internal method, given metadata") {
def buildMetadata(num: Int): Seq[Metadata] = {
(0 until num).map { n =>
val builder = new MetadataBuilder
@@ -702,6 +716,18 @@ class DataFrameSuite extends QueryTest
"The size of column names: 2 isn't equal to the size of metadata elements: 1"))
}
+ test("SPARK-36642: withMetadata: replace metadata of a column") {
+ val metadata = new MetadataBuilder().putLong("key", 1L).build()
+ val df1 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
+ val df2 = df1.withMetadata("x", metadata)
+ assert(df2.schema(0).metadata === metadata)
+
+ val err = intercept[AnalysisException] {
+ df1.withMetadata("x1", metadata)
+ }
+ assert(err.getMessage.contains("Cannot resolve column name"))
+ }
+
test("replace column using withColumn") {
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
@@ -834,6 +860,56 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
+ test("SPARK-20384: Value class filter") {
+ val df = spark.sparkContext
+ .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
+ .toDF()
+ val filtered = df.where("s = \"a\"")
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
+ }
+
+ test("SPARK-20384: Tuple2 of value class filter") {
+ val df = spark.sparkContext
+ .parallelize(Seq(
+ (StringWrapper("a1"), StringWrapper("a2")),
+ (StringWrapper("b1"), StringWrapper("b2"))))
+ .toDF()
+ val filtered = df.where("_2.s = \"a2\"")
+ checkAnswer(filtered,
+ spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF)
+ }
+
+ test("SPARK-20384: Tuple3 of value class filter") {
+ val df = spark.sparkContext
+ .parallelize(Seq(
+ (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")),
+ (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3"))))
+ .toDF()
+ val filtered = df.where("_3.s = \"a3\"")
+ checkAnswer(filtered,
+ spark.sparkContext.parallelize(
+ Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF)
+ }
+
+ test("SPARK-20384: Array value class filter") {
+ val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
+ val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))
+
+ val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
+ val filtered = df.where(array_contains(col("wrappers.s"), "b"))
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
+ }
+
+ test("SPARK-20384: Nested value class filter") {
+ val a = ContainerStringWrapper(StringWrapper("a"))
+ val b = ContainerStringWrapper(StringWrapper("b"))
+
+ val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
+ // flat value class, `s` field is not in schema
+ val filtered = df.where("wrapper = \"a\"")
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
+ }
+
private lazy val person2: DataFrame = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
@@ -1489,7 +1565,9 @@ class DataFrameSuite extends QueryTest
test("SPARK-7133: Implement struct, array, and map field accessor") {
assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
- assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ if (!conf.ansiEnabled) {
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ }
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1)
assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1)
@@ -2384,8 +2462,10 @@ class DataFrameSuite extends QueryTest
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
checkAnswer(aggPlusSort1, aggPlusSort2.collect())
- val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
- val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
+ val aggPlusFilter1 =
+ df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === "test1")
+ val aggPlusFilter2 =
+ df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === "test1")
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
}
}
@@ -2569,7 +2649,8 @@ class DataFrameSuite extends QueryTest
val err = intercept[AnalysisException] {
df.groupBy($"d", $"b").as[GroupByKey, Row]
}
- assert(err.getMessage.contains("cannot resolve 'd'"))
+ assert(err.getErrorClass == "MISSING_COLUMN")
+ assert(err.messageParameters.head == "d")
}
test("emptyDataFrame should be foldable") {
@@ -2852,6 +2933,25 @@ class DataFrameSuite extends QueryTest
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
}
+ test("SPARK-39293: The accumulator of ArrayAggregate to handle complex types properly") {
+ val reverse = udf((s: String) => s.reverse)
+
+ val df = Seq(Array("abc", "def")).toDF("array")
+ val testArray = df.select(
+ aggregate(
+ col("array"),
+ array().cast("array<string>"),
+ (acc, s) => concat(acc, array(reverse(s)))))
+ checkAnswer(testArray, Row(Array("cba", "fed")) :: Nil)
+
+ val testMap = df.select(
+ aggregate(
+ col("array"),
+ map().cast("map<string, string>"),
+ (acc, s) => map_concat(acc, map(s, reverse(s)))))
+ checkAnswer(testMap, Row(Map("abc" -> "cba", "def" -> "fed")) :: Nil)
+ }
+
test("SPARK-34882: Aggregate with multiple distinct null sensitive aggregators") {
withUserDefinedFunction(("countNulls", true)) {
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong, JLong] {
@@ -2952,6 +3052,33 @@ class DataFrameSuite extends QueryTest
assert(ids.toSet === Range(0, 10).toSet)
}
+ test("SPARK-35320: Reading JSON with key type different to String in a map should fail") {
+ Seq(
+ (MapType(IntegerType, StringType), """{"1": "test"}"""),
+ (StructType(Seq(StructField("test", MapType(IntegerType, StringType)))),
+ """"test": {"1": "test"}"""),
+ (ArrayType(MapType(IntegerType, StringType)), """[{"1": "test"}]"""),
+ (MapType(StringType, MapType(IntegerType, StringType)), """{"key": {"1" : "test"}}""")
+ ).foreach { case (schema, jsonData) =>
+ withTempDir { dir =>
+ val colName = "col"
+ val msg = "can only contain STRING as a key type for a MAP"
+
+ val thrown1 = intercept[AnalysisException](
+ spark.read.schema(StructType(Seq(StructField(colName, schema))))
+ .json(Seq(jsonData).toDS()).collect())
+ assert(thrown1.getMessage.contains(msg))
+
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ Seq(jsonData).toDF(colName).write.json(jsonDir)
+ val thrown2 = intercept[AnalysisException](
+ spark.read.schema(StructType(Seq(StructField(colName, schema))))
+ .json(jsonDir).collect())
+ assert(thrown2.getMessage.contains(msg))
+ }
+ }
+ }
+
test("SPARK-37855: IllegalStateException when transforming an array inside a nested struct") {
def makeInput(): DataFrame = {
val innerElement1 = Row(3, 3.12)
@@ -3088,6 +3215,79 @@ class DataFrameSuite extends QueryTest
}
}
}
+
+ test("SPARK-39612: exceptAll with following count should work") {
+ val d1 = Seq("a").toDF
+ assert(d1.exceptAll(d1).count() === 0)
+ }
+
+ test("SPARK-39887: RemoveRedundantAliases should keep attributes of a Union's first child") {
+ val df = sql(
+ """
+ |SELECT a, b AS a FROM (
+ | SELECT a, a AS b FROM (SELECT a FROM VALUES (1) AS t(a))
+ | UNION ALL
+ | SELECT a, b FROM (SELECT a, b FROM VALUES (1, 2) AS t(a, b))
+ |)
+ |""".stripMargin)
+ val stringCols = df.logicalPlan.output.map(Column(_).cast(StringType))
+ val castedDf = df.select(stringCols: _*)
+ checkAnswer(castedDf, Row("1", "1") :: Row("1", "2") :: Nil)
+ }
+
+ test("SPARK-39887: RemoveRedundantAliases should keep attributes of a Union's first child 2") {
+ val df = sql(
+ """
+ |SELECT
+ | to_date(a) a,
+ | to_date(b) b
+ |FROM
+ | (
+ | SELECT
+ | a,
+ | a AS b
+ | FROM
+ | (
+ | SELECT
+ | to_date(a) a
+ | FROM
+ | VALUES
+ | ('2020-02-01') AS t1(a)
+ | GROUP BY
+ | to_date(a)
+ | ) t3
+ | UNION ALL
+ | SELECT
+ | a,
+ | b
+ | FROM
+ | (
+ | SELECT
+ | to_date(a) a,
+ | to_date(b) b
+ | FROM
+ | VALUES
+ | ('2020-01-01', '2020-01-02') AS t1(a, b)
+ | GROUP BY
+ | to_date(a),
+ | to_date(b)
+ | ) t4
+ | ) t5
+ |GROUP BY
+ | to_date(a),
+ | to_date(b);
+ |""".stripMargin)
+ checkAnswer(df,
+ Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) ::
+ Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil)
+ }
+
+ test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2)
+ assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
+ }
+ }
}
case class GroupByKey(a: Int, b: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index c385d9f58c..bd39453f51 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql
import java.time.LocalDateTime
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -490,4 +491,88 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
assert(attributeReference.dataType == tuple._2)
}
}
+
+ test("No need to filter windows when windowDuration is multiple of slideDuration") {
+ val df1 = Seq(
+ ("2022-02-15 19:39:34", 1, "a"),
+ ("2022-02-15 19:39:56", 2, "a"),
+ ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id")
+ .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value")
+ .orderBy($"window.start".asc, $"value".desc).select("value")
+ val df2 = Seq(
+ (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"),
+ (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"),
+ (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id")
+ .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value")
+ .orderBy($"window.start".asc, $"value".desc).select("value")
+
+ val df3 = Seq(
+ ("2022-02-15 19:39:34", 1, "a"),
+ ("2022-02-15 19:39:56", 2, "a"),
+ ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id")
+ .select(window($"time", "9 seconds", "3 seconds", "-2 second"), $"value")
+ .orderBy($"window.start".asc, $"value".desc).select("value")
+ val df4 = Seq(
+ (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"),
+ (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"),
+ (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id")
+ .select(window($"time", "9 seconds", "3 seconds", "2 second"), $"value")
+ .orderBy($"window.start".asc, $"value".desc).select("value")
+
+ Seq(df1, df2, df3, df4).foreach { df =>
+ val filter = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Filter])
+ assert(filter.isDefined)
+ val exist = filter.get.constraints.filter(e =>
+ e.toString.contains(">=") || e.toString.contains("<"))
+ assert(exist.isEmpty, "No need to filter windows " +
+ "when windowDuration is multiple of slideDuration")
+ }
+ }
+
+ test("SPARK-38227: 'start' and 'end' fields should be nullable") {
+ // We expect the fields in window struct as nullable since the dataType of TimeWindow defines
+ // them as nullable. The rule 'TimeWindowing' should respect the dataType.
+ val df1 = Seq(
+ ("2016-03-27 09:00:05", 1),
+ ("2016-03-27 09:00:32", 2)).toDF("time", "value")
+ val df2 = Seq(
+ (LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+ (LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value")
+
+ def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = {
+ schema.find(_.name == colName) match {
+ case Some(StructField(_, st: StructType, _, _)) =>
+ assertFieldInWindowStruct(st, "start")
+ assertFieldInWindowStruct(st, "end")
+
+ case _ => fail("Failed to find suitable window column from DataFrame!")
+ }
+ }
+
+ def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = {
+ val field = windowType.fields.find(_.name == fieldName)
+ assert(field.isDefined, s"'$fieldName' field should exist in window struct")
+ assert(field.get.nullable, s"'$fieldName' field should be nullable")
+ }
+
+ for {
+ df <- Seq(df1, df2)
+ nullable <- Seq(true, false)
+ } {
+ val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder(
+ StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
+ // tumbling windows
+ val windowedProject = dfWithDesiredNullability
+ .select(window($"time", "10 seconds").as("window"), $"value")
+ val schema = windowedProject.queryExecution.optimizedPlan.schema
+ validateWindowColumnInSchema(schema, "window")
+
+ // sliding windows
+ val windowedProject2 = dfWithDesiredNullability
+ .select(window($"time", "10 seconds", "3 seconds").as("window"),
+ $"value")
+ val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
+ validateWindowColumnInSchema(schema2, "window")
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 666bf739ca..e57650ff62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -20,9 +20,12 @@ package org.apache.spark.sql
import org.scalatest.matchers.must.Matchers.the
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -94,7 +97,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
test("corr, covar_pop, stddev_pop functions in specific window") {
- withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
+ withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true",
+ SQLConf.ANSI_ENABLED.key -> "false") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
@@ -147,7 +151,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest
test("SPARK-13860: " +
"corr, covar_pop, stddev_pop functions in specific window " +
"LEGACY_STATISTICAL_AGGREGATE off") {
- withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") {
+ withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "false") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
@@ -399,26 +404,29 @@ class DataFrameWindowFunctionsSuite extends QueryTest
val df = Seq((1, "1")).toDF("key", "value")
val e = intercept[AnalysisException](
df.select($"key", count("invalid").over()))
- assert(e.message.contains("cannot resolve 'invalid' given input columns: [key, value]"))
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("invalid", "value, key")))
}
test("numerical aggregate functions on string column") {
- val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
- checkAnswer(
- df.select($"key",
- var_pop("value1").over(),
- variance("value1").over(),
- stddev_pop("value1").over(),
- stddev("value1").over(),
- sum("value1").over(),
- mean("value1").over(),
- avg("value1").over(),
- corr("value1", "value2").over(),
- covar_pop("value1", "value2").over(),
- covar_samp("value1", "value2").over(),
- skewness("value1").over(),
- kurtosis("value1").over()),
- Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null)))
+ if (!conf.ansiEnabled) {
+ val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
+ checkAnswer(
+ df.select($"key",
+ var_pop("value1").over(),
+ variance("value1").over(),
+ stddev_pop("value1").over(),
+ stddev("value1").over(),
+ sum("value1").over(),
+ mean("value1").over(),
+ avg("value1").over(),
+ corr("value1", "value2").over(),
+ covar_pop("value1", "value2").over(),
+ covar_samp("value1", "value2").over(),
+ skewness("value1").over(),
+ kurtosis("value1").over()),
+ Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null)))
+ }
}
test("statistical functions") {
@@ -702,6 +710,50 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row("a", 4, "x", "x", "y", "x", "x", "y"),
Row("b", 1, null, null, null, null, null, null),
Row("b", 2, null, null, null, null, null, null)))
+
+ val df2 = Seq(
+ ("a", 1, "x"),
+ ("a", 2, "y"),
+ ("a", 3, "z")).
+ toDF("key", "order", "value")
+ checkAnswer(
+ df2.select(
+ $"key",
+ $"order",
+ nth_value($"value", 2).over(window1),
+ nth_value($"value", 2, ignoreNulls = true).over(window1),
+ nth_value($"value", 2).over(window2),
+ nth_value($"value", 2, ignoreNulls = true).over(window2),
+ nth_value($"value", 3).over(window1),
+ nth_value($"value", 3, ignoreNulls = true).over(window1),
+ nth_value($"value", 3).over(window2),
+ nth_value($"value", 3, ignoreNulls = true).over(window2),
+ nth_value($"value", 4).over(window1),
+ nth_value($"value", 4, ignoreNulls = true).over(window1),
+ nth_value($"value", 4).over(window2),
+ nth_value($"value", 4, ignoreNulls = true).over(window2)),
+ Seq(
+ Row("a", 1, "y", "y", null, null, "z", "z", null, null, null, null, null, null),
+ Row("a", 2, "y", "y", "y", "y", "z", "z", null, null, null, null, null, null),
+ Row("a", 3, "y", "y", "y", "y", "z", "z", "z", "z", null, null, null, null)))
+
+ val df3 = Seq(
+ ("a", 1, "x"),
+ ("a", 2, nullStr),
+ ("a", 3, "z")).
+ toDF("key", "order", "value")
+ checkAnswer(
+ df3.select(
+ $"key",
+ $"order",
+ nth_value($"value", 3).over(window1),
+ nth_value($"value", 3, ignoreNulls = true).over(window1),
+ nth_value($"value", 3).over(window2),
+ nth_value($"value", 3, ignoreNulls = true).over(window2)),
+ Seq(
+ Row("a", 1, "z", null, null, null),
+ Row("a", 2, "z", null, null, null),
+ Row("a", 3, "z", null, "z", null)))
}
test("nth_value on descending ordered window") {
@@ -1070,4 +1122,98 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row("a", 1, "x", "x"),
Row("b", 0, null, null)))
}
+
+ test("SPARK-38237: require all cluster keys for child required distribution for window query") {
+ def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = {
+ expressions.flatMap {
+ case ref: AttributeReference => Some(ref.name)
+ }
+ }
+
+ def isShuffleExecByRequirement(
+ plan: ShuffleExchangeExec,
+ desiredClusterColumns: Seq[String]): Boolean = plan match {
+ case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) =>
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+ case _ => false
+ }
+
+ val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value")
+ val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")
+
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {
+
+ val windowed = df
+ // repartition by subset of window partitionBy keys which satisfies ClusteredDistribution
+ .repartition($"key1")
+ .select(
+ lead($"key1", 1).over(windowSpec),
+ lead($"value", 1).over(windowSpec))
+
+ checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), Row(null, null)))
+
+ val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
+ case w: WindowExec =>
+ w.child.exists {
+ case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2"))
+ case _ => false
+ }
+ case _ => false
+ }
+
+ assert(shuffleByRequirement, "Can't find desired shuffle node from the query plan")
+ }
+ }
+
+ test("SPARK-38308: Properly handle Stream of window expressions") {
+ val df = Seq(
+ (1, 2, 3),
+ (1, 3, 4),
+ (2, 4, 5),
+ (2, 5, 6)
+ ).toDF("a", "b", "c")
+
+ val w = Window.partitionBy("a").orderBy("b")
+ val selectExprs = Stream(
+ sum("c").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)).as("sumc"),
+ avg("c").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)).as("avgc")
+ )
+ checkAnswer(
+ df.select(selectExprs: _*),
+ Seq(
+ Row(3, 3),
+ Row(7, 3.5),
+ Row(5, 5),
+ Row(11, 5.5)
+ )
+ )
+ }
+
+ test("SPARK-38614: percent_rank should apply before limit") {
+ val df = Seq.tabulate(101)(identity).toDF("id")
+ val w = Window.orderBy("id")
+ checkAnswer(
+ df.select($"id", percent_rank().over(w)).limit(3),
+ Seq(
+ Row(0, 0.0d),
+ Row(1, 0.01d),
+ Row(2, 0.02d)
+ )
+ )
+ }
+
+ test("SPARK-40002: ntile should apply before limit") {
+ val df = Seq.tabulate(101)(identity).toDF("id")
+ val w = Window.orderBy("id")
+ checkAnswer(
+ df.select($"id", ntile(10).over(w)).limit(3),
+ Seq(
+ Row(0, 1),
+ Row(1, 1),
+ Row(2, 1)
+ )
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
index 8aef27a1b6..86108a81da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
@@ -23,12 +23,15 @@ import scala.collection.JavaConverters._
import org.scalatest.BeforeAndAfter
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
+import org.apache.spark.sql.connector.InMemoryV1Provider
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.FakeSourceOne
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType}
@@ -531,6 +534,23 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
assert(table.properties === (Map("provider" -> "foo") ++ defaultOwnership).asJava)
}
+ test("SPARK-39543 writeOption should be passed to storage properties when fallback to v1") {
+ val provider = classOf[InMemoryV1Provider].getName
+
+ withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, provider)) {
+ spark.range(10)
+ .writeTo("table_name")
+ .option("compression", "zstd").option("name", "table_name")
+ .using(provider)
+ .create()
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("table_name"))
+
+ assert(table.identifier === TableIdentifier("table_name", Some("default")))
+ assert(table.storage.properties.contains("compression"))
+ assert(table.storage.properties.getOrElse("compression", "foo") == "zstd")
+ }
+ }
+
test("Replace: basic behavior") {
spark.sql(
"CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 009ccb9a45..2f4098d7cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -250,7 +250,7 @@ class DatasetCacheSuite extends QueryTest
case i: InMemoryRelation => i.cacheBuilder.cachedPlan
}
assert(df2LimitInnerPlan.isDefined &&
- df2LimitInnerPlan.get.find(_.isInstanceOf[InMemoryTableScanExec]).isEmpty)
+ !df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec]))
}
test("SPARK-27739 Save stats from optimized plan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 347e9fc08a..f5e736621e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,11 +20,14 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
+import org.apache.hadoop.fs.{Path, PathFilter}
import org.scalatest.Assertions._
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.TestUtils.withListener
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
@@ -322,23 +325,25 @@ class DatasetSuite extends QueryTest
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
var e = intercept[AnalysisException] {
ds.select(expr("`(_1)?+.+`").as[Int])
- }.getMessage
- assert(e.contains("cannot resolve '`(_1)?+.+`'"))
+ }
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.head == "`(_1)?+.+`")
e = intercept[AnalysisException] {
ds.select(expr("`(_1|_2)`").as[Int])
- }.getMessage
- assert(e.contains("cannot resolve '`(_1|_2)`'"))
+ }
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.head == "`(_1|_2)`")
e = intercept[AnalysisException] {
ds.select(ds("`(_1)?+.+`"))
- }.getMessage
- assert(e.contains("Cannot resolve column name \"`(_1)?+.+`\""))
+ }
+ assert(e.getMessage.contains("Cannot resolve column name \"`(_1)?+.+`\""))
e = intercept[AnalysisException] {
ds.select(ds("`(_1|_2)`"))
- }.getMessage
- assert(e.contains("Cannot resolve column name \"`(_1|_2)`\""))
+ }
+ assert(e.getMessage.contains("Cannot resolve column name \"`(_1|_2)`\""))
}
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") {
@@ -703,6 +708,72 @@ class DatasetSuite extends QueryTest
1 -> "a", 2 -> "bc", 3 -> "d")
}
+ test("SPARK-34806: observation on datasets") {
+ val namedObservation = Observation("named")
+ val unnamedObservation = Observation()
+
+ val df = spark.range(100)
+ val observed_df = df
+ .observe(
+ namedObservation,
+ min($"id").as("min_val"),
+ max($"id").as("max_val"),
+ sum($"id").as("sum_val"),
+ count(when($"id" % 2 === 0, 1)).as("num_even")
+ )
+ .observe(
+ unnamedObservation,
+ avg($"id").cast("int").as("avg_val")
+ )
+
+ def checkMetrics(namedMetric: Observation, unnamedMetric: Observation): Unit = {
+ assert(namedMetric.get === Map(
+ "min_val" -> 0L, "max_val" -> 99L, "sum_val" -> 4950L, "num_even" -> 50L)
+ )
+ assert(unnamedMetric.get === Map("avg_val" -> 49))
+ }
+
+ observed_df.collect()
+ // we can get the result multiple times
+ checkMetrics(namedObservation, unnamedObservation)
+ checkMetrics(namedObservation, unnamedObservation)
+
+ // an observation can be used only once
+ val err = intercept[IllegalArgumentException] {
+ df.observe(namedObservation, sum($"id").as("sum_val"))
+ }
+ assert(err.getMessage.contains("An Observation can be used with a Dataset only once"))
+
+ // streaming datasets are not supported
+ val streamDf = new MemoryStream[Int](0, sqlContext).toDF()
+ val streamObservation = Observation("stream")
+ val streamErr = intercept[IllegalArgumentException] {
+ streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val"))
+ }
+ assert(streamErr.getMessage.contains("Observation does not support streaming Datasets"))
+
+ // an observation cannot have an empty name
+ val err2 = intercept[IllegalArgumentException] {
+ Observation("")
+ }
+ assert(err2.getMessage.contains("Name must not be empty"))
+ }
+
+ test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") {
+ def observe[T](df: Dataset[T], expected: Map[String, _]): Unit = {
+ val namedObservation = Observation("named")
+ val observed_df = df.observe(
+ namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val"))
+ observed_df.collect()
+ assert(namedObservation.get === expected)
+ }
+
+ observe(spark.range(100), Map("percentile_approx_val" -> 49))
+ observe(spark.range(0), Map("percentile_approx_val" -> null))
+ observe(spark.range(1, 10), Map("percentile_approx_val" -> 5))
+ observe(spark.range(1, 10, 1, 11), Map("percentile_approx_val" -> 5))
+ }
+
test("sample with replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
@@ -866,7 +937,8 @@ class DatasetSuite extends QueryTest
val e = intercept[AnalysisException] {
ds.as[ClassData2]
}
- assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("c", "a, b")))
}
test("runtime nullability check") {
@@ -933,24 +1005,6 @@ class DatasetSuite extends QueryTest
checkDataset(cogrouped, "a13", "b24")
}
- test("give nice error message when the real number of fields doesn't match encoder schema") {
- val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
-
- val message = intercept[AnalysisException] {
- ds.as[(String, Int, Long)]
- }.message
- assert(message ==
- "Try to map struct<a:string,b:int> to Tuple3, " +
- "but failed as the number of fields does not line up.")
-
- val message2 = intercept[AnalysisException] {
- ds.as[Tuple1[String]]
- }.message
- assert(message2 ==
- "Try to map struct<a:string,b:int> to Tuple1, " +
- "but failed as the number of fields does not line up.")
- }
-
test("SPARK-13440: Resolving option fields") {
val df = Seq(1, 2, 3).toDS()
val ds = df.as[Option[Int]]
@@ -1392,8 +1446,29 @@ class DatasetSuite extends QueryTest
}
testCheckpointing("basic") {
- val ds = spark.range(10).repartition($"id" % 2).filter($"id" > 5).orderBy($"id".desc)
- val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)
+ val ds = spark
+ .range(10)
+ // Num partitions is set to 1 to avoid a RangePartitioner in the orderBy below
+ .repartition(1, $"id" % 2)
+ .filter($"id" > 5)
+ .orderBy($"id".desc)
+ @volatile var jobCounter = 0
+ val listener = new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ jobCounter += 1
+ }
+ }
+ var cp = ds
+ withListener(spark.sparkContext, listener) { _ =>
+ // AQE adds a job per shuffle. The expression above does multiple shuffles and
+ // that screws up the job counting
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)
+ }
+ }
+ if (eager) {
+ assert(jobCounter === 1)
+ }
val logicalRDD = cp.logicalPlan match {
case plan: LogicalRDD => plan
@@ -1861,7 +1936,7 @@ class DatasetSuite extends QueryTest
.map(b => b - 1)
.collect()
}
- assert(thrownException.message.contains("Cannot up cast id from bigint to tinyint"))
+ assert(thrownException.message.contains("""Cannot up cast id from "BIGINT" to "TINYINT""""))
}
test("SPARK-26690: checkpoints should be executed with an execution id") {
@@ -2060,6 +2135,23 @@ class DatasetSuite extends QueryTest
(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13))
}
+
+ test("SPARK-40407: repartition should not result in severe data skew") {
+ val df = spark.range(0, 100, 1, 50).repartition(4)
+ val result = df.mapPartitions(iter => Iterator.single(iter.length)).collect()
+ assert(result.sorted.toSeq === Seq(23, 25, 25, 27))
+ }
+
+ test("SPARK-40660: Switch to XORShiftRandom to distribute elements") {
+ withTempDir { dir =>
+ spark.range(10).repartition(10).write.mode(SaveMode.Overwrite).parquet(dir.getCanonicalPath)
+ val fs = new Path(dir.getAbsolutePath).getFileSystem(spark.sessionState.newHadoopConf())
+ val parquetFiles = fs.listStatus(new Path(dir.getAbsolutePath), new PathFilter {
+ override def accept(path: Path): Boolean = path.getName.endsWith("parquet")
+ })
+ assert(parquetFiles.size === 10)
+ }
+ }
}
case class Bar(a: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index 543f845aff..fa246fa79b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -23,7 +23,7 @@ import java.time.{Instant, LocalDateTime, ZoneId}
import java.util.{Locale, TimeZone}
import java.util.concurrent.TimeUnit
-import org.apache.spark.{SparkException, SparkUpgradeException}
+import org.apache.spark.{SparkConf, SparkException, SparkUpgradeException}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CEST, LA}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
@@ -35,6 +35,10 @@ import org.apache.spark.unsafe.types.CalendarInterval
class DateFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
+ // The test cases which throw exceptions under ANSI mode are covered by date.sql and
+ // datetime-parsing-invalid.sql in org.apache.spark.sql.SQLQueryTestSuite.
+ override def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false")
+
test("function current_date") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
val d0 = DateTimeUtils.currentDate(ZoneId.systemDefault())
@@ -512,7 +516,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(null), Row(null), Row(null)))
val e = intercept[SparkUpgradeException](df.select(to_date(col("s"), "yyyy-dd-aa")).collect())
assert(e.getCause.isInstanceOf[IllegalArgumentException])
- assert(e.getMessage.contains("You may get a different result due to the upgrading of Spark"))
+ assert(e.getMessage.contains("You may get a different result due to the upgrading to Spark"))
// February
val x1 = "2016-02-29"
@@ -695,7 +699,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession {
val e = intercept[SparkUpgradeException](invalid.collect())
assert(e.getCause.isInstanceOf[IllegalArgumentException])
assert(
- e.getMessage.contains("You may get a different result due to the upgrading of Spark"))
+ e.getMessage.contains("You may get a different result due to the upgrading to Spark"))
}
// February
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index 1033d43a1b..d5498c469c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -31,14 +31,14 @@ import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
/**
* Test suite for the filtering ratio policy used to trigger dynamic partition pruning (DPP).
*/
abstract class DynamicPartitionPruningSuiteBase
extends QueryTest
- with SharedSparkSession
+ with SQLTestUtils
with GivenWhenThen
with AdaptiveSparkPlanHelper {
@@ -49,7 +49,7 @@ abstract class DynamicPartitionPruningSuiteBase
protected def initState(): Unit = {}
protected def runAnalyzeColumnCommands: Boolean = true
- override def beforeAll(): Unit = {
+ override protected def beforeAll(): Unit = {
super.beforeAll()
initState()
@@ -107,6 +107,10 @@ abstract class DynamicPartitionPruningSuiteBase
(6, 60)
)
+ if (tableFormat == "hive") {
+ spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
+ }
+
spark.range(1000)
.select($"id" as "product_id", ($"id" % 10) as "store_id", ($"id" + 1) as "code")
.write
@@ -150,11 +154,12 @@ abstract class DynamicPartitionPruningSuiteBase
if (runAnalyzeColumnCommands) {
sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id")
sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id")
+ sql("ANALYZE TABLE dim_store COMPUTE STATISTICS FOR COLUMNS store_id")
sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id")
}
}
- override def afterAll(): Unit = {
+ override protected def afterAll(): Unit = {
try {
sql("DROP TABLE IF EXISTS fact_np")
sql("DROP TABLE IF EXISTS fact_sk")
@@ -183,11 +188,11 @@ abstract class DynamicPartitionPruningSuiteBase
val plan = df.queryExecution.executedPlan
val dpExprs = collectDynamicPruningExpressions(plan)
val hasSubquery = dpExprs.exists {
- case InSubqueryExec(_, _: SubqueryExec, _, _) => true
+ case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true
case _ => false
}
val subqueryBroadcast = dpExprs.collect {
- case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _) => b
+ case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => b
}
val hasFilter = if (withSubquery) "Should" else "Shouldn't"
@@ -202,10 +207,10 @@ abstract class DynamicPartitionPruningSuiteBase
case _: ReusedExchangeExec => // reuse check ok.
case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok.
case b: BroadcastExchangeLike =>
- val hasReuse = plan.find {
+ val hasReuse = plan.exists {
case ReusedExchangeExec(_, e) => e eq b
case _ => false
- }.isDefined
+ }
assert(hasReuse, s"$s\nshould have been reused in\n$plan")
case a: AdaptiveSparkPlanExec =>
val broadcastQueryStage = collectFirst(a) {
@@ -229,7 +234,7 @@ abstract class DynamicPartitionPruningSuiteBase
case r: ReusedSubqueryExec => r.child
case o => o
}
- assert(subquery.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive)
+ assert(subquery.exists(_.isInstanceOf[AdaptiveSparkPlanExec]) == isMainQueryAdaptive)
}
}
@@ -240,7 +245,7 @@ abstract class DynamicPartitionPruningSuiteBase
df.collect()
val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect {
- case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _) =>
+ case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) =>
b.index
}
assert(buf.distinct.size == n)
@@ -249,7 +254,7 @@ abstract class DynamicPartitionPruningSuiteBase
/**
* Collect the children of all correctly pushed down dynamic pruning expressions in a spark plan.
*/
- private def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = {
+ protected def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = {
flatMap(plan) {
case s: FileSourceScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child
@@ -339,12 +344,12 @@ abstract class DynamicPartitionPruningSuiteBase
| )
""".stripMargin)
- val found = df.queryExecution.executedPlan.find {
+ val found = df.queryExecution.executedPlan.exists {
case BroadcastHashJoinExec(_, _, p: ExistenceJoin, _, _, _, _, _) => true
case _ => false
}
- assert(found.isEmpty)
+ assert(!found)
}
}
@@ -467,31 +472,70 @@ abstract class DynamicPartitionPruningSuiteBase
Given("no stats and selective predicate with the size of dim too large")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
- SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") {
- sql(
- """
- |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
- |FROM fact_sk f WHERE store_id < 5
- """.stripMargin)
- .write
- .partitionBy("store_id")
- .saveAsTable("fact_aux")
+ SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true",
+ SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.02") {
+ withTable("fact_aux") {
+ sql(
+ """
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
+ |FROM fact_sk f WHERE store_id < 5
+ """.stripMargin)
+ .write
+ .partitionBy("store_id")
+ .saveAsTable("fact_aux")
+
+ val df = sql(
+ """
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
+ |FROM fact_aux f JOIN dim_store s
+ |ON f.store_id = s.store_id WHERE s.country = 'US'
+ """.stripMargin)
+
+ checkPartitionPruningPredicate(df, false, false)
+
+ checkAnswer(df,
+ Row(1070, 2, 10, 4) ::
+ Row(1080, 3, 20, 4) ::
+ Row(1090, 3, 10, 4) ::
+ Row(1100, 3, 10, 4) :: Nil
+ )
+ }
+ }
- val df = sql(
- """
- |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
- |FROM fact_aux f JOIN dim_store s
- |ON f.store_id = s.store_id WHERE s.country = 'US'
- """.stripMargin)
+ Given("no stats and selective predicate with the size of dim too large but cached")
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
+ SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") {
+ withTable("fact_aux") {
+ withTempView("cached_dim_store") {
+ sql(
+ """
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
+ |FROM fact_sk f WHERE store_id < 5
+ """.stripMargin)
+ .write
+ .partitionBy("store_id")
+ .saveAsTable("fact_aux")
- checkPartitionPruningPredicate(df, false, false)
+ spark.table("dim_store").cache()
+ .createOrReplaceTempView("cached_dim_store")
- checkAnswer(df,
- Row(1070, 2, 10, 4) ::
- Row(1080, 3, 20, 4) ::
- Row(1090, 3, 10, 4) ::
- Row(1100, 3, 10, 4) :: Nil
- )
+ val df = sql(
+ """
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id
+ |FROM fact_aux f JOIN cached_dim_store s
+ |ON f.store_id = s.store_id WHERE s.country = 'US'
+ """.stripMargin)
+
+ checkPartitionPruningPredicate(df, true, false)
+
+ checkAnswer(df,
+ Row(1070, 2, 10, 4) ::
+ Row(1080, 3, 20, 4) ::
+ Row(1090, 3, 10, 4) ::
+ Row(1100, 3, 10, 4) :: Nil
+ )
+ }
+ }
}
Given("no stats and selective predicate with the size of dim small")
@@ -983,41 +1027,6 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
- test("no partition pruning when the build side is a stream") {
- withTable("fact") {
- val input = MemoryStream[Int]
- val stream = input.toDF.select($"value" as "one", ($"value" * 3) as "code")
- spark.range(100).select(
- $"id",
- ($"id" + 1).as("one"),
- ($"id" + 2).as("two"),
- ($"id" + 3).as("three"))
- .write.partitionBy("one")
- .format(tableFormat).mode("overwrite").saveAsTable("fact")
- val table = sql("SELECT * from fact f")
-
- // join a partitioned table with a stream
- val joined = table.join(stream, Seq("one")).where("code > 40")
- val query = joined.writeStream.format("memory").queryName("test").start()
- input.addData(1, 10, 20, 40, 50)
- try {
- query.processAllAvailable()
- } finally {
- query.stop()
- }
- // search dynamic pruning predicates on the executed plan
- val plan = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
- val ret = plan.find {
- case s: FileSourceScanExec => s.partitionFilters.exists {
- case _: DynamicPruningExpression => true
- case _ => false
- }
- case _ => false
- }
- assert(ret.isDefined == false)
- }
- }
-
test("avoid reordering broadcast join keys to match input hash partitioning") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
@@ -1144,7 +1153,8 @@ abstract class DynamicPartitionPruningSuiteBase
test("join key with multiple references on the filtering plan") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true",
- SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName,
+ SQLConf.ANSI_ENABLED.key -> "false" // ANSI mode doesn't support "String + String"
) {
// when enable AQE, the reusedExchange is inserted when executed.
withTable("fact", "dim") {
@@ -1473,9 +1483,124 @@ abstract class DynamicPartitionPruningSuiteBase
checkAnswer(df, Row(1150, 1) :: Row(1130, 4) :: Row(1140, 4) :: Nil)
}
}
+
+ test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " +
+ "pruning") {
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
+ Seq(
+ "f.store_id = 1" -> false,
+ "1 = f.store_id" -> false,
+ "f.store_id <=> 1" -> false,
+ "1 <=> f.store_id" -> false,
+ "f.store_id > 1" -> true,
+ "5 > f.store_id" -> true).foreach { case (condition, hasDPP) =>
+ // partitioned table at left side
+ val df1 = sql(
+ s"""
+ |SELECT /*+ broadcast(s) */ * FROM fact_sk f
+ |JOIN dim_store s ON f.store_id = s.store_id AND $condition
+ """.stripMargin)
+ checkPartitionPruningPredicate(df1, false, withBroadcast = hasDPP)
+
+ val df2 = sql(
+ s"""
+ |SELECT /*+ broadcast(s) */ * FROM fact_sk f
+ |JOIN dim_store s ON f.store_id = s.store_id
+ |WHERE $condition
+ """.stripMargin)
+ checkPartitionPruningPredicate(df2, false, withBroadcast = hasDPP)
+
+ // partitioned table at right side
+ val df3 = sql(
+ s"""
+ |SELECT /*+ broadcast(s) */ * FROM dim_store s
+ |JOIN fact_sk f ON f.store_id = s.store_id AND $condition
+ """.stripMargin)
+ checkPartitionPruningPredicate(df3, false, withBroadcast = hasDPP)
+
+ val df4 = sql(
+ s"""
+ |SELECT /*+ broadcast(s) */ * FROM dim_store s
+ |JOIN fact_sk f ON f.store_id = s.store_id
+ |WHERE $condition
+ """.stripMargin)
+ checkPartitionPruningPredicate(df4, false, withBroadcast = hasDPP)
+ }
+ }
+ }
+
+ test("SPARK-38570: Fix incorrect DynamicPartitionPruning caused by Literal") {
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
+ val df = sql(
+ """
+ |SELECT f.store_id,
+ | f.date_id,
+ | s.state_province
+ |FROM (SELECT 4 AS store_id,
+ | date_id,
+ | product_id
+ | FROM fact_sk
+ | WHERE date_id >= 1300
+ | UNION ALL
+ | SELECT 5 AS store_id,
+ | date_id,
+ | product_id
+ | FROM fact_stats
+ | WHERE date_id <= 1000) f
+ |JOIN dim_store s
+ |ON f.store_id = s.store_id
+ |WHERE s.country = 'US'
+ |""".stripMargin)
+
+ checkPartitionPruningPredicate(df, withSubquery = false, withBroadcast = false)
+ checkAnswer(df, Row(4, 1300, "California") :: Row(5, 1000, "Texas") :: Nil)
+ }
+ }
}
-abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSuiteBase {
+abstract class DynamicPartitionPruningDataSourceSuiteBase
+ extends DynamicPartitionPruningSuiteBase
+ with SharedSparkSession {
+
+ import testImplicits._
+
+ test("no partition pruning when the build side is a stream") {
+ withTable("fact") {
+ val input = MemoryStream[Int]
+ val stream = input.toDF.select($"value" as "one", ($"value" * 3) as "code")
+ spark.range(100).select(
+ $"id",
+ ($"id" + 1).as("one"),
+ ($"id" + 2).as("two"),
+ ($"id" + 3).as("three"))
+ .write.partitionBy("one")
+ .format(tableFormat).mode("overwrite").saveAsTable("fact")
+ val table = sql("SELECT * from fact f")
+
+ // join a partitioned table with a stream
+ val joined = table.join(stream, Seq("one")).where("code > 40")
+ val query = joined.writeStream.format("memory").queryName("test").start()
+ input.addData(1, 10, 20, 40, 50)
+ try {
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ // search dynamic pruning predicates on the executed plan
+ val plan = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
+ val ret = plan.exists {
+ case s: FileSourceScanExec => s.partitionFilters.exists {
+ case _: DynamicPruningExpression => true
+ case _ => false
+ }
+ case _ => false
+ }
+ assert(!ret)
+ }
+ }
+}
+
+abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDataSourceSuiteBase {
import testImplicits._
@@ -1510,10 +1635,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSui
val scanOption =
find(plan) {
case s: FileSourceScanExec =>
- s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined)
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid")))
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
- s.output.exists(_.find(_.argString(maxFields = 100).contains("f1")).isDefined)
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
case _ => false
}
assert(scanOption.isDefined)
@@ -1569,6 +1694,25 @@ class DynamicPartitionPruningV1SuiteAEOff extends DynamicPartitionPruningV1Suite
class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite
with EnableAdaptiveExecutionSuite {
+ test("SPARK-39447: Avoid AssertionError in AdaptiveSparkPlanExec.doExecuteBroadcast") {
+ val df = sql(
+ """
+ |WITH empty_result AS (
+ | SELECT * FROM fact_stats WHERE product_id < 0
+ |)
+ |SELECT *
+ |FROM (SELECT /*+ SHUFFLE_MERGE(fact_sk) */ empty_result.store_id
+ | FROM fact_sk
+ | JOIN empty_result
+ | ON fact_sk.product_id = empty_result.product_id) t2
+ | JOIN empty_result
+ | ON t2.store_id = empty_result.store_id
+ """.stripMargin)
+
+ checkPartitionPruningPredicate(df, false, false)
+ checkAnswer(df, Nil)
+ }
+
test("SPARK-37995: PlanAdaptiveDynamicPruningFilters should use prepareExecutedPlan " +
"rather than createSparkPlan to re-plan subquery") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
@@ -1587,7 +1731,7 @@ class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite
}
}
-abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningSuiteBase {
+abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningDataSourceSuiteBase {
override protected def runAnalyzeColumnCommands: Boolean = false
override protected def initState(): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index aef39a24e9..d637283446 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -106,7 +106,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)")
}
- test("optimized plan should show the rewritten aggregate expression") {
+ test("optimized plan should show the rewritten expression") {
withTempView("test_agg") {
sql(
"""
@@ -125,6 +125,13 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
"Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, " +
"any(v#x) AS any(v)#x]")
}
+
+ withTable("t") {
+ sql("CREATE TABLE t(col TIMESTAMP) USING parquet")
+ val df = sql("SELECT date_part('month', col) FROM t")
+ checkKeywordsExistsInExplain(df,
+ "Project [month(cast(col#x as date)) AS date_part(month, col)#x]")
+ }
}
test("explain inline tables cross-joins") {
@@ -217,8 +224,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
// AND conjunction
// OR disjunction
// ---------------------------------------------------------------------------------------
- checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"),
- "Project [null AS (concat(a, 1) + 2)#x]")
+ checkKeywordsExistsInExplain(sql("select '1' || 1 + 2"),
+ "Project [13", " AS (concat(1, 1) + 2)#x")
checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"),
"Project [-1b AS concat((1 - 2), b)#x]")
checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"),
@@ -232,12 +239,11 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
}
test("explain for these functions; use range to avoid constant folding") {
- val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " +
+ val df = sql("select ifnull(id, 1), nullif(id, 1), nvl(id, 1), nvl2(id, 1, 2) " +
"from range(2)")
checkKeywordsExistsInExplain(df,
- "Project [coalesce(cast(id#xL as string), x) AS ifnull(id, x)#x, " +
- "id#xL AS nullif(id, x)#xL, coalesce(cast(id#xL as string), x) AS nvl(id, x)#x, " +
- "x AS nvl2(id, x, y)#x]")
+ "Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " +
+ "else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]")
}
test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") {
@@ -522,6 +528,21 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
"== Analyzed Logical Plan ==\nCreateViewCommand")
}
}
+
+ test("SPARK-39112: UnsupportedOperationException if explain cost command using v2 command") {
+ withTempDir { dir =>
+ sql("EXPLAIN COST CREATE DATABASE tmp")
+ sql("EXPLAIN COST DESC DATABASE tmp")
+ sql(s"EXPLAIN COST ALTER DATABASE tmp SET LOCATION '${dir.toURI.toString}'")
+ sql("EXPLAIN COST USE tmp")
+ sql("EXPLAIN COST CREATE TABLE t(c1 int) USING PARQUET")
+ sql("EXPLAIN COST SHOW TABLES")
+ sql("EXPLAIN COST SHOW CREATE TABLE t")
+ sql("EXPLAIN COST SHOW TBLPROPERTIES t")
+ sql("EXPLAIN COST DROP TABLE t")
+ sql("EXPLAIN COST DROP DATABASE tmp")
+ }
+ }
}
class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite {
@@ -594,7 +615,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
}
test("SPARK-35884: Explain should only display one plan before AQE takes effect") {
- val df = (0 to 10).toDF("id").where('id > 5)
+ val df = (0 to 10).toDF("id").where(Symbol("id") > 5)
val modes = Seq(SimpleMode, ExtendedMode, CostMode, FormattedMode)
modes.foreach { mode =>
checkKeywordsExistsInExplain(df, mode, "AdaptiveSparkPlan")
@@ -609,7 +630,8 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
test("SPARK-35884: Explain formatted with subquery") {
withTempView("t1", "t2") {
- spark.range(100).select('id % 10 as "key", 'id as "value").createOrReplaceTempView("t1")
+ spark.range(100).select(Symbol("id") % 10 as "key", Symbol("id") as "value")
+ .createOrReplaceTempView("t1")
spark.range(10).createOrReplaceTempView("t2")
val query =
"""
@@ -678,6 +700,33 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
}
}
+ test("SPARK-32986: Bucketed scan info should be a part of explain string") {
+ withTable("t1", "t2") {
+ Seq((1, 2), (2, 3)).toDF("i", "j").write.bucketBy(8, "i").saveAsTable("t1")
+ Seq(2, 3).toDF("i").write.bucketBy(8, "i").saveAsTable("t2")
+ val df1 = spark.table("t1")
+ val df2 = spark.table("t2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ checkKeywordsExistsInExplain(
+ df1.join(df2, df1("i") === df2("i")),
+ "Bucketed: true")
+ }
+
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+ checkKeywordsExistsInExplain(
+ df1.join(df2, df1("i") === df2("i")),
+ "Bucketed: false (disabled by configuration)")
+ }
+
+ checkKeywordsExistsInExplain(df1, "Bucketed: false (disabled by query planner)" )
+
+ checkKeywordsExistsInExplain(
+ df1.select("j"),
+ "Bucketed: false (bucket column(s) not read)")
+ }
+ }
+
test("SPARK-36795: Node IDs should not be duplicated when InMemoryRelation present") {
withTempView("t1", "t2") {
Seq(1).toDF("k").write.saveAsTable("t1")
@@ -696,6 +745,35 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
assert(inMemoryRelationNodeId != columnarToRowNodeId)
}
}
+
+ test("SPARK-38232: Explain formatted does not collect subqueries under query stage in AQE") {
+ withTable("t") {
+ sql("CREATE TABLE t USING PARQUET AS SELECT 1 AS c")
+ val expected =
+ "Subquery:1 Hosting operator id = 2 Hosting Expression = Subquery subquery#x, [id=#x]"
+ val df = sql("SELECT count(s) FROM (SELECT (SELECT c FROM t) as s)")
+ df.collect()
+ withNormalizedExplain(df, FormattedMode) { output =>
+ assert(output.contains(expected))
+ }
+ }
+ }
+
+ test("SPARK-38322: Support query stage show runtime statistics in formatted explain mode") {
+ val df = Seq(1, 2).toDF("c").distinct()
+ val statistics = "Statistics(sizeInBytes=32.0 B, rowCount=2)"
+
+ checkKeywordsNotExistsInExplain(
+ df,
+ FormattedMode,
+ statistics)
+
+ df.collect()
+ checkKeywordsExistsInExplain(
+ df,
+ FormattedMode,
+ statistics)
+ }
}
case class ExplainSingleData(id: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 001b6a00af..17dfde65ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -382,7 +382,9 @@ class FileBasedDataSourceSuite extends QueryTest
msg.toLowerCase(Locale.ROOT).contains(msg2))
}
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) {
+ withSQLConf(
+ SQLConf.USE_V1_SOURCE_LIST.key -> useV1List,
+ SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
// write path
Seq("csv", "json", "parquet", "orc").foreach { format =>
val msg = intercept[AnalysisException] {
@@ -532,6 +534,64 @@ class FileBasedDataSourceSuite extends QueryTest
}
}
+ test("SPARK-30362: test input metrics for DSV2") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ Seq("json", "orc", "parquet").foreach { format =>
+ withTempPath { path =>
+ val dir = path.getCanonicalPath
+ spark.range(0, 10).write.format(format).save(dir)
+ val df = spark.read.format(format).load(dir)
+ val bytesReads = new mutable.ArrayBuffer[Long]()
+ val recordsRead = new mutable.ArrayBuffer[Long]()
+ val bytesReadListener = new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead
+ recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+ }
+ }
+ sparkContext.addSparkListener(bytesReadListener)
+ try {
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(bytesReads.sum > 0)
+ assert(recordsRead.sum == 10)
+ } finally {
+ sparkContext.removeSparkListener(bytesReadListener)
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-37585: test input metrics for DSV2 with output limits") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ Seq("json", "orc", "parquet").foreach { format =>
+ withTempPath { path =>
+ val dir = path.getCanonicalPath
+ spark.range(0, 100).write.format(format).save(dir)
+ val df = spark.read.format(format).load(dir)
+ val bytesReads = new mutable.ArrayBuffer[Long]()
+ val recordsRead = new mutable.ArrayBuffer[Long]()
+ val bytesReadListener = new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead
+ recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+ }
+ }
+ sparkContext.addSparkListener(bytesReadListener)
+ try {
+ df.limit(10).collect()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(bytesReads.sum > 0)
+ assert(recordsRead.sum > 0)
+ } finally {
+ sparkContext.removeSparkListener(bytesReadListener)
+ }
+ }
+ }
+ }
+ }
+
test("Do not use cache on overwrite") {
Seq("", "orc").foreach { useV1SourceReaderList =>
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1SourceReaderList) {
@@ -695,7 +755,7 @@ class FileBasedDataSourceSuite extends QueryTest
test("SPARK-22790,SPARK-27668: spark.sql.sources.compressionFactor takes effect") {
Seq(1.0, 0.5).foreach { compressionFactor =>
withSQLConf(SQLConf.FILE_COMPRESSION_FACTOR.key -> compressionFactor.toString,
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "350") {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "457") {
withTempPath { workDir =>
// the file size is 504 bytes
val workDirPath = workDir.getAbsolutePath
@@ -731,6 +791,28 @@ class FileBasedDataSourceSuite extends QueryTest
}
}
+ test("SPARK-36568: FileScan statistics estimation takes read schema into account") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ withTempDir { dir =>
+ spark.range(1000).map(x => (x / 100, x, x)).toDF("k", "v1", "v2").
+ write.partitionBy("k").mode(SaveMode.Overwrite).orc(dir.toString)
+ val dfAll = spark.read.orc(dir.toString)
+ val dfK = dfAll.select("k")
+ val dfV1 = dfAll.select("v1")
+ val dfV2 = dfAll.select("v2")
+ val dfV1V2 = dfAll.select("v1", "v2")
+
+ def sizeInBytes(df: DataFrame): BigInt = df.queryExecution.optimizedPlan.stats.sizeInBytes
+
+ assert(sizeInBytes(dfAll) === BigInt(getLocalDirSize(dir)))
+ assert(sizeInBytes(dfK) < sizeInBytes(dfAll))
+ assert(sizeInBytes(dfV1) < sizeInBytes(dfAll))
+ assert(sizeInBytes(dfV2) === sizeInBytes(dfV1))
+ assert(sizeInBytes(dfV1V2) < sizeInBytes(dfAll))
+ }
+ }
+ }
+
test("File source v2: support partition pruning") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
allFileBasedDataSources.foreach { format =>
@@ -754,12 +836,13 @@ class FileBasedDataSourceSuite extends QueryTest
}
assert(filterCondition.isDefined)
// The partitions filters should be pushed down and no need to be reevaluated.
- assert(filterCondition.get.collectFirst {
- case a: AttributeReference if a.name == "p1" || a.name == "p2" => a
- }.isEmpty)
+ assert(!filterCondition.get.exists {
+ case a: AttributeReference => a.name == "p1" || a.name == "p2"
+ case _ => false
+ })
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: FileScan, _) => f
+ case BatchScanExec(_, f: FileScan, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@@ -799,7 +882,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: FileScan, _) => f
+ case BatchScanExec(_, f: FileScan, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -885,52 +968,57 @@ class FileBasedDataSourceSuite extends QueryTest
// cases when value == MAX
var v = Short.MaxValue
- checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true)
- checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"),
- sources.EqualTo("id", v)))
- checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"),
- sources.EqualTo("id", v)))
- checkPushedFilters(format, df.where('id <=> v.toInt),
+ checkPushedFilters(format, df.where(Symbol("id") > v.toInt), Array(), noScan = true)
+ checkPushedFilters(format, df.where(Symbol("id") >= v.toInt),
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
+ checkPushedFilters(format, df.where(Symbol("id") === v.toInt),
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
+ checkPushedFilters(format, df.where(Symbol("id") <=> v.toInt),
Array(sources.EqualNullSafe("id", v)))
- checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id")))
- checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"),
- sources.Not(sources.EqualTo("id", v))))
+ checkPushedFilters(format, df.where(Symbol("id") <= v.toInt),
+ Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(Symbol("id") < v.toInt),
+ Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v))))
// cases when value > MAX
var v1: Int = positiveInt
- checkPushedFilters(format, df.where('id > v1), Array(), noScan = true)
- checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true)
- checkPushedFilters(format, df.where('id === v1), Array(), noScan = true)
- checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true)
- checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id")))
- checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(Symbol("id") > v1), Array(), noScan = true)
+ checkPushedFilters(format, df.where(Symbol("id") >= v1), Array(), noScan = true)
+ checkPushedFilters(format, df.where(Symbol("id") === v1), Array(), noScan = true)
+ checkPushedFilters(format, df.where(Symbol("id") <=> v1), Array(), noScan = true)
+ checkPushedFilters(format, df.where(Symbol("id") <= v1), Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(Symbol("id") < v1), Array(sources.IsNotNull("id")))
// cases when value = MIN
v = Short.MinValue
- checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"),
- sources.Not(sources.EqualTo("id", v))))
- checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id")))
- checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"),
+ checkPushedFilters(format, df.where(lit(v.toInt) < Symbol("id")),
+ Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v))))
+ checkPushedFilters(format, df.where(lit(v.toInt) <= Symbol("id")),
+ Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(lit(v.toInt) === Symbol("id")),
+ Array(sources.IsNotNull("id"),
sources.EqualTo("id", v)))
- checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id),
+ checkPushedFilters(format, df.where(lit(v.toInt) <=> Symbol("id")),
Array(sources.EqualNullSafe("id", v)))
- checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"),
- sources.EqualTo("id", v)))
- checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true)
+ checkPushedFilters(format, df.where(lit(v.toInt) >= Symbol("id")),
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
+ checkPushedFilters(format, df.where(lit(v.toInt) > Symbol("id")), Array(), noScan = true)
// cases when value < MIN
v1 = negativeInt
- checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id")))
- checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id")))
- checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true)
- checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true)
- checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true)
+ checkPushedFilters(format, df.where(lit(v1) < Symbol("id")),
+ Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(lit(v1) <= Symbol("id")),
+ Array(sources.IsNotNull("id")))
+ checkPushedFilters(format, df.where(lit(v1) === Symbol("id")), Array(), noScan = true)
+ checkPushedFilters(format, df.where(lit(v1) >= Symbol("id")), Array(), noScan = true)
+ checkPushedFilters(format, df.where(lit(v1) > Symbol("id")), Array(), noScan = true)
// cases when value is within range (MIN, MAX)
- checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"),
+ checkPushedFilters(format, df.where(Symbol("id") > 30), Array(sources.IsNotNull("id"),
sources.GreaterThan("id", 30)))
- checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"),
- sources.LessThanOrEqual("id", 100)))
+ checkPushedFilters(format, df.where(lit(100) >= Symbol("id")),
+ Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100)))
}
}
}
@@ -967,28 +1055,6 @@ class FileBasedDataSourceSuite extends QueryTest
checkAnswer(df, Row("v1", "v2"))
}
}
-
- test("SPARK-36271: V1 insert should check schema field name too") {
- withView("v") {
- spark.range(1).createTempView("v")
- withTempDir { dir =>
- val e = intercept[AnalysisException] {
- sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite)
- .format("parquet").save(dir.getCanonicalPath)
- }.getMessage
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
- }
-
- withTempDir { dir =>
- val e = intercept[AnalysisException] {
- sql("SELECT NAMED_STRUCT('(IF((ID = 1), 1, 0))', IF(ID=1,ID,0)) AS col1 FROM v")
- .write.mode(SaveMode.Overwrite)
- .format("parquet").save(dir.getCanonicalPath)
- }.getMessage
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
- }
- }
- }
}
object TestingUDT {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
index 4e7fe8455f..ce98fd2735 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -85,10 +85,11 @@ trait FileScanSuiteBase extends SharedSparkSession {
val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap))
val optionsNotEqual =
new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", "value2")))
- val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
- val partitionFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1)))
- val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
- val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1)))
+ val partitionFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0)))
+ val partitionFiltersNotEqual = Seq(And(IsNull(Symbol("data").int),
+ LessThan(Symbol("data").int, 1)))
+ val dataFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0)))
+ val dataFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 1)))
scanBuilders.foreach { case (name, scanBuilder, exclusions) =>
test(s"SPARK-33482: Test $name equals") {
@@ -354,11 +355,11 @@ class FileScanSuite extends FileScanSuiteBase {
val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
("ParquetScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
- ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df),
+ ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df),
Seq.empty),
("OrcScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
- OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df),
+ OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df),
Seq.empty),
("CSVScan",
(s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df),
@@ -367,7 +368,7 @@ class FileScanSuite extends FileScanSuiteBase {
(s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df),
Seq.empty),
("TextScan",
- (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df),
+ (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df),
Seq("dataSchema", "pushedFilters")))
run(scanBuilders)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index d5c2d93055..49cdc80241 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -331,7 +332,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
val msg1 = intercept[AnalysisException] {
sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1")
}.getMessage
- assert(msg1.contains("Generators are not supported when it's nested in expressions"))
+ assert(msg1.contains("The generator is not supported: nested in expressions"))
val msg2 = intercept[AnalysisException] {
sql(
@@ -341,7 +342,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
|from t1 group by c1
""".stripMargin)
}.getMessage
- assert(msg2.contains("Only one generator allowed per aggregate clause"))
+ assert(msg2.contains("The generator is not supported: " +
+ "only one generator allowed per aggregate clause"))
}
}
@@ -349,14 +351,99 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
val errMsg = intercept[AnalysisException] {
sql("SELECT array(array(1, 2), array(3)) v").select(explode(explode($"v"))).collect
}.getMessage
- assert(errMsg.contains("Generators are not supported when it's nested in expressions, " +
- "but got: explode(explode(v))"))
+ assert(errMsg.contains("The generator is not supported: " +
+ "nested in expressions \"explode(explode(v))\""))
}
test("SPARK-30997: generators in aggregate expressions for dataframe") {
val df = Seq(1, 2, 3).toDF("v")
checkAnswer(df.select(explode(array(min($"v"), max($"v")))), Row(1) :: Row(3) :: Nil)
}
+
+ test("SPARK-38528: generator in stream of aggregate expressions") {
+ val df = Seq(1, 2, 3).toDF("v")
+ checkAnswer(
+ df.select(Stream(explode(array(min($"v"), max($"v"))), sum($"v")): _*),
+ Row(1, 6) :: Row(3, 6) :: Nil)
+ }
+
+ test("SPARK-37947: lateral view <func>_outer()") {
+ checkAnswer(
+ sql("select * from values 1, 2 lateral view explode_outer(array()) a as b"),
+ Row(1, null) :: Row(2, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from values 1, 2 lateral view outer explode_outer(array()) a as b"),
+ Row(1, null) :: Row(2, null) :: Nil)
+
+ withTempView("t1") {
+ sql(
+ """select * from values
+ |array(struct(0, 1), struct(3, 4)),
+ |array(struct(6, 7)),
+ |array(),
+ |null
+ |as tbl(arr)
+ """.stripMargin).createOrReplaceTempView("t1")
+ checkAnswer(
+ sql("select f1, f2 from t1 lateral view inline_outer(arr) as f1, f2"),
+ Row(0, 1) :: Row(3, 4) :: Row(6, 7) :: Row(null, null) :: Row(null, null) :: Nil)
+ }
+ }
+
+ def testNullStruct(): Unit = {
+ val df = sql(
+ """select * from values
+ |(
+ | 1,
+ | array(
+ | named_struct('c1', 0, 'c2', 1),
+ | null,
+ | named_struct('c1', 2, 'c2', 3),
+ | null
+ | )
+ |)
+ |as tbl(a, b)
+ """.stripMargin)
+ df.createOrReplaceTempView("t1")
+
+ checkAnswer(
+ sql("select inline(b) from t1"),
+ Row(0, 1) :: Row(null, null) :: Row(2, 3) :: Row(null, null) :: Nil)
+
+ checkAnswer(
+ sql("select a, inline(b) from t1"),
+ Row(1, 0, 1) :: Row(1, null, null) :: Row(1, 2, 3) :: Row(1, null, null) :: Nil)
+ }
+
+ test("SPARK-39061: inline should handle null struct") {
+ testNullStruct
+ }
+
+ test("SPARK-39496: inline eval path should handle null struct") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ testNullStruct
+ }
+ }
+
+ test("SPARK-40963: generator output has correct nullability") {
+ // This test does not check nullability directly. Before SPARK-40963,
+ // the below query got wrong results due to incorrect nullability.
+ val df = sql(
+ """select c1, explode(c4) as c5 from (
+ | select c1, array(c3) as c4 from (
+ | select c1, explode_outer(c2) as c3
+ | from values
+ | (1, array(1, 2)),
+ | (2, array(2, 3)),
+ | (3, null)
+ | as data(c1, c2)
+ | )
+ |)
+ |""".stripMargin)
+ checkAnswer(df,
+ Row(1, 1) :: Row(1, 2) :: Row(2, 2) :: Row(2, 3) :: Row(3, null) :: Nil)
+ }
}
case class EmptyGenerator() extends Generator with LeafLike[Expression] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
new file mode 100644
index 0000000000..6c6bd1799e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -0,0 +1,579 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
+import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries
+import org.apache.spark.sql.catalyst.plans.LeftSemi
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan}
+import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession
+ with AdaptiveSparkPlanHelper {
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ val schema = new StructType().add("a1", IntegerType, nullable = true)
+ .add("b1", IntegerType, nullable = true)
+ .add("c1", IntegerType, nullable = true)
+ .add("d1", IntegerType, nullable = true)
+ .add("e1", IntegerType, nullable = true)
+ .add("f1", IntegerType, nullable = true)
+
+ val data1 = Seq(Seq(null, 47, null, 4, 6, 48),
+ Seq(73, 63, null, 92, null, null),
+ Seq(76, 10, 74, 98, 37, 5),
+ Seq(0, 63, null, null, null, null),
+ Seq(15, 77, null, null, null, null),
+ Seq(null, 57, 33, 55, null, 58),
+ Seq(4, 0, 86, null, 96, 14),
+ Seq(28, 16, 58, null, null, null),
+ Seq(1, 88, null, 8, null, 79),
+ Seq(59, null, null, null, 20, 25),
+ Seq(1, 50, null, 94, 94, null),
+ Seq(null, null, null, 67, 51, 57),
+ Seq(77, 50, 8, 90, 16, 21),
+ Seq(34, 28, null, 5, null, 64),
+ Seq(null, null, 88, 11, 63, 79),
+ Seq(92, 94, 23, 1, null, 64),
+ Seq(57, 56, null, 83, null, null),
+ Seq(null, 35, 8, 35, null, 70),
+ Seq(null, 8, null, 35, null, 87),
+ Seq(9, null, null, 60, null, 5),
+ Seq(null, 15, 66, null, 83, null))
+ val rdd1 = spark.sparkContext.parallelize(data1)
+ val rddRow1 = rdd1.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1")
+
+ val schema2 = new StructType().add("a2", IntegerType, nullable = true)
+ .add("b2", IntegerType, nullable = true)
+ .add("c2", IntegerType, nullable = true)
+ .add("d2", IntegerType, nullable = true)
+ .add("e2", IntegerType, nullable = true)
+ .add("f2", IntegerType, nullable = true)
+
+
+ val data2 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd2 = spark.sparkContext.parallelize(data2)
+ val rddRow2 = rdd2.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2")
+
+ val schema3 = new StructType().add("a3", IntegerType, nullable = true)
+ .add("b3", IntegerType, nullable = true)
+ .add("c3", IntegerType, nullable = true)
+ .add("d3", IntegerType, nullable = true)
+ .add("e3", IntegerType, nullable = true)
+ .add("f3", IntegerType, nullable = true)
+
+ val data3 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd3 = spark.sparkContext.parallelize(data3)
+ val rddRow3 = rdd3.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3")
+
+
+ val schema4 = new StructType().add("a4", IntegerType, nullable = true)
+ .add("b4", IntegerType, nullable = true)
+ .add("c4", IntegerType, nullable = true)
+ .add("d4", IntegerType, nullable = true)
+ .add("e4", IntegerType, nullable = true)
+ .add("f4", IntegerType, nullable = true)
+
+ val data4 = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd4 = spark.sparkContext.parallelize(data4)
+ val rddRow4 = rdd4.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4")
+
+ val schema5part = new StructType().add("a5", IntegerType, nullable = true)
+ .add("b5", IntegerType, nullable = true)
+ .add("c5", IntegerType, nullable = true)
+ .add("d5", IntegerType, nullable = true)
+ .add("e5", IntegerType, nullable = true)
+ .add("f5", IntegerType, nullable = true)
+
+ val data5part = Seq(Seq(67, 17, 45, 91, null, null),
+ Seq(98, 63, 0, 89, null, 40),
+ Seq(null, 76, 68, 75, 20, 19),
+ Seq(8, null, null, null, 78, null),
+ Seq(48, 62, null, null, 11, 98),
+ Seq(84, null, 99, 65, 66, 51),
+ Seq(98, null, null, null, 42, 51),
+ Seq(10, 3, 29, null, 68, 8),
+ Seq(85, 36, 41, null, 28, 71),
+ Seq(89, null, 94, 95, 67, 21),
+ Seq(44, null, 24, 33, null, 6),
+ Seq(null, 6, 78, 31, null, 69),
+ Seq(59, 2, 63, 9, 66, 20),
+ Seq(5, 23, 10, 86, 68, null),
+ Seq(null, 63, 99, 55, 9, 65),
+ Seq(57, 62, 68, 5, null, 0),
+ Seq(75, null, 15, null, 81, null),
+ Seq(53, null, 6, 68, 28, 13),
+ Seq(null, null, null, null, 89, 23),
+ Seq(36, 73, 40, null, 8, null),
+ Seq(24, null, null, 40, null, null))
+ val rdd5part = spark.sparkContext.parallelize(data5part)
+ val rddRow5part = rdd5part.map(s => Row.fromSeq(s))
+ spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5")
+ .saveAsTable("bf5part")
+ spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30")
+ .write.partitionBy("f5")
+ .saveAsTable("bf5filtered")
+
+ sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1")
+ sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2")
+ sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3")
+ sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4")
+ sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5")
+ sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5")
+
+ // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing
+ // complicated.
+ conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName)
+ }
+
+ protected override def afterAll(): Unit = try {
+ conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.defaultValueString)
+
+ sql("DROP TABLE IF EXISTS bf1")
+ sql("DROP TABLE IF EXISTS bf2")
+ sql("DROP TABLE IF EXISTS bf3")
+ sql("DROP TABLE IF EXISTS bf4")
+ sql("DROP TABLE IF EXISTS bf5part")
+ sql("DROP TABLE IF EXISTS bf5filtered")
+ } finally {
+ super.afterAll()
+ }
+
+ private def ensureLeftSemiJoinExists(plan: LogicalPlan): Unit = {
+ assert(
+ plan.find {
+ case j: Join if j.joinType == LeftSemi => true
+ case _ => false
+ }.isDefined
+ )
+ }
+
+ def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean,
+ shouldReplace: Boolean): Unit = {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ if (testSemiJoin) {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ if (shouldReplace) {
+ val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled))
+ val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled))
+ ensureLeftSemiJoinExists(planEnabled)
+ assert(normalizedEnabled != normalizedDisabled)
+ } else {
+ comparePlans(planDisabled, planEnabled)
+ }
+ } else {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ if (shouldReplace) {
+ assert(!columnPruningTakesEffect(planEnabled))
+ assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled))
+ } else {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled))
+ }
+ }
+ }
+ }
+
+ def getNumBloomFilters(plan: LogicalPlan): Integer = {
+ val numBloomFilterAggs = plan.collect {
+ case Filter(condition, _) => condition.collect {
+ case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+ => subquery.plan.collect {
+ case Aggregate(_, aggregateExpressions, _) =>
+ aggregateExpressions.map {
+ case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _),
+ _) =>
+ assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal])
+ assert(bfAgg.numBitsExpression.isInstanceOf[Literal])
+ 1
+ }.sum
+ }.sum
+ }.sum
+ }.sum
+ val numMightContains = plan.collect {
+ case Filter(condition, _) => condition.collect {
+ case BloomFilterMightContain(_, _) => 1
+ }.sum
+ }.sum
+ assert(numBloomFilterAggs == numMightContains)
+ numMightContains
+ }
+
+ def columnPruningTakesEffect(plan: LogicalPlan): Boolean = {
+ def takesEffect(plan: LogicalPlan): Boolean = {
+ val result = org.apache.spark.sql.catalyst.optimizer.ColumnPruning.apply(plan)
+ !result.fastEquals(plan)
+ }
+
+ plan.collectFirst {
+ case Filter(condition, _) if condition.collectFirst {
+ case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+ if takesEffect(subquery.plan) => true
+ }.nonEmpty => true
+ }.nonEmpty
+ }
+
+ def assertRewroteSemiJoin(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true)
+ }
+
+ def assertDidNotRewriteSemiJoin(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false)
+ }
+
+ def assertRewroteWithBloomFilter(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true)
+ }
+
+ def assertDidNotRewriteWithBloomFilter(query: String): Unit = {
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false)
+ }
+
+ test("Runtime semi join reduction: simple") {
+ // Filter creation side is 3409 bytes
+ // Filter application side scan is 3362 bytes
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62")
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+ }
+ }
+
+ test("Runtime semi join reduction: two joins") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
+ "and bf3.c3 = bf2.c2 where bf2.a2 = 5")
+ }
+ }
+
+ test("Runtime semi join reduction: three joins") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " +
+ "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5")
+ }
+ }
+
+ test("Runtime semi join reduction: simple expressions only") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ val squared = (s: Long) => {
+ s * s
+ }
+ spark.udf.register("square", squared)
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62")
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " +
+ "bf1.c1 = square(bf2.c2) where bf2.a2= 62")
+ }
+ }
+
+ test("Runtime bloom filter join: simple") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2")
+ }
+ }
+
+ test("Runtime bloom filter join: two filters single join") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+ }
+ }
+
+ test("Runtime bloom filter join: test the number of filter threshold") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ for (numFilterThreshold <- 0 to 3) {
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
+ SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ if (numFilterThreshold < 3) {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)
+ + numFilterThreshold)
+ } else {
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2)
+ }
+ }
+ }
+ }
+
+ test("Runtime bloom filter join: insert one bloom filter per column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ var planDisabled: LogicalPlan = null
+ var planEnabled: LogicalPlan = null
+ var expectedAnswer: Array[Row] = null
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.c1 = bf2.b2 where bf2.a2 = 62"
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") {
+ planDisabled = sql(query).queryExecution.optimizedPlan
+ expectedAnswer = sql(query).collect()
+ }
+
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") {
+ planEnabled = sql(query).queryExecution.optimizedPlan
+ checkAnswer(sql(query), expectedAnswer)
+ }
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1)
+ }
+ }
+
+ test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " +
+ "on the same column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " +
+ "bf5part.f5 = bf2.c2 where bf2.a2 = 62")
+ }
+ }
+
+ test("Runtime bloom filter join: add bloom filter if dpp filter exists on " +
+ "a different column") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ assertRewroteWithBloomFilter("select * from bf5part join bf2 on " +
+ "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62")
+ }
+ }
+
+ test("Runtime bloom filter join: BF rewrite triggering threshold test") {
+ // Filter creation side data size is 3409 bytes. On the filter application side, an individual
+ // scan's byte size is 3362.
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+ ) {
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50"
+ ) {
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " +
+ "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000"
+ ) {
+ // Rewrite should not be triggered as the Bloom filter application side scan size is small.
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 "
+ + "where bf2.a2 = 62")
+ }
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32",
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") {
+ // Test that the max scan size rather than an individual scan size on the filter
+ // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes.
+ withSQLConf(
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") {
+ assertRewroteWithBloomFilter("select * from " +
+ "(select * from bf5filtered union all select * from bf2) t " +
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+ }
+ withSQLConf(
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") {
+ assertDidNotRewriteWithBloomFilter("select * from " +
+ "(select * from bf5filtered union all select * from bf2) t " +
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5")
+ }
+ }
+ }
+
+ test("Runtime bloom filter join: simple expressions only") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+ val squared = (s: Long) => {
+ s * s
+ }
+ spark.udf.register("square", squared)
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" )
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " +
+ "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" )
+ }
+ }
+
+ test("Support Left Semi join in row level runtime filters") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32") {
+ assertRewroteWithBloomFilter(
+ """
+ |SELECT *
+ |FROM bf1 LEFT SEMI
+ |JOIN (SELECT * FROM bf2 WHERE bf2.a2 = 62) tmp
+ |ON bf1.c1 = tmp.c2
+ """.stripMargin)
+ }
+ }
+
+ test("Merge runtime bloom filters") {
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000",
+ SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false",
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
+ // Re-enable `MergeScalarSubqueries`
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " +
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62"
+ val df = sql(query)
+ df.collect()
+ val plan = df.queryExecution.executedPlan
+
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 1,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index a090eba430..86c8c8261e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -27,19 +27,20 @@ import org.scalatest.Assertions._
import org.apache.spark.TestUtils
import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
/**
- * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and
- * Scalar Pandas UDFs can be tested in SBT & Maven tests.
+ * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
+ * Scalar Pandas UDF and Grouped Aggregate Pandas UDF can be tested in SBT & Maven tests.
*
- * The available UDFs are special. It defines an UDF wrapped by cast. So, the input column is
- * casted into string, UDF returns strings as are, and then output column is casted back to
- * the input column. In this way, UDF is virtually no-op.
+ * The available UDFs are special. For Scalar UDF, Python UDF and Scalar Pandas UDF,
+ * it defines an UDF wrapped by cast. So, the input column is casted into string,
+ * UDF returns strings as are, and then output column is casted back to the input column.
+ * In this way, UDF is virtually no-op.
*
* Note that, due to this implementation limitation, complex types such as map, array and struct
* types do not work with this UDFs because they cannot be same after the cast roundtrip.
@@ -69,6 +70,28 @@ import org.apache.spark.sql.types.StringType
* df.select(expr("udf_name(id)")
* df.select(pandasTestUDF(df("id")))
* }}}
+ *
+ * For Grouped Aggregate Pandas UDF, it defines an UDF that calculates the count using pandas.
+ * The UDF returns the count of the given column. In this way, UDF is virtually not no-op.
+ *
+ * To register Grouped Aggregate Pandas UDF in SQL:
+ * {{{
+ * val groupedAggPandasTestUDF = TestGroupedAggPandasUDF(name = "udf_name")
+ * registerTestUDF(groupedAggPandasTestUDF, spark)
+ * }}}
+ *
+ * To use it in Scala API and SQL:
+ * {{{
+ * sql("SELECT udf_name(1)")
+ * val df = Seq(
+ * (536361, "85123A", 2, 17850),
+ * (536362, "85123B", 4, 17850),
+ * (536363, "86123A", 6, 17851)
+ * ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+ *
+ * df.groupBy("CustomerID").agg(expr("udf_name(Quantity)"))
+ * df.groupBy("CustomerID").agg(groupedAggPandasTestUDF(df("Quantity")))
+ * }}}
*/
object IntegratedUDFTestUtils extends SQLHelper {
import scala.sys.process._
@@ -190,6 +213,28 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}
+ private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) {
+ var binaryPandasFunc: Array[Byte] = null
+ withTempPath { path =>
+ Process(
+ Seq(
+ pythonExec,
+ "-c",
+ "from pyspark.sql.types import IntegerType; " +
+ "from pyspark.serializers import CloudPickleSerializer; " +
+ s"f = open('$path', 'wb');" +
+ "f.write(CloudPickleSerializer().dumps((" +
+ "lambda x: x.agg('count'), IntegerType())))"),
+ None,
+ "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+ binaryPandasFunc = Files.readAllBytes(path.toPath)
+ }
+ assert(binaryPandasFunc != null)
+ binaryPandasFunc
+ } else {
+ throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
+ }
+
// Make sure this map stays mutable - this map gets updated later in Python runners.
private val workerEnv = new java.util.HashMap[String, String]()
workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
@@ -209,6 +254,8 @@ object IntegratedUDFTestUtils extends SQLHelper {
lazy val shouldTestScalarPandasUDFs: Boolean =
isPythonAvailable && isPandasAvailable && isPyArrowAvailable
+ lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs
+
/**
* A base trait for various UDFs defined in this object.
*/
@@ -218,6 +265,29 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String
}
+ class PythonUDFWithoutId(
+ name: String,
+ func: PythonFunction,
+ dataType: DataType,
+ children: Seq[Expression],
+ evalType: Int,
+ udfDeterministic: Boolean,
+ resultId: ExprId)
+ extends PythonUDF(name, func, dataType, children, evalType, udfDeterministic, resultId) {
+
+ def this(pudf: PythonUDF) = {
+ this(pudf.name, pudf.func, pudf.dataType, pudf.children,
+ pudf.evalType, pudf.udfDeterministic, pudf.resultId)
+ }
+
+ override def toString: String = s"$name(${children.mkString(", ")})"
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): PythonUDFWithoutId = {
+ new PythonUDFWithoutId(super.withNewChildrenInternal(newChildren))
+ }
+ }
+
/**
* A Python UDF that takes one column, casts into string, executes the Python native function,
* and casts back to the type of input column.
@@ -253,7 +323,9 @@ object IntegratedUDFTestUtils extends SQLHelper {
val expr = e.head
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
- Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType)
+ val pythonUDF = new PythonUDFWithoutId(
+ super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF])
+ Cast(pythonUDF, expr.dataType)
}
}
@@ -297,7 +369,9 @@ object IntegratedUDFTestUtils extends SQLHelper {
val expr = e.head
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
- Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType)
+ val pythonUDF = new PythonUDFWithoutId(
+ super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF])
+ Cast(pythonUDF, expr.dataType)
}
}
@@ -306,6 +380,46 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String = "Scalar Pandas UDF"
}
+ /**
+ * A Grouped Aggregate Pandas UDF that takes one column, executes the
+ * Python native function calculating the count of the column using pandas.
+ *
+ * Virtually equivalent to:
+ *
+ * {{{
+ * import pandas as pd
+ * from pyspark.sql.functions import pandas_udf
+ *
+ * df = spark.createDataFrame(
+ * [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
+ *
+ * @pandas_udf("double")
+ * def pandas_count(v: pd.Series) -> int:
+ * return v.count()
+ *
+ * count_col = pandas_count(df['v'])
+ * }}}
+ */
+ case class TestGroupedAggPandasUDF(name: String) extends TestUDF {
+ private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
+ name = name,
+ func = PythonFunction(
+ command = pandasGroupedAggFunc,
+ envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+ pythonIncludes = List.empty[String].asJava,
+ pythonExec = pythonExec,
+ pythonVer = pythonVer,
+ broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+ accumulator = null),
+ dataType = IntegerType,
+ pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ udfDeterministic = true)
+
+ def apply(exprs: Column*): Column = udf(exprs: _*)
+
+ val prettyName: String = "Grouped Aggregate Pandas UDF"
+ }
+
/**
* A Scala UDF that takes one column, casts into string, executes the
* Scala native function, and casts back to the type of input column.
@@ -360,6 +474,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match {
case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf)
case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf)
+ case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf)
case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf)
case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 9f4c24b46a..1792b4c32e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.log4j.Level
+import org.apache.logging.log4j.Level
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, EliminateResolvedHint}
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -55,7 +55,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
}
val warningMessages = logAppender.loggingEvents
.filter(_.getLevel == Level.WARN)
- .map(_.getRenderedMessage)
+ .map(_.getMessage.getFormattedMessage)
.filter(_.contains("hint"))
assert(warningMessages.size == warnings.size)
warnings.foreach { w =>
@@ -597,4 +597,115 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
assert(df9.collect().size == df10.collect().size)
}
}
+
+ test("SPARK-35221: Add join hint build side check") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
+ Seq("left_outer", "left_semi", "left_anti").foreach { joinType =>
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertShuffleMergeJoin(
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType))
+ assertShuffleMergeJoin(
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType))
+ }
+
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.size === 2)
+ logs.foreach(log =>
+ assert(log.contains(s"build left for ${joinType.split("_").mkString(" ")} join.")))
+ }
+
+ Seq("left_outer", "left_semi", "left_anti").foreach { joinType =>
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertBroadcastHashJoin(
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType), BuildRight)
+ assertShuffleHashJoin(
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType), BuildRight)
+ }
+
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.isEmpty)
+ }
+
+ Seq("right_outer").foreach { joinType =>
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertShuffleMergeJoin(
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType))
+ assertShuffleMergeJoin(
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType))
+ }
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.size === 2)
+ logs.foreach(log =>
+ assert(log.contains(s"build right for ${joinType.split("_").mkString(" ")} join.")))
+ }
+
+ Seq("right_outer").foreach { joinType =>
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertBroadcastHashJoin(
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType), BuildLeft)
+ assertShuffleHashJoin(
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType), BuildLeft)
+ }
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.isEmpty)
+ }
+
+ Seq("inner", "cross").foreach { joinType =>
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertBroadcastHashJoin(
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType), BuildLeft)
+ assertBroadcastHashJoin(
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType), BuildRight)
+
+ assertShuffleHashJoin(
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType), BuildLeft)
+ assertShuffleHashJoin(
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType), BuildRight)
+ }
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.isEmpty)
+ }
+ }
+ }
+
+ test("SPARK-35221: Add join hint non equi-join check") {
+ val hintAppender = new LogAppender(s"join hint check for equi-join")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertBroadcastNLJoin(
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" !== $"b1"), BuildRight)
+ }
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ assertBroadcastNLJoin(
+ df1.join(df2.hint("MERGE"), $"a1" !== $"b1"), BuildRight)
+ }
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query:"))
+ assert(logs.size === 2)
+ logs.foreach(log => assert(log.contains("no equi-join keys")))
+ }
+
+ test("SPARK-36652: AQE dynamic join selection should not apply to non-equi join") {
+ val hintAppender = new LogAppender(s"join hint check for equi-join")
+ withLogAppender(hintAppender, level = Some(Level.WARN)) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "64MB") {
+ df1.join(df2.repartition($"b1"), $"a1" =!= $"b1").collect()
+ }
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ .filter(_.contains("is not supported in the query: no equi-join keys"))
+ assert(logs.isEmpty)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index abfc19ac6d..4a8421a221 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -183,7 +183,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("inner join where, one match per row") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(
- upperCaseData.join(lowerCaseData).where('n === 'N),
+ upperCaseData.join(lowerCaseData).where(Symbol("n") === 'N),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -404,8 +404,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("full outer join") {
withTempView("`left`", "`right`") {
- upperCaseData.where('N <= 4).createOrReplaceTempView("`left`")
- upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
+ upperCaseData.where(Symbol("N") <= 4).createOrReplaceTempView("`left`")
+ upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`")
val left = UnresolvedRelation(TableIdentifier("left"))
val right = UnresolvedRelation(TableIdentifier("right"))
@@ -623,7 +623,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
testData.createOrReplaceTempView("B")
testData2.createOrReplaceTempView("C")
testData3.createOrReplaceTempView("D")
- upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
+ upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`")
val cartesianQueries = Seq(
/** The following should error out since there is no explicit cross join */
"SELECT * FROM testData inner join testData2",
@@ -1074,8 +1074,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val df = left.crossJoin(right).where(pythonTestUDF(left("a")) === right.col("c"))
// Before optimization, there is a logical Filter operator.
- val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter])
- assert(filterInAnalysis.isDefined)
+ val filterInAnalysis = df.queryExecution.analyzed.exists(_.isInstanceOf[Filter])
+ assert(filterInAnalysis)
// Filter predicate was pushdown as join condition. So there is no Filter exec operator.
val filterExec = find(df.queryExecution.executedPlan)(_.isInstanceOf[FilterExec])
@@ -1402,4 +1402,42 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
assertJoin(sql, classOf[ShuffledHashJoinExec])
}
}
+
+ test("SPARK-36794: Ignore duplicated key when building relation for semi/anti hash join") {
+ withTable("t1", "t2") {
+ spark.range(10).map(i => (i.toString, i + 1)).toDF("c1", "c2").write.saveAsTable("t1")
+ spark.range(10).map(i => ((i % 5).toString, i % 3)).toDF("c1", "c2").write.saveAsTable("t2")
+
+ val semiJoinQueries = Seq(
+ // No join condition, ignore duplicated key.
+ (s"SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2 ON t1.c1 = t2.c1",
+ true),
+ // Have join condition on build join key only, ignore duplicated key.
+ (s"""
+ |SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2
+ |ON t1.c1 = t2.c1 AND CAST(t1.c2 * 2 AS STRING) != t2.c1
+ """.stripMargin,
+ true),
+ // Have join condition on other build attribute beside join key, do not ignore
+ // duplicated key.
+ (s"""
+ |SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2
+ |ON t1.c1 = t2.c1 AND t1.c2 * 100 != t2.c2
+ """.stripMargin,
+ false)
+ )
+ semiJoinQueries.foreach {
+ case (query, ignoreDuplicatedKey) =>
+ val semiJoinDF = sql(query)
+ val antiJoinDF = sql(query.replaceAll("SEMI", "ANTI"))
+ checkAnswer(semiJoinDF, Seq(Row("0"), Row("1"), Row("2"), Row("3"), Row("4")))
+ checkAnswer(antiJoinDF, Seq(Row("5"), Row("6"), Row("7"), Row("8"), Row("9")))
+ Seq(semiJoinDF, antiJoinDF).foreach { df =>
+ assert(collect(df.queryExecution.executedPlan) {
+ case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
+ }.size == 1)
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 82cca2b737..1c6bbc5a09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -390,16 +390,16 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
test("SPARK-24027: from_json of a map with unsupported key type") {
val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType)
-
- checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)),
- Row(null))
- checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)),
- Row(null))
+ val startMsg = "cannot resolve 'entries' due to data type mismatch:"
+ val exception = intercept[AnalysisException] {
+ Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema))
+ }.getMessage
+ assert(exception.contains(startMsg))
}
test("SPARK-24709: infers schemas of json strings and pass them to from_json") {
val in = Seq("""{"a": [1, 2, 3]}""").toDS()
- val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed")
+ val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a": [1]}""")) as "parsed")
val expected = StructType(StructField(
"parsed",
StructType(StructField(
@@ -413,7 +413,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
test("infers schemas using options") {
val df = spark.range(1)
.select(schema_of_json(lit("{a:1}"), Map("allowUnquotedFieldNames" -> "true").asJava))
- checkAnswer(df, Seq(Row("STRUCT<`a`: BIGINT>")))
+ checkAnswer(df, Seq(Row("STRUCT<a: BIGINT>")))
}
test("from_json - array of primitive types") {
@@ -595,6 +595,31 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
}
}
+ test("SPARK-36069: from_json invalid json schema - check field name and field value") {
+ withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") {
+ val schema = new StructType()
+ .add("a", IntegerType)
+ .add("b", IntegerType)
+ .add("_unparsed", StringType)
+ val badRec = """{"a": "1", "b": 11}"""
+ val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS()
+
+ checkAnswer(
+ df.select(from_json($"value", schema, Map("mode" -> "PERMISSIVE"))),
+ Row(Row(null, 11, badRec)) :: Row(Row(2, 12, null)) :: Nil)
+
+ val errMsg = intercept[SparkException] {
+ df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))).collect()
+ }.getMessage
+
+ assert(errMsg.contains(
+ "Malformed records are detected in record parsing. Parse Mode: FAILFAST."))
+ assert(errMsg.contains(
+ "Failed to parse field name a, field value 1, " +
+ "[VALUE_STRING] to target spark data type [IntegerType]."))
+ }
+ }
+
test("corrupt record column in the middle") {
val schema = new StructType()
.add("a", IntegerType)
@@ -668,14 +693,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
val input = regexp_replace(lit("""{"item_id": 1, "item_price": 0.1}"""), "item_", "")
checkAnswer(
spark.range(1).select(schema_of_json(input)),
- Seq(Row("STRUCT<`id`: BIGINT, `price`: DOUBLE>")))
+ Seq(Row("STRUCT<id: BIGINT, price: DOUBLE>")))
}
test("SPARK-31065: schema_of_json - null and empty strings as strings") {
Seq("""{"id": null}""", """{"id": ""}""").foreach { input =>
checkAnswer(
spark.range(1).select(schema_of_json(input)),
- Seq(Row("STRUCT<`id`: STRING>")))
+ Seq(Row("STRUCT<id: STRING>")))
}
}
@@ -687,7 +712,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
schema_of_json(
lit("""{"id": "a", "drop": {"drop": null}}"""),
options.asJava)),
- Seq(Row("STRUCT<`id`: STRING>")))
+ Seq(Row("STRUCT<id: STRING>")))
// Array of structs
checkAnswer(
@@ -695,7 +720,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
schema_of_json(
lit("""[{"id": "a", "drop": {"drop": null}}]"""),
options.asJava)),
- Seq(Row("ARRAY<STRUCT<`id`: STRING>>")))
+ Seq(Row("ARRAY<STRUCT<id: STRING>>")))
// Other types are not affected.
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 1e7e8c1c5e..88071da293 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.nio.charset.StandardCharsets
+import java.time.Period
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
@@ -46,12 +47,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
c: Column => Column,
f: T => U): Unit = {
checkAnswer(
- doubleData.select(c('a)),
+ doubleData.select(c(Symbol("a"))),
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
)
checkAnswer(
- doubleData.select(c('b)),
+ doubleData.select(c(Symbol("b"))),
(1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T])))
)
@@ -64,13 +65,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit =
{
checkAnswer(
- nnDoubleData.select(c('a)),
+ nnDoubleData.select(c(Symbol("a"))),
(1 to 10).map(n => Row(f(n * 0.1)))
)
if (f(-1) === StrictMath.log1p(-1)) {
checkAnswer(
- nnDoubleData.select(c('b)),
+ nnDoubleData.select(c(Symbol("b"))),
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null)
)
}
@@ -86,12 +87,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
d: (Column, Double) => Column,
f: (Double, Double) => Double): Unit = {
checkAnswer(
- nnDoubleData.select(c('a, 'a)),
+ nnDoubleData.select(c('a, Symbol("a"))),
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0))))
)
checkAnswer(
- nnDoubleData.select(c('a, 'b)),
+ nnDoubleData.select(c('a, Symbol("b"))),
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1))))
)
@@ -108,7 +109,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null)
checkAnswer(
- nullDoubles.select(c('a, 'a)).orderBy('a.asc),
+ nullDoubles.select(c('a, Symbol("a"))).orderBy(Symbol("a").asc),
Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0))))
)
}
@@ -117,6 +118,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(sin, math.sin)
}
+ test("csc") {
+ testOneToOneMathFunction(csc,
+ (x: Double) => (1 / math.sin(x)) )
+ }
+
test("asin") {
testOneToOneMathFunction(asin, math.asin)
}
@@ -134,6 +140,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(cos, math.cos)
}
+ test("sec") {
+ testOneToOneMathFunction(sec,
+ (x: Double) => (1 / math.cos(x)) )
+ }
+
test("acos") {
testOneToOneMathFunction(acos, math.acos)
}
@@ -151,6 +162,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(tan, math.tan)
}
+ test("cot") {
+ testOneToOneMathFunction(cot,
+ (x: Double) => (1 / math.tan(x)) )
+ }
+
test("atan") {
testOneToOneMathFunction(atan, math.atan)
}
@@ -186,6 +202,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("ceil and ceiling") {
testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong)
+ // testOneToOneMathFunction does not validate the resulting data type
+ assert(
+ spark.range(1).select(ceil(col("id")).alias("a")).schema ==
+ types.StructType(Seq(types.StructField("a", types.LongType))))
+ assert(
+ spark.range(1).select(ceil(col("id"), lit(0)).alias("a")).schema ==
+ types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0)))))
checkAnswer(
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0L, 1L, 2L))
@@ -234,12 +257,19 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("floor") {
testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong)
+ // testOneToOneMathFunction does not validate the resulting data type
+ assert(
+ spark.range(1).select(floor(col("id")).alias("a")).schema ==
+ types.StructType(Seq(types.StructField("a", types.LongType))))
+ assert(
+ spark.range(1).select(floor(col("id"), lit(0)).alias("a")).schema ==
+ types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0)))))
}
test("factorial") {
val df = (0 to 5).map(i => (i, i)).toDF("a", "b")
checkAnswer(
- df.select(factorial('a)),
+ df.select(factorial(Symbol("a"))),
Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
)
checkAnswer(
@@ -252,16 +282,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
testOneToOneMathFunction(rint, math.rint)
}
- test("round/bround") {
+ test("round/bround/ceil/floor") {
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
checkAnswer(
- df.select(round('a), round('a, -1), round('a, -2)),
+ df.select(round(Symbol("a")), round('a, -1), round('a, -2)),
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
checkAnswer(
- df.select(bround('a), bround('a, -1), bround('a, -2)),
+ df.select(bround(Symbol("a")), bround('a, -1), bround('a, -2)),
Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
)
+ checkAnswer(
+ df.select(ceil('a), ceil('a, lit(-1)), ceil('a, lit(-2))),
+ Seq(Row(5, 10, 100), Row(55, 60, 100), Row(555, 560, 600))
+ )
+ checkAnswer(
+ df.select(floor('a), floor('a, lit(-1)), floor('a, lit(-2))),
+ Seq(Row(5, 0, 0), Row(55, 50, 0), Row(555, 550, 500))
+ )
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
val pi = "3.1415"
@@ -277,6 +315,18 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
+ checkAnswer(
+ sql(s"SELECT ceil($pi), ceil($pi, -3), ceil($pi, -2), ceil($pi, -1), " +
+ s"ceil($pi, 0), ceil($pi, 1), ceil($pi, 2), ceil($pi, 3)"),
+ Seq(Row(BigDecimal(4), BigDecimal("1E3"), BigDecimal("1E2"), BigDecimal("1E1"),
+ BigDecimal(4), BigDecimal("3.2"), BigDecimal("3.15"), BigDecimal("3.142")))
+ )
+ checkAnswer(
+ sql(s"SELECT floor($pi), floor($pi, -3), floor($pi, -2), floor($pi, -1), " +
+ s"floor($pi, 0), floor($pi, 1), floor($pi, 2), floor($pi, 3)"),
+ Seq(Row(BigDecimal(3), BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"),
+ BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.141")))
+ )
}
val bdPi: BigDecimal = BigDecimal(31415925L, 7)
@@ -291,21 +341,46 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
)
+ checkAnswer(
+ sql(s"SELECT ceil($bdPi, 7), ceil($bdPi, 8), ceil($bdPi, 9), ceil($bdPi, 10), " +
+ s"ceil($bdPi, 100), ceil($bdPi, 6), ceil(null, 8)"),
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
+ )
+
+ checkAnswer(
+ sql(s"SELECT floor($bdPi, 7), floor($bdPi, 8), floor($bdPi, 9), floor($bdPi, 10), " +
+ s"floor($bdPi, 100), floor($bdPi, 6), floor(null, 8)"),
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
+ )
}
- test("round/bround with data frame from a local Seq of Product") {
+ test("round/bround/ceil/floor with data frame from a local Seq of Product") {
val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
checkAnswer(
- df.withColumn("value_rounded", round('value)),
+ df.withColumn("value_rounded", round(Symbol("value"))),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
checkAnswer(
- df.withColumn("value_brounded", bround('value)),
+ df.withColumn("value_brounded", bround(Symbol("value"))),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
+ checkAnswer(
+ df
+ .withColumn("value_ceil", ceil('value))
+ .withColumn("value_ceil1", ceil('value, lit(0)))
+ .withColumn("value_ceil2", ceil('value, lit(1))),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"), BigDecimal("6"), BigDecimal("5.9")))
+ )
+ checkAnswer(
+ df
+ .withColumn("value_floor", floor('value))
+ .withColumn("value_floor1", floor('value, lit(0)))
+ .withColumn("value_floor2", floor('value, lit(1))),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5"), BigDecimal("5"), BigDecimal("5.9")))
+ )
}
- test("round/bround with table columns") {
+ test("round/bround/ceil/floor with table columns") {
withTable("t") {
Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
checkAnswer(
@@ -314,6 +389,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(
sql("select i, bround(i) from t"),
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+ checkAnswer(
+ sql("select i, ceil(i) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+ checkAnswer(
+ sql("select i, ceil(i, 0) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+ checkAnswer(
+ sql("select i, ceil(i, 1) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5.9"))))
+ checkAnswer(
+ sql("select i, floor(i) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5"))))
+ checkAnswer(
+ sql("select i, floor(i, 0) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5"))))
+ checkAnswer(
+ sql("select i, floor(i, 1) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5.9"))))
}
}
@@ -344,10 +437,10 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("hex") {
val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d")
- checkAnswer(data.select(hex('a)), Seq(Row("1C")))
- checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4")))
- checkAnswer(data.select(hex('c)), Seq(Row("177828FED4")))
- checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F")))
+ checkAnswer(data.select(hex(Symbol("a"))), Seq(Row("1C")))
+ checkAnswer(data.select(hex(Symbol("b"))), Seq(Row("FFFFFFFFFFFFFFE4")))
+ checkAnswer(data.select(hex(Symbol("c"))), Seq(Row("177828FED4")))
+ checkAnswer(data.select(hex(Symbol("d"))), Seq(Row("68656C6C6F")))
checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C")))
checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4")))
checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4")))
@@ -357,8 +450,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("unhex") {
val data = Seq(("1C", "737472696E67")).toDF("a", "b")
- checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte)))
- checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8)))
+ checkAnswer(data.select(unhex(Symbol("a"))), Row(Array[Byte](28.toByte)))
+ checkAnswer(data.select(unhex(Symbol("b"))), Row("string".getBytes(StandardCharsets.UTF_8)))
checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte)))
checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8)))
checkAnswer(data.selectExpr("""unhex("##")"""), Row(null))
@@ -505,4 +598,20 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.selectExpr("positive(a)"), Row(1))
checkAnswer(df.selectExpr("positive(b)"), Row(-1))
}
+
+ test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") {
+ Seq(
+ (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10) -> 0,
+ (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10) -> 1,
+ (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
+ (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10) -> 2,
+ (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
+ (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10) -> 11,
+ (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 6,
+ (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 5
+ ).foreach { case ((value, start, end, num), expected) =>
+ val df = Seq((value, start, end, num)).toDF("v", "s", "e", "n")
+ checkAnswer(df.selectExpr("width_bucket(v, s, e, n)"), Row(expected))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
index b166b9b684..37ba52023d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.BinaryType
class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
@@ -49,13 +50,32 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
val df = sql("select current_user(), current_user")
checkAnswer(df, Row(user, user))
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true",
+ SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") {
val df = sql("select current_user")
checkAnswer(df, Row(spark.sparkContext.sparkUser))
val e = intercept[ParseException](sql("select current_user()"))
assert(e.getMessage.contains("current_user"))
}
}
+
+ test("SPARK-37591: AES functions - GCM mode") {
+ Seq(
+ ("abcdefghijklmnop", ""),
+ ("abcdefghijklmnop", "abcdefghijklmnop"),
+ ("abcdefghijklmnop12345678", "Spark"),
+ ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode")
+ ).foreach { case (key, input) =>
+ val df = Seq((key, input)).toDF("key", "input")
+ val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS enc", "input", "key")
+ assert(encrypted.schema("enc").dataType === BinaryType)
+ assert(encrypted.filter($"enc" === $"input").isEmpty)
+ val result = encrypted.selectExpr(
+ "CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input")
+ assert(!result.filter($"res" === $"input").isEmpty &&
+ result.filter($"res" =!= $"input").isEmpty)
+ }
+ }
}
object ReflectClass {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
new file mode 100644
index 0000000000..823c1375de
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * End-to-end tests for percentile aggregate function.
+ */
+class PercentileQuerySuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ private val table = "percentile_test"
+
+ test("SPARK-37138, SPARK-39427: Disable Ansi Interval type in Percentile") {
+ withTempView(table) {
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
+ (Period.ofMonths(200), Duration.ofSeconds(200L)),
+ (Period.ofMonths(300), Duration.ofSeconds(300L)))
+ .toDF("col1", "col2").createOrReplaceTempView(table)
+ val e = intercept[AnalysisException] {
+ spark.sql(
+ s"""SELECT
+ | CAST(percentile(col1, 0.5) AS STRING),
+ | SUM(null),
+ | CAST(percentile(col2, 0.5) AS STRING)
+ |FROM $table""".stripMargin).collect()
+ }
+ assert(e.getMessage.contains("data type mismatch"))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index 69a8c72153..8cbb841e7d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec, ValidateRequirements}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.tags.ExtendedSQLTest
// scalastyle:off line.size.limit
/**
@@ -47,38 +48,28 @@ import org.apache.spark.sql.internal.SQLConf
*
* To run the entire test suite:
* {{{
- * build/sbt "sql/testOnly *PlanStability[WithStats]Suite"
+ * build/sbt "sql/testOnly *PlanStability*Suite"
* }}}
*
* To run a single test file upon change:
* {{{
- * build/sbt "sql/testOnly *PlanStability[WithStats]Suite -- -z (tpcds-v1.4/q49)"
+ * build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)"
* }}}
*
* To re-generate golden files for entire suite, run:
* {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability[WithStats]Suite"
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability*Suite"
+ * SPARK_GENERATE_GOLDEN_FILES=1 SPARK_ANSI_SQL_MODE=true build/sbt "sql/testOnly *PlanStability*Suite"
* }}}
*
* To re-generate golden file for a single test, run:
* {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability[WithStats]Suite -- -z (tpcds-v1.4/q49)"
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)"
+ * SPARK_GENERATE_GOLDEN_FILES=1 SPARK_ANSI_SQL_MODE=true build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)"
* }}}
*/
// scalastyle:on line.size.limit
-trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite {
-
- private val originalMaxToStringFields = conf.maxToStringFields
-
- override def beforeAll(): Unit = {
- conf.setConf(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue)
- super.beforeAll()
- }
-
- override def afterAll(): Unit = {
- super.afterAll()
- conf.setConf(SQLConf.MAX_TO_STRING_FIELDS, originalMaxToStringFields)
- }
+trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
@@ -89,13 +80,24 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite {
private val referenceRegex = "#\\d+".r
private val normalizeRegex = "#\\d+L?".r
+ private val planIdRegex = "plan_id=\\d+".r
private val clsName = this.getClass.getCanonicalName
def goldenFilePath: String
+ private val approvedAnsiPlans: Seq[String] = Seq(
+ "q83",
+ "q83.sf100"
+ )
+
private def getDirForTest(name: String): File = {
- new File(goldenFilePath, name)
+ val goldenFileName = if (SQLConf.get.ansiEnabled && approvedAnsiPlans.contains(name)) {
+ name + ".ansi"
+ } else {
+ name
+ }
+ new File(goldenFilePath, goldenFileName)
}
private def isApproved(
@@ -233,7 +235,15 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite {
val map = new mutable.HashMap[String, String]()
normalizeRegex.findAllMatchIn(plan).map(_.toString)
.foreach(map.getOrElseUpdate(_, (map.size + 1).toString))
- normalizeRegex.replaceAllIn(plan, regexMatch => s"#${map(regexMatch.toString)}")
+ val exprIdNormalized = normalizeRegex.replaceAllIn(
+ plan, regexMatch => s"#${map(regexMatch.toString)}")
+
+ // Normalize the plan id in Exchange nodes. See `Exchange.stringArgs`.
+ val planIdMap = new mutable.HashMap[String, String]()
+ planIdRegex.findAllMatchIn(exprIdNormalized).map(_.toString)
+ .foreach(planIdMap.getOrElseUpdate(_, (planIdMap.size + 1).toString))
+ planIdRegex.replaceAllIn(
+ exprIdNormalized, regexMatch => s"plan_id=${planIdMap(regexMatch.toString)}")
}
private def normalizeLocation(plan: String): String = {
@@ -262,7 +272,8 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite {
}
}
-class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite with TPCDSBase {
override val goldenFilePath: String =
new File(baseResourcePath, s"approved-plans-v1_4").getAbsolutePath
@@ -273,7 +284,8 @@ class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite {
}
}
-class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase {
override def injectStats: Boolean = true
override val goldenFilePath: String =
@@ -286,7 +298,8 @@ class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite {
}
}
-class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite with TPCDSBase {
override val goldenFilePath: String =
new File(baseResourcePath, s"approved-plans-v2_7").getAbsolutePath
@@ -297,7 +310,8 @@ class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite {
}
}
-class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase {
override def injectStats: Boolean = true
override val goldenFilePath: String =
@@ -310,7 +324,8 @@ class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite {
}
}
-class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite with TPCDSBase {
override val goldenFilePath: String =
new File(baseResourcePath, s"approved-plans-modified").getAbsolutePath
@@ -321,7 +336,8 @@ class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite {
}
}
-class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite {
+@ExtendedSQLTest
+class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase {
override def injectStats: Boolean = true
override val goldenFilePath: String =
@@ -333,3 +349,15 @@ class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite {
}
}
}
+
+@ExtendedSQLTest
+class TPCHPlanStabilitySuite extends PlanStabilitySuite with TPCHBase {
+ override def goldenFilePath: String = getWorkspaceFilePath(
+ "sql", "core", "src", "test", "resources", "tpch-plan-stability").toFile.getAbsolutePath
+
+ tpchQueries.foreach { q =>
+ test(s"check simplified (tpch/$q)") {
+ testQuery("tpch", q)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8469216901..06f94c62d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -207,11 +207,12 @@ abstract class QueryTest extends PlanTest {
*/
def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = {
val planWithCaching = query.queryExecution.withCachedData
- val matched = planWithCaching.collectFirst { case cached: InMemoryRelation =>
- val cacheBuilder = cached.asInstanceOf[InMemoryRelation].cacheBuilder
- cachedName == cacheBuilder.tableName.get &&
- (storageLevel == cacheBuilder.storageLevel)
- }.getOrElse(false)
+ val matched = planWithCaching.exists {
+ case cached: InMemoryRelation =>
+ val cacheBuilder = cached.cacheBuilder
+ cachedName == cacheBuilder.tableName.get && (storageLevel == cacheBuilder.storageLevel)
+ case _ => false
+ }
assert(matched, s"Expected query plan to hit cache $cachedName with storage " +
s"level $storageLevel, but it doesn't.")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala
index 739b4052ee..8883e9be19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala
@@ -59,13 +59,13 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared
val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out")
checkAnswer(q5, Row(1) :: Row(1) :: Nil)
q5.queryExecution.executedPlan.foreach { p =>
- assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty))
+ assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[If])))
}
val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END")
checkAnswer(q6, Row(2) :: Row(2) :: Nil)
q6.queryExecution.executedPlan.foreach { p =>
- assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty))
+ assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[CaseWhen])))
}
checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true))
@@ -75,10 +75,10 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared
test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") {
def assertNoLiteralNullInPlan(df: DataFrame): Unit = {
df.queryExecution.executedPlan.foreach { p =>
- assert(p.expressions.forall(_.find {
+ assert(p.expressions.forall(!_.exists {
case Literal(null, BooleanType) => true
case _ => false
- }.isEmpty))
+ }))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
index 2f56fbaf7f..f30465203d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
@@ -264,7 +264,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
val e = intercept[AnalysisException] {
sql("INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)")
}
- assert(e.getMessage.contains("Found duplicate keys 'c'"))
+ assert(e.getMessage.contains("Found duplicate keys `c`"))
}
// The following code is skipped for Hive because columns stored in Hive Metastore is always
// case insensitive and we cannot create such table in Hive Metastore.
@@ -286,20 +286,44 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
} else {
SQLConf.StoreAssignmentPolicy.values
}
+
+ def shouldThrowException(policy: SQLConf.StoreAssignmentPolicy.Value): Boolean = policy match {
+ case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT =>
+ true
+ case SQLConf.StoreAssignmentPolicy.LEGACY =>
+ false
+ }
+
testingPolicies.foreach { policy =>
- withSQLConf(
- SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) {
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) {
withTable("t") {
sql("create table t(a int, b string) using parquet partitioned by (a)")
- policy match {
- case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT =>
- val errorMsg = intercept[NumberFormatException] {
- sql("insert into t partition(a='ansi') values('ansi')")
- }.getMessage
- assert(errorMsg.contains("invalid input syntax for type numeric: ansi"))
- case SQLConf.StoreAssignmentPolicy.LEGACY =>
+ if (shouldThrowException(policy)) {
+ val errorMsg = intercept[NumberFormatException] {
sql("insert into t partition(a='ansi') values('ansi')")
- checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil)
+ }.getMessage
+ assert(errorMsg.contains(
+ """The value 'ansi' of the type "STRING" cannot be cast to "INT""""))
+ } else {
+ sql("insert into t partition(a='ansi') values('ansi')")
+ checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil)
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-38228: legacy store assignment should not fail on error under ANSI mode") {
+ // DS v2 doesn't support the legacy policy
+ if (format != "foo") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(
+ SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString,
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withTable("t") {
+ sql("create table t(a int) using parquet")
+ sql("insert into t values('ansi')")
+ checkAnswer(spark.table("t"), Row(null))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1515abb052..66f9700e8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,34 +24,40 @@ import java.time.{Duration, Period}
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.mutable
+
import org.apache.commons.io.FileUtils
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
-import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite}
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.{CommandResultExec, UnionExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
-import org.apache.spark.sql.execution.command.{DataWritingCommandExec, FunctionsCommand}
+import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
+import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.tags.ExtendedSQLTest
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.ResetSystemProperties
+@ExtendedSQLTest
class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper
with ResetSystemProperties {
import testImplicits._
@@ -65,8 +71,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ")
val queryCoalesce = sql("select coalesce(null, 1, '1') from src ")
- checkAnswer(queryCaseWhen, Row("1.0") :: Nil)
- checkAnswer(queryCoalesce, Row("1") :: Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(queryCaseWhen, Row("1.0") :: Nil)
+ checkAnswer(queryCoalesce, Row("1") :: Nil)
+ }
}
}
@@ -74,7 +82,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
def getFunctions(pattern: String): Seq[Row] = {
StringUtils.filterPattern(
spark.sessionState.catalog.listFunctions("default").map(_._1.funcName)
- ++ FunctionsCommand.virtualOperators, pattern)
+ ++ FunctionRegistry.builtinOperators.keys, pattern)
.map(Row(_))
}
@@ -123,7 +131,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkKeywordsNotExist(sql("describe functioN Upper"), "Extended Usage")
- checkKeywordsExist(sql("describe functioN abcadf"), "Function: abcadf not found.")
+ val e = intercept[AnalysisException](sql("describe functioN abcadf"))
+ assert(e.message.contains("Undefined function: abcadf. This function is neither a " +
+ "built-in/temporary function, nor a persistent function"))
}
test("SPARK-34678: describe functions for table-valued functions") {
@@ -387,10 +397,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
testCodeGen(
"SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
Row(100, 1, 50.5, 300, 100) :: Nil)
- // Aggregate with Code generation handling all null values
- testCodeGen(
- "SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(null, null, 0) :: Nil)
+ // Aggregate with Code generation handling all null values.
+ // If ANSI mode is on, there will be an error since 'a' cannot converted as Numeric.
+ // Here we simply test it when ANSI mode is off.
+ if (!conf.ansiEnabled) {
+ testCodeGen(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(null, null, 0) :: Nil)
+ }
} finally {
spark.catalog.dropTempView("testData3x")
}
@@ -482,9 +496,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Seq(Row(Timestamp.valueOf("1969-12-31 16:00:00.001")),
Row(Timestamp.valueOf("1969-12-31 16:00:00.002"))))
- checkAnswer(sql(
- "SELECT time FROM timestamps WHERE time='123'"),
- Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(sql(
+ "SELECT time FROM timestamps WHERE time='123'"),
+ Nil)
+ }
}
}
@@ -579,6 +595,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
|select * from q1 union all select * from q2""".stripMargin),
Row(5, "5") :: Row(4, "4") :: Nil)
+ // inner CTE relation refers to outer CTE relation.
+ withSQLConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY.key -> "CORRECTED") {
+ checkAnswer(
+ sql(
+ """
+ |with temp1 as (select 1 col),
+ |temp2 as (
+ | with temp1 as (select col + 1 AS col from temp1),
+ | temp3 as (select col + 1 from temp1)
+ | select * from temp3
+ |)
+ |select * from temp2
+ |""".stripMargin),
+ Row(3))
+ }
}
test("Allow only a single WITH clause per query") {
@@ -933,9 +964,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") ::
Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil)
// Column type mismatches are not allowed, forcing a type coercion.
- checkAnswer(
- sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
- ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
+ // When ANSI mode is on, the String input will be cast as Int in the following Union, which will
+ // cause a runtime error. Here we simply test the case when ANSI mode is off.
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
+ ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
+ }
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
intercept[AnalysisException] {
@@ -1032,32 +1067,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Row(Row(3, true), Map("C3" -> null)) ::
Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil)
- checkAnswer(
- sql("SELECT f1.f11, f2['D4'] FROM applySchema2"),
- Row(1, null) ::
- Row(2, null) ::
- Row(3, null) ::
- Row(4, 2147483644) :: Nil)
-
- // The value of a MapType column can be a mutable map.
- val rowRDD3 = unparsedStrings.map { r =>
- val values = r.split(",").map(_.trim)
- val v4 = try values(3).toInt catch {
- case _: NumberFormatException => null
+ // If ANSI mode is on, there will be an error "Key D4 does not exist".
+ if (!conf.ansiEnabled) {
+ checkAnswer(
+ sql("SELECT f1.f11, f2['D4'] FROM applySchema2"),
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
+
+ // The value of a MapType column can be a mutable map.
+ val rowRDD3 = unparsedStrings.map { r =>
+ val values = r.split(",").map(_.trim)
+ val v4 = try values(3).toInt catch {
+ case _: NumberFormatException => null
+ }
+ Row(Row(values(0).toInt, values(2).toBoolean),
+ scala.collection.mutable.Map(values(1) -> v4))
}
- Row(Row(values(0).toInt, values(2).toBoolean),
- scala.collection.mutable.Map(values(1) -> v4))
- }
- val df3 = spark.createDataFrame(rowRDD3, schema2)
- df3.createOrReplaceTempView("applySchema3")
+ val df3 = spark.createDataFrame(rowRDD3, schema2)
+ df3.createOrReplaceTempView("applySchema3")
- checkAnswer(
- sql("SELECT f1.f11, f2['D4'] FROM applySchema3"),
- Row(1, null) ::
- Row(2, null) ::
- Row(3, null) ::
- Row(4, 2147483644) :: Nil)
+ checkAnswer(
+ sql("SELECT f1.f11, f2['D4'] FROM applySchema3"),
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
+ }
}
}
@@ -1098,7 +1136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
|order by struct.a, struct.b
|""".stripMargin)
}
- assert(error.message contains "cannot resolve 'struct.a' given input columns: [a, b]")
+ assert(error.getErrorClass == "MISSING_COLUMN")
+ assert(error.messageParameters.sameElements(Array("struct.a", "a, b")))
}
@@ -1396,22 +1435,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-7952: fix the equality check between boolean and numeric types") {
- withTempView("t") {
- // numeric field i, boolean field j, result of i = j, result of i <=> j
- Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
- (1, true, true, true),
- (0, false, true, true),
- (2, true, false, false),
- (2, false, false, false),
- (null, true, null, false),
- (null, false, null, false),
- (0, null, null, false),
- (1, null, null, false),
- (null, null, null, true)
- ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t")
-
- checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
- checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
+ // If ANSI mode is on, Spark disallows comparing Int with Boolean.
+ if (!conf.ansiEnabled) {
+ withTempView("t") {
+ // numeric field i, boolean field j, result of i = j, result of i <=> j
+ Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
+ (1, true, true, true),
+ (0, false, true, true),
+ (2, true, false, false),
+ (2, false, false, false),
+ (null, true, null, false),
+ (null, false, null, false),
+ (0, null, null, false),
+ (1, null, null, false),
+ (null, null, null, true)
+ ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t")
+
+ checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
+ checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
+ }
}
}
@@ -1445,32 +1487,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val ymDF = sql("select interval 3 years -3 month")
checkAnswer(ymDF, Row(Period.of(2, 9, 0)))
- withTempPath(f => {
- val e = intercept[AnalysisException] {
- ymDF.write.json(f.getCanonicalPath)
- }
- e.message.contains("Cannot save interval data type into external storage")
- })
val dtDF = sql("select interval 5 days 8 hours 12 minutes 50 seconds")
checkAnswer(dtDF, Row(Duration.ofDays(5).plusHours(8).plusMinutes(12).plusSeconds(50)))
- withTempPath(f => {
- val e = intercept[AnalysisException] {
- dtDF.write.json(f.getCanonicalPath)
- }
- e.message.contains("Cannot save interval data type into external storage")
- })
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
val df = sql("select interval 3 years -3 month 7 week 123 microseconds")
checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7 * 7, 123)))
- withTempPath(f => {
- // Currently we don't yet support saving out values of interval data type.
- val e = intercept[AnalysisException] {
- df.write.json(f.getCanonicalPath)
- }
- e.message.contains("Cannot save interval data type into external storage")
- })
}
}
@@ -2689,10 +2712,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- val m1 = intercept[AnalysisException] {
- sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))")
- }.message
- assert(m1.contains("Union can only be performed on tables with the compatible column types"))
+ // Union resolves nested columns by position too.
+ checkAnswer(sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))"),
+ Row(Row(1)) :: Row(Row(2)) :: Nil)
val m2 = intercept[AnalysisException] {
sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))")
@@ -2720,8 +2742,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("SELECT i from (SELECT i FROM v)"), Row(1))
val e = intercept[AnalysisException](sql("SELECT v.i from (SELECT i FROM v)"))
- assert(e.message ==
- "cannot resolve 'v.i' given input columns: [__auto_generated_subquery_name.i]")
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("v.i", "__auto_generated_subquery_name.i")))
checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1))
}
@@ -2809,15 +2831,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("Non-deterministic aggregate functions should not be deduplicated") {
- val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a"
- val df = sql(query)
- val physical = df.queryExecution.sparkPlan
- val aggregateExpressions = physical.collectFirst {
- case agg : HashAggregateExec => agg.aggregateExpressions
- case agg : SortAggregateExec => agg.aggregateExpressions
+ withUserDefinedFunction("sumND" -> true) {
+ spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] {
+ def zero: Long = 0L
+ def reduce(b: Long, a: Long): Long = b + a
+ def merge(b1: Long, b2: Long): Long = b1 + b2
+ def finish(r: Long): Long = r
+ def bufferEncoder: Encoder[Long] = Encoders.scalaLong
+ def outputEncoder: Encoder[Long] = Encoders.scalaLong
+ }).asNondeterministic())
+
+ val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a"
+ val df = sql(query)
+ val physical = df.queryExecution.sparkPlan
+ val aggregateExpressions = physical.collectFirst {
+ case agg: BaseAggregateExec => agg.aggregateExpressions
+ }
+ assert(aggregateExpressions.isDefined)
+ assert(aggregateExpressions.get.size == 2)
}
- assert (aggregateExpressions.isDefined)
- assert (aggregateExpressions.get.size == 2)
}
test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {
@@ -3051,15 +3083,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val df = spark.read.format(format).load(dir.getCanonicalPath)
checkPushedFilters(
format,
- df.where(('id < 2 and 's.contains("foo")) or ('id > 10 and 's.contains("bar"))),
+ df.where((Symbol("id") < 2 and Symbol("s").contains("foo")) or
+ (Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array(sources.Or(sources.LessThan("id", 2), sources.GreaterThan("id", 10))))
checkPushedFilters(
format,
- df.where('s.contains("foo") or ('id > 10 and 's.contains("bar"))),
+ df.where(Symbol("s").contains("foo") or
+ (Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array.empty)
checkPushedFilters(
format,
- df.where('id < 2 and not('id > 10 and 's.contains("bar"))),
+ df.where(Symbol("id") < 2 and not(Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array(sources.IsNotNull("id"), sources.LessThan("id", 2)))
}
}
@@ -3140,16 +3174,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select * from t1 where d >= '2000-01-01'"), Row(result))
checkAnswer(sql("select * from t1 where d >= '2000-01-02'"), Nil)
checkAnswer(sql("select * from t1 where '2000' >= d"), Row(result))
- checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil)
+ }
withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") {
checkAnswer(sql("select * from t1 where d < '2000'"), Nil)
checkAnswer(sql("select * from t1 where d < '2001'"), Row(result))
- checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result))
checkAnswer(sql("select * from t1 where d <= '1999'"), Nil)
checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result))
- checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result))
- checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true))
+ if (!conf.ansiEnabled) {
+ checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result))
+ checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result))
+ checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true))
+ }
}
}
}
@@ -3182,17 +3220,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select * from t1 where d >= '2000-01-01 01:10:00.000'"), Row(result))
checkAnswer(sql("select * from t1 where d >= '2000-01-02 01:10:00.000'"), Nil)
checkAnswer(sql("select * from t1 where '2000' >= d"), Nil)
- checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil)
+ if (!conf.ansiEnabled) {
+ checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil)
+ }
withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") {
checkAnswer(sql("select * from t1 where d < '2000'"), Nil)
checkAnswer(sql("select * from t1 where d < '2001'"), Row(result))
- checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result))
checkAnswer(sql("select * from t1 where d <= '2000-01-02'"), Row(result))
checkAnswer(sql("select * from t1 where d <= '1999'"), Nil)
checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result))
- checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result))
- checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true))
+ if (!conf.ansiEnabled) {
+ checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result))
+ checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result))
+ checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true))
+ }
}
sql("DROP VIEW t1")
}
@@ -3257,28 +3299,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("SPARK-29213: FilterExec should not throw NPE") {
- withTempView("t1", "t2", "t3") {
- sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")
- sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)")
- .as[java.lang.Long]
- .map(identity)
- .toDF("x")
- .createOrReplaceTempView("t2")
- sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3")
- sql(
- """
- |SELECT t1.x
- |FROM t1
- |LEFT JOIN (
- | SELECT x FROM (
- | SELECT x FROM t2
- | UNION ALL
- | SELECT SUBSTR(x,5) x FROM t3
- | ) a
- | WHERE LENGTH(x)>0
- |) t3
- |ON t1.x=t3.x
+ // Under ANSI mode, casting string '' as numeric will cause runtime error
+ if (!conf.ansiEnabled) {
+ withTempView("t1", "t2", "t3") {
+ sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")
+ sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)")
+ .as[java.lang.Long]
+ .map(identity)
+ .toDF("x")
+ .createOrReplaceTempView("t2")
+ sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3")
+ sql(
+ """
+ |SELECT t1.x
+ |FROM t1
+ |LEFT JOIN (
+ | SELECT x FROM (
+ | SELECT x FROM t2
+ | UNION ALL
+ | SELECT SUBSTR(x,5) x FROM t3
+ | ) a
+ | WHERE LENGTH(x)>0
+ |) t3
+ |ON t1.x=t3.x
""".stripMargin).collect()
+ }
}
}
@@ -3298,7 +3343,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
sql("CREATE TEMPORARY VIEW tc AS SELECT * FROM VALUES(CAST(1 AS DOUBLE)) AS tc(id)")
sql("CREATE TEMPORARY VIEW td AS SELECT * FROM VALUES(CAST(1 AS FLOAT)) AS td(id)")
sql("CREATE TEMPORARY VIEW te AS SELECT * FROM VALUES(CAST(1 AS BIGINT)) AS te(id)")
- sql("CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)")
val df1 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tb)")
checkAnswer(df1, Row(new java.math.BigDecimal(1)))
val df2 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tc)")
@@ -3307,8 +3351,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(df3, Row(new java.math.BigDecimal(1)))
val df4 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM te)")
checkAnswer(df4, Row(new java.math.BigDecimal(1)))
- val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)")
- checkAnswer(df5, Array.empty[Row])
+ if (!conf.ansiEnabled) {
+ sql(
+ "CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)")
+ val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)")
+ checkAnswer(df5, Array.empty[Row])
+ }
}
}
@@ -4037,6 +4085,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
+ test("SPARK-40247: Fix BitSet equals") {
+ withTable("td") {
+ testData
+ .withColumn("bucket", $"key" % 3)
+ .write
+ .mode(SaveMode.Overwrite)
+ .bucketBy(2, "bucket")
+ .format("parquet")
+ .saveAsTable("td")
+ val df = sql(
+ """
+ |SELECT t1.key, t2.key, t3.key
+ |FROM td AS t1
+ |JOIN td AS t2 ON t2.key = t1.key
+ |JOIN td AS t3 ON t3.key = t2.key
+ |WHERE t1.bucket = 1 AND t2.bucket = 1 AND t3.bucket = 1
+ |""".stripMargin)
+ df.collect()
+ val reusedExchanges = collect(df.queryExecution.executedPlan) {
+ case r: ReusedExchangeExec => r
+ }
+ assert(reusedExchanges.size == 1)
+ }
+ }
+
test("SPARK-35331: Fix resolving original expression in RepartitionByExpression after aliased") {
Seq("CLUSTER", "DISTRIBUTE").foreach { keyword =>
Seq("a", "substr(a, 0, 3)").foreach { expr =>
@@ -4215,12 +4288,329 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
+ test("SPARK-36371: Support raw string literal") {
+ checkAnswer(sql("""SELECT r'a\tb\nc'"""), Row("""a\tb\nc"""))
+ checkAnswer(sql("""SELECT R'a\tb\nc'"""), Row("""a\tb\nc"""))
+ checkAnswer(sql("""SELECT r"a\tb\nc""""), Row("""a\tb\nc"""))
+ checkAnswer(sql("""SELECT R"a\tb\nc""""), Row("""a\tb\nc"""))
+ checkAnswer(sql("""SELECT from_json(r'{"a": "\\"}', 'a string')"""), Row(Row("\\")))
+ checkAnswer(sql("""SELECT from_json(R'{"a": "\\"}', 'a string')"""), Row(Row("\\")))
+ }
+
test("SPARK-36979: Add RewriteLateralSubquery rule into nonExcludableRules") {
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.RewriteLateralSubquery") {
sql("SELECT * FROM testData, LATERAL (SELECT * FROM testData)").collect()
}
}
+
+ test("TABLE SAMPLE") {
+ withTable("test") {
+ sql("CREATE TABLE test(c int) USING PARQUET")
+ for (i <- 0 to 20) {
+ sql(s"INSERT INTO test VALUES ($i)")
+ }
+ val df1 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)")
+ val df2 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)")
+ checkAnswer(df1, df2)
+
+ val df3 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)")
+ val df4 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)")
+ checkAnswer(df3, df4)
+ }
+ }
+
+ test("SPARK-27442: Spark support read/write parquet file with invalid char in field name") {
+ withTempDir { dir =>
+ Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20))
+ .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a")
+ .repartition(1)
+ .write.mode(SaveMode.Overwrite).parquet(dir.getAbsolutePath)
+ val df = spark.read.parquet(dir.getAbsolutePath)
+ checkAnswer(df,
+ Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) ::
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil)
+ assert(df.schema.names.sameElements(
+ Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a")))
+ checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"),
+ Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil)
+ checkAnswer(df.where("`a.b` > 10"),
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil)
+ }
+ }
+
+ test("SPARK-37965: Spark support read/write orc file with invalid char in field name") {
+ withTempDir { dir =>
+ Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22))
+ .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ",")
+ .repartition(1)
+ .write.mode(SaveMode.Overwrite).orc(dir.getAbsolutePath)
+ val df = spark.read.orc(dir.getAbsolutePath)
+ checkAnswer(df,
+ Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) ::
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil)
+ assert(df.schema.names.sameElements(
+ Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ",")))
+ checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"),
+ Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil)
+ checkAnswer(df.where("`a.b` > 10"),
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil)
+ }
+ }
+
+ test("SPARK-38173: Quoted column cannot be recognized correctly " +
+ "when quotedRegexColumnNames is true") {
+ withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT `(C3)?+.+`,T.`C1` * `C2` AS CC
+ |FROM (SELECT 3 AS C1,2 AS C2,1 AS C3) T
+ |""".stripMargin),
+ Row(3, 2, 6) :: Nil)
+ }
+ }
+
+ test("SPARK-38548: try_sum should return null if overflow happens before merging") {
+ val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+ .map(Period.ofMonths)
+ .toDF("v")
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+ .map(Duration.ofDays)
+ .toDF("v")
+ Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), Row(null))
+ }
+ }
+
+ test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
+ " when WSCG is off") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+ sql("create table t(i int, j int) using parquet")
+ sql("insert into t values(2147483647, 10)")
+ Seq(
+ "select i + j from t",
+ "select -i - j from t",
+ "select i * j from t",
+ "select i / (j - 10) from t").foreach { query =>
+ val msg = intercept[SparkException] {
+ sql(query).collect()
+ }.getMessage
+ assert(msg.contains(query))
+ }
+ }
+ }
+ }
+
+ test("SPARK-39175: Query context of Cast should be serialized to executors" +
+ " when WSCG is off") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+ sql("create table t(s string) using parquet")
+ sql("insert into t values('a')")
+ Seq(
+ "select cast(s as int) from t",
+ "select cast(s as long) from t",
+ "select cast(s as double) from t",
+ "select cast(s as decimal(10, 2)) from t",
+ "select cast(s as date) from t",
+ "select cast(s as timestamp) from t",
+ "select cast(s as boolean) from t").foreach { query =>
+ val msg = intercept[SparkException] {
+ sql(query).collect()
+ }.getMessage
+ assert(msg.contains(query))
+ }
+ }
+ }
+ }
+
+ test("SPARK-39177: Query context of getting map value should be serialized to executors" +
+ " when WSCG is off") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+ sql("create table t(m map<string, string>) using parquet")
+ sql("insert into t values map('a', 'b')")
+ Seq(
+ "select m['foo'] from t",
+ "select element_at(m, 'foo') from t").foreach { query =>
+ val msg = intercept[SparkException] {
+ sql(query).collect()
+ }.getMessage
+ assert(msg.contains(query))
+ }
+ }
+ }
+ }
+
+ test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
+ "be serialized to executors when WSCG is off") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+ sql("create table t(d decimal(38, 0)) using parquet")
+ sql("insert into t values (6e37BD),(6e37BD)")
+ Seq(
+ "select d / 0.1 from t",
+ "select sum(d) from t",
+ "select avg(d) from t").foreach { query =>
+ val msg = intercept[SparkException] {
+ sql(query).collect()
+ }.getMessage
+ assert(msg.contains(query))
+ }
+ }
+ }
+ }
+
+ test("SPARK-38589: try_avg should return null if overflow happens before merging") {
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+ .map(Period.ofMonths)
+ .toDF("v")
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+ .map(Duration.ofDays)
+ .toDF("v")
+ Seq(yearMonthDf, dayTimeDf).foreach { df =>
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_avg(v)"), Row(null))
+ }
+ }
+
+ test("SPARK-39012: SparkSQL cast partition value does not support all data types") {
+ withTempDir { dir =>
+ val binary1 = Hex.hex(UTF8String.fromString("Spark").getBytes).getBytes
+ val binary2 = Hex.hex(UTF8String.fromString("SQL").getBytes).getBytes
+ val data = Seq[(Int, Boolean, Array[Byte])](
+ (1, false, binary1),
+ (2, true, binary2)
+ )
+ data.toDF("a", "b", "c")
+ .write
+ .mode("overwrite")
+ .partitionBy("b", "c")
+ .parquet(dir.getCanonicalPath)
+ val res = spark.read
+ .schema("a INT, b BOOLEAN, c BINARY")
+ .parquet(dir.getCanonicalPath)
+ checkAnswer(res,
+ Seq(
+ Row(1, false, mutable.WrappedArray.make(binary1)),
+ Row(2, true, mutable.WrappedArray.make(binary2))
+ ))
+ }
+ }
+
+ test("SPARK-39216: Don't collapse projects in CombineUnions if it hasCorrelatedSubquery") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT (SELECT IF(x, 1, 0)) AS a
+ |FROM (SELECT true) t(x)
+ |UNION
+ |SELECT 1 AS a
+ """.stripMargin),
+ Seq(Row(1)))
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT x + 1
+ |FROM (SELECT id
+ | + (SELECT Max(id)
+ | FROM range(2)) AS x
+ | FROM range(1)) t
+ |UNION
+ |SELECT 1 AS a
+ """.stripMargin),
+ Seq(Row(2), Row(1)))
+ }
+
+ test("SPARK-39548: CreateView will make queries go into inline CTE code path thus" +
+ "trigger a mis-clarified `window definition not found` issue") {
+ sql(
+ """
+ |create or replace temporary view test_temp_view as
+ |with step_1 as (
+ |select * , min(a) over w2 as min_a_over_w2 from
+ |(select 1 as a, 2 as b, 3 as c) window w2 as (partition by b order by c)) , step_2 as
+ |(
+ |select *, max(e) over w1 as max_a_over_w1
+ |from (select 1 as e, 2 as f, 3 as g)
+ |join step_1 on true
+ |window w1 as (partition by f order by g)
+ |)
+ |select *
+ |from step_2
+ |""".stripMargin)
+
+ checkAnswer(
+ sql("select * from test_temp_view"),
+ Row(1, 2, 3, 1, 2, 3, 1, 1))
+ }
+
+ test("SPARK-40389: Don't eliminate a cast which can cause overflow") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("dt") {
+ sql("create table dt using parquet as select 9000000000BD as d")
+ val msg = intercept[SparkException] {
+ sql("select cast(cast(d as int) as long) from dt").collect()
+ }.getCause.getMessage
+ assert(msg.contains("The value 9000000000BD of the type \"DECIMAL(10,0)\" " +
+ "cannot be cast to \"INT\" due to an overflow"))
+ }
+ }
+ }
+
+ test("SPARK-41144: Unresolved hint should not cause query failure") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1(c1 bigint) USING PARQUET")
+ sql("CREATE TABLE t2(c2 bigint) USING PARQUET")
+ sql("SELECT /*+ hash(t2) */ * FROM t1 join t2 on c1 = c2")
+ }
+ }
+
+ test("SPARK-41538: Metadata column should be appended at the end of project") {
+ val tableName = "table_1"
+ val viewName = "view_1"
+ withTable(tableName) {
+ withView(viewName) {
+ sql(s"CREATE TABLE $tableName (a ARRAY<STRING>, s STRUCT<id: STRING>) USING parquet")
+ val id = "id1"
+ sql(s"INSERT INTO $tableName values(ARRAY('a'), named_struct('id', '$id'))")
+ sql(
+ s"""
+ |CREATE VIEW $viewName (id)
+ |AS WITH source AS (
+ | SELECT * FROM $tableName
+ |),
+ |renamed AS (
+ | SELECT s.id FROM source
+ |)
+ |SELECT id FROM renamed
+ |""".stripMargin)
+ val query =
+ s"""
+ |with foo AS (
+ | SELECT '$id' as id
+ |),
+ |bar AS (
+ | SELECT '$id' as id
+ |)
+ |SELECT
+ | 1
+ |FROM foo
+ |FULL OUTER JOIN bar USING(id)
+ |FULL OUTER JOIN $viewName USING(id)
+ |WHERE foo.id IS NOT NULL
+ |""".stripMargin
+ checkAnswer(sql(query), Row(1))
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
index b9ca2a0f03..987e09adb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
@@ -35,6 +35,7 @@ trait SQLQueryTestHelper {
protected def replaceNotIncludedMsg(line: String): String = {
line.replaceAll("#\\d+", "#x")
+ .replaceAll("plan_id=\\d+", "plan_id=x")
.replaceAll(
s"Location.*$clsName/",
s"Location $notIncludedMsg/{warehouse_dir}/")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index d5a34ae64a..d6a7c69018 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.ExtendedSQLTest
import org.apache.spark.util.Utils
+// scalastyle:off line.size.limit
/**
* End-to-end test cases for SQL queries.
*
@@ -44,22 +45,22 @@ import org.apache.spark.util.Utils
*
* To run the entire test suite:
* {{{
- * build/sbt "sql/testOnly *SQLQueryTestSuite"
+ * build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
* }}}
*
* To run a single test file upon change:
* {{{
- * build/sbt "~sql/testOnly *SQLQueryTestSuite -- -z inline-table.sql"
+ * build/sbt "~sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z inline-table.sql"
* }}}
*
* To re-generate golden files for entire suite, run:
* {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite"
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
* }}}
*
* To re-generate golden file for a single test, run:
* {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite -- -z describe.sql"
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z describe.sql"
* }}}
*
* The format for input files is simple:
@@ -120,6 +121,7 @@ import org.apache.spark.util.Utils
* Therefore, UDF test cases should have single input and output files but executed by three
* different types of UDFs. See 'udf/udf-inner-join.sql' as an example.
*/
+// scalastyle:on line.size.limit
@ExtendedSQLTest
class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
with SQLQueryTestHelper {
@@ -386,6 +388,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
localSparkSession.conf.set(SQLConf.TIMESTAMP_TYPE.key,
TimestampTypes.TIMESTAMP_NTZ.toString)
case _ =>
+ localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
}
if (configSet.nonEmpty) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala
deleted file mode 100644
index 6839294348..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala
+++ /dev/null
@@ -1,241 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.CatalogTable
-import org.apache.spark.sql.sources.SimpleInsertSource
-import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
-import org.apache.spark.util.Utils
-
-class SimpleShowCreateTableSuite extends ShowCreateTableSuite with SharedSparkSession
-
-abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils {
- import testImplicits._
-
- test("data source table with user specified schema") {
- withTable("ddl_test") {
- val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile
-
- sql(
- s"""CREATE TABLE ddl_test (
- | a STRING,
- | b STRING,
- | `extra col` ARRAY<INT>,
- | `<another>` STRUCT<x: INT, y: ARRAY<BOOLEAN>>
- |)
- |USING json
- |OPTIONS (
- | PATH '$jsonFilePath'
- |)
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("data source table CTAS") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |AS SELECT 1 AS a, "foo" AS b
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("partitioned data source table") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |PARTITIONED BY (b)
- |AS SELECT 1 AS a, "foo" AS b
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("bucketed data source table") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS
- |AS SELECT 1 AS a, "foo" AS b
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("partitioned bucketed data source table") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |PARTITIONED BY (c)
- |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("data source table with a comment") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |COMMENT 'This is a comment'
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("data source table with table properties") {
- withTable("ddl_test") {
- sql(
- s"""CREATE TABLE ddl_test
- |USING json
- |TBLPROPERTIES ('a' = '1')
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c
- """.stripMargin
- )
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("data source table using Dataset API") {
- withTable("ddl_test") {
- spark
- .range(3)
- .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e)
- .write
- .mode("overwrite")
- .partitionBy("a", "b")
- .bucketBy(2, "c", "d")
- .saveAsTable("ddl_test")
-
- checkCreateTable("ddl_test")
- }
- }
-
- test("temp view") {
- val viewName = "spark_28383"
- withTempView(viewName) {
- sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT 1 AS a")
- val ex = intercept[AnalysisException] {
- sql(s"SHOW CREATE TABLE $viewName")
- }
- assert(ex.getMessage.contains(
- s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view."))
- }
-
- withGlobalTempView(viewName) {
- sql(s"CREATE GLOBAL TEMPORARY VIEW $viewName AS SELECT 1 AS a")
- val globalTempViewDb = spark.sessionState.catalog.globalTempViewManager.database
- val ex = intercept[AnalysisException] {
- sql(s"SHOW CREATE TABLE $globalTempViewDb.$viewName")
- }
- assert(ex.getMessage.contains(
- s"$globalTempViewDb.$viewName is a temp view. " +
- "'SHOW CREATE TABLE' expects a table or permanent view."))
- }
- }
-
- test("SPARK-24911: keep quotes for nested fields") {
- withTable("t1") {
- val createTable = "CREATE TABLE `t1` (`a` STRUCT<`b`: STRING>)"
- sql(s"$createTable USING json")
- val shownDDL = getShowDDL("SHOW CREATE TABLE t1")
- assert(shownDDL == "CREATE TABLE `default`.`t1` ( `a` STRUCT<`b`: STRING>) USING json")
-
- checkCreateTable("t1")
- }
- }
-
- test("SPARK-36012: Add NULL flag when SHOW CREATE TABLE") {
- val t = "SPARK_36012"
- withTable(t) {
- sql(
- s"""
- |CREATE TABLE $t (
- | a bigint NOT NULL,
- | b bigint
- |)
- |USING ${classOf[SimpleInsertSource].getName}
- """.stripMargin)
- val showDDL = getShowDDL(s"SHOW CREATE TABLE $t")
- assert(showDDL == s"CREATE TABLE `default`.`$t` ( `a` BIGINT NOT NULL," +
- s" `b` BIGINT) USING ${classOf[SimpleInsertSource].getName}")
- }
- }
-
- protected def getShowDDL(showCreateTableSql: String): String = {
- sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim).mkString(" ")
- }
-
- protected def checkCreateTable(table: String, serde: Boolean = false): Unit = {
- checkCreateTableOrView(TableIdentifier(table, Some("default")), "TABLE", serde)
- }
-
- protected def checkCreateView(table: String, serde: Boolean = false): Unit = {
- checkCreateTableOrView(TableIdentifier(table, Some("default")), "VIEW", serde)
- }
-
- protected def checkCreateTableOrView(
- table: TableIdentifier,
- checkType: String,
- serde: Boolean): Unit = {
- val db = table.database.getOrElse("default")
- val expected = spark.sharedState.externalCatalog.getTable(db, table.table)
- val shownDDL = if (serde) {
- sql(s"SHOW CREATE TABLE ${table.quotedString} AS SERDE").head().getString(0)
- } else {
- sql(s"SHOW CREATE TABLE ${table.quotedString}").head().getString(0)
- }
-
- sql(s"DROP $checkType ${table.quotedString}")
-
- try {
- sql(shownDDL)
- val actual = spark.sharedState.externalCatalog.getTable(db, table.table)
- checkCatalogTables(expected, actual)
- } finally {
- sql(s"DROP $checkType IF EXISTS ${table.table}")
- }
- }
-
- protected def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = {
- assert(CatalogTable.normalize(actual) == CatalogTable.normalize(expected))
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
index 9d4e57093c..0a7c684a68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
+import org.apache.logging.log4j.Level
import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._
@@ -431,7 +432,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit
.getOrCreate()
.sharedState
}
- assert(logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg)))
+ assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg)))
}
test("SPARK-33944: no warning setting spark.sql.warehouse.dir using session options") {
@@ -444,7 +445,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit
.getOrCreate()
.sharedState
}
- assert(!logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg)))
+ assert(!logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg)))
}
Seq(".", "..", "dir0", "dir0/dir1", "/dir0/dir1", "./dir0").foreach { pathStr =>
@@ -484,6 +485,89 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit
.getOrCreate()
session.sql("SELECT 1").collect()
}
- assert(logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg)))
+ assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg)))
+ }
+
+ test("SPARK-37727: Show ignored configurations in debug level logs") {
+ // Create one existing SparkSession to check following logs.
+ SparkSession.builder().master("local").getOrCreate()
+
+ val logAppender = new LogAppender
+ logAppender.setThreshold(Level.DEBUG)
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) {
+ SparkSession.builder()
+ .config("spark.sql.warehouse.dir", "2")
+ .config("spark.abc", "abcb")
+ .config("spark.abcd", "abcb4")
+ .getOrCreate()
+ }
+
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ Seq(
+ "Ignored static SQL configurations",
+ "spark.sql.warehouse.dir=2",
+ "Configurations that might not take effect",
+ "spark.abcd=abcb4",
+ "spark.abc=abcb").foreach { msg =>
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}")
+ }
+ }
+
+ test("SPARK-37727: Hide the same configuration already explicitly set in logs") {
+ // Create one existing SparkSession to check following logs.
+ SparkSession.builder().master("local").config("spark.abc", "abc").getOrCreate()
+
+ val logAppender = new LogAppender
+ logAppender.setThreshold(Level.DEBUG)
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) {
+ // Ignore logs because it's already set.
+ SparkSession.builder().config("spark.abc", "abc").getOrCreate()
+ // Show logs for only configuration newly set.
+ SparkSession.builder().config("spark.abc.new", "abc").getOrCreate()
+ // Ignore logs because it's set ^.
+ SparkSession.builder().config("spark.abc.new", "abc").getOrCreate()
+ }
+
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ Seq(
+ "Using an existing Spark session; only runtime SQL configurations will take effect",
+ "Configurations that might not take effect",
+ "spark.abc.new=abc").foreach { msg =>
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}")
+ }
+
+ assert(
+ !logs.exists(_.contains("spark.abc=abc")),
+ s"'spark.abc=abc' existed in:\n${logs.mkString("\n")}")
+ }
+
+ test("SPARK-37727: Hide runtime SQL configurations in logs") {
+ // Create one existing SparkSession to check following logs.
+ SparkSession.builder().master("local").getOrCreate()
+
+ val logAppender = new LogAppender
+ logAppender.setThreshold(Level.DEBUG)
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) {
+ // Ignore logs for runtime SQL configurations
+ SparkSession.builder().config("spark.sql.ansi.enabled", "true").getOrCreate()
+ // Show logs for Spark core configuration
+ SparkSession.builder().config("spark.buffer.size", "1234").getOrCreate()
+ // Show logs for custom runtime options
+ SparkSession.builder().config("spark.sql.source.abc", "abc").getOrCreate()
+ // Show logs for static SQL configurations
+ SparkSession.builder().config("spark.sql.warehouse.dir", "xyz").getOrCreate()
+ }
+
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ Seq(
+ "spark.buffer.size=1234",
+ "spark.sql.source.abc=abc",
+ "spark.sql.warehouse.dir=xyz").foreach { msg =>
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}")
+ }
+
+ assert(
+ !logs.exists(_.contains("spark.sql.ansi.enabled\"")),
+ s"'spark.sql.ansi.enabled' existed in:\n${logs.mkString("\n")}")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index f8a155ec1b..5c8eec6b10 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -27,7 +27,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -45,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Test cases for the [[SparkSessionExtensions]].
*/
-class SparkSessionExtensionSuite extends SparkFunSuite {
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
private def create(
builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder)
@@ -171,7 +172,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
withSession(extensions) { session =>
session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
- assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule()))
+ assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules
+ .contains(MyQueryStagePrepRule()))
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
import session.sqlContext.implicits._
@@ -406,6 +408,26 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
session.sql("SELECT * FROM v")
}
}
+
+ test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE Optimizer") {
+ def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = {
+ assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
+ df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+ }
+ val extensions = create { extensions =>
+ extensions.injectRuntimeOptimizerRule(_ => AddLimit)
+ }
+ withSession(extensions) { session =>
+ assert(session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules.contains(AddLimit))
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val df = session.range(2).repartition()
+ assert(!executedPlan(df).isInstanceOf[CollectLimitExec])
+ df.collect()
+ assert(executedPlan(df).isInstanceOf[CollectLimitExec])
+ }
+ }
+ }
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -441,6 +463,9 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
override def parseDataType(sqlText: String): DataType =
delegate.parseDataType(sqlText)
+
+ override def parseQuery(sqlText: String): LogicalPlan =
+ delegate.parseQuery(sqlText)
}
object MyExtensions {
@@ -546,7 +571,7 @@ case class NoCloseColumnVector(wrapped: ColumnVector) extends ColumnVector(wrapp
override def getBinary(rowId: Int): Array[Byte] = wrapped.getBinary(rowId)
- override protected def getChild(ordinal: Int): ColumnVector = wrapped.getChild(ordinal)
+ override def getChild(ordinal: Int): ColumnVector = wrapped.getChild(ordinal)
}
trait ColumnarExpression extends Expression with Serializable {
@@ -722,37 +747,32 @@ class BrokenColumnarAdd(
lhs = left.columnarEval(batch)
rhs = right.columnarEval(batch)
- if (lhs == null || rhs == null) {
- ret = null
- } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[ColumnVector]
- val r = rhs.asInstanceOf[ColumnVector]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
- }
- } else if (rhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[Long]
- val r = rhs.asInstanceOf[ColumnVector]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
- }
- } else if (lhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[ColumnVector]
- val r = rhs.asInstanceOf[Long]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
- }
- } else {
- ret = nullSafeEval(lhs, rhs)
+ (lhs, rhs) match {
+ case (null, null) =>
+ ret = null
+ case (l: ColumnVector, r: ColumnVector) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ case (l: Long, r: ColumnVector) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ case (l: ColumnVector, r: Long) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
+ }
+ case (l, r) =>
+ ret = nullSafeEval(l, r)
}
} finally {
if (lhs != null && lhs.isInstanceOf[ColumnVector]) {
@@ -1026,3 +1046,10 @@ class YourExtensions extends SparkSessionExtensionsProvider {
v1.injectFunction(getAppName)
}
}
+
+object AddLimit extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case Limit(_, _) => plan
+ case _ => Limit(Literal(1), plan)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 9f8000a08f..c37309d97a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -27,8 +27,9 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneUTC
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, PST, UTC}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC}
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -406,9 +407,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
withTable("TBL1", "TBL") {
import org.apache.spark.sql.functions._
val df = spark.range(1000L).select('id,
- 'id * 2 as "FLD1",
- 'id * 12 as "FLD2",
- lit("aaa") + 'id as "fld3")
+ Symbol("id") * 2 as "FLD1",
+ Symbol("id") * 12 as "FLD2",
+ lit(null).cast(DoubleType) + Symbol("id") as "fld3")
df.write
.mode(SaveMode.Overwrite)
.bucketBy(10, "id", "FLD1", "FLD2")
@@ -424,7 +425,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
|WHERE t1.fld3 IN (-123.23,321.23)
""".stripMargin)
df2.createTempView("TBL2")
- sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan
+ sql("SELECT * FROM tbl2 WHERE fld3 IN (0,1) ").queryExecution.executedPlan
}
}
}
@@ -470,7 +471,89 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}
}
- def getStatAttrNames(tableName: String): Set[String] = {
+ private def checkDescTimestampColStats(
+ tableName: String,
+ timestampColumn: String,
+ expectedMinTimestamp: String,
+ expectedMaxTimestamp: String): Unit = {
+
+ def extractColumnStatsFromDesc(statsName: String, rows: Array[Row]): String = {
+ rows.collect {
+ case r: Row if r.getString(0) == statsName =>
+ r.getString(1)
+ }.head
+ }
+
+ val descTsCol = sql(s"DESC FORMATTED $tableName $timestampColumn").collect()
+ assert(extractColumnStatsFromDesc("min", descTsCol) == expectedMinTimestamp)
+ assert(extractColumnStatsFromDesc("max", descTsCol) == expectedMaxTimestamp)
+ }
+
+ test("SPARK-38140: describe column stats (min, max) for timestamp column: desc results should " +
+ "be consistent with the written value if writing and desc happen in the same time zone") {
+
+ val zoneIdAndOffsets =
+ Seq((UTC, "+0000"), (PST, "-0800"), (getZoneId("Asia/Hong_Kong"), "+0800"))
+
+ zoneIdAndOffsets.foreach { case (zoneId, offset) =>
+ withDefaultTimeZone(zoneId) {
+ val table = "insert_desc_same_time_zone"
+ val tsCol = "timestamp_typed_col"
+ withTable(table) {
+ val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)"
+ val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)"
+ sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet")
+ sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS")
+
+ checkDescTimestampColStats(
+ tableName = table,
+ timestampColumn = tsCol,
+ expectedMinTimestamp = "2022-01-01 00:00:01.123456 " + offset,
+ expectedMaxTimestamp = "2022-01-03 00:00:02.987654 " + offset)
+ }
+ }
+ }
+ }
+
+ test("SPARK-38140: describe column stats (min, max) for timestamp column: desc should show " +
+ "different results if writing in UTC and desc in other time zones") {
+
+ val table = "insert_desc_diff_time_zones"
+ val tsCol = "timestamp_typed_col"
+
+ withDefaultTimeZone(UTC) {
+ withTable(table) {
+ val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)"
+ val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)"
+ sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet")
+ sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS")
+
+ checkDescTimestampColStats(
+ tableName = table,
+ timestampColumn = tsCol,
+ expectedMinTimestamp = "2022-01-01 00:00:01.123456 +0000",
+ expectedMaxTimestamp = "2022-01-03 00:00:02.987654 +0000")
+
+ TimeZone.setDefault(DateTimeUtils.getTimeZone("PST"))
+ checkDescTimestampColStats(
+ tableName = table,
+ timestampColumn = tsCol,
+ expectedMinTimestamp = "2021-12-31 16:00:01.123456 -0800",
+ expectedMaxTimestamp = "2022-01-02 16:00:02.987654 -0800")
+
+ TimeZone.setDefault(DateTimeUtils.getTimeZone("Asia/Hong_Kong"))
+ checkDescTimestampColStats(
+ tableName = table,
+ timestampColumn = tsCol,
+ expectedMinTimestamp = "2022-01-01 08:00:01.123456 +0800",
+ expectedMaxTimestamp = "2022-01-03 08:00:02.987654 +0800")
+ }
+ }
+ }
+
+ private def getStatAttrNames(tableName: String): Set[String] = {
val queryStats = spark.table(tableName).queryExecution.optimizedPlan.stats.attributeStats
queryStats.map(_._1.name).toSet
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index dd8a1a8478..2f118f236e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -112,9 +112,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
val df = Seq[(String, String, String, Int)](("hello", "world", null, 15))
.toDF("a", "b", "c", "d")
- checkAnswer(
- df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"),
- Row(null, "hello", null))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ checkAnswer(
+ df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"),
+ Row(null, "hello", null))
+ }
// check implicit type cast
checkAnswer(
@@ -383,9 +385,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
Row("host", "/file;param", "query;p2", null, "http", "/file;param?query;p2",
"user:pass@host", "user:pass", null))
- testUrl(
- "inva lid://user:pass@host/file;param?query;p2",
- Row(null, null, null, null, null, null, null, null, null))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ testUrl(
+ "inva lid://user:pass@host/file;param?query;p2",
+ Row(null, null, null, null, null, null, null, null, null))
+ }
}
@@ -486,6 +490,58 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
)
}
+ test("SPARK-36751: add octet length api for scala") {
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08"))
+ .toDF("a", "b", "c", "d", "e", "f")
+ // string and binary input
+ checkAnswer(
+ df.select(octet_length($"a"), octet_length($"b")),
+ Row(3, 4))
+ // string and binary input
+ checkAnswer(
+ df.selectExpr("octet_length(a)", "octet_length(b)"),
+ Row(3, 4))
+ // integer, float and double input
+ checkAnswer(
+ df.selectExpr("octet_length(c)", "octet_length(d)", "octet_length(e)"),
+ Row(3, 3, 5)
+ )
+ // multi-byte character input
+ checkAnswer(
+ df.selectExpr("octet_length(f)"),
+ Row(4)
+ )
+ // scalastyle:on
+ }
+
+ test("SPARK-36751: add bit length api for scala") {
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08"))
+ .toDF("a", "b", "c", "d", "e", "f")
+ // string and binary input
+ checkAnswer(
+ df.select(bit_length($"a"), bit_length($"b")),
+ Row(24, 32))
+ // string and binary input
+ checkAnswer(
+ df.selectExpr("bit_length(a)", "bit_length(b)"),
+ Row(24, 32))
+ // integer, float and double input
+ checkAnswer(
+ df.selectExpr("bit_length(c)", "bit_length(d)", "bit_length(e)"),
+ Row(24, 24, 40)
+ )
+ // multi-byte character input
+ checkAnswer(
+ df.selectExpr("bit_length(f)"),
+ Row(32)
+ )
+ // scalastyle:on
+ }
+
test("initcap function") {
val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z")
checkAnswer(
@@ -598,4 +654,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
)
}
+
+ test("SPARK-36148: check input data types of regexp_replace") {
+ val m = intercept[AnalysisException] {
+ sql("select regexp_replace(collect_list(1), '1', '2')")
+ }.getMessage
+ assert(m.contains("data type mismatch: argument 1 requires string type"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 277cb1bceb..b3aefac05b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project, Sort}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -145,12 +146,12 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
test("runtime error when the number of rows is greater than 1") {
- val error2 = intercept[RuntimeException] {
+ val e = intercept[IllegalStateException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
- assert(error2.getMessage.contains(
- "more than one row returned by a subquery used as an expression")
- )
+ // TODO(SPARK-39167): Throw an exception w/ an error class for multiple rows from a subquery
+ assert(e.getMessage.contains(
+ "more than one row returned by a subquery used as an expression"))
}
test("uncorrelated scalar subquery on a DataFrame generated query") {
@@ -895,7 +896,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
withTempView("t") {
Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("t")
val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)"))
- assert(e.message.contains("cannot resolve 'a' given input columns: [t.i, t.j]"))
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("a", "t.i, t.j")))
}
}
@@ -1407,7 +1409,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
test("Scalar subquery name should start with scalar-subquery#") {
val df = sql("SELECT a FROM l WHERE a = (SELECT max(c) FROM r WHERE c = 1)".stripMargin)
- var subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty
+ val subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty
df.queryExecution.executedPlan.transformAllExpressions {
case s @ ScalarSubquery(p: SubqueryExec, _) =>
subqueryExecs += p
@@ -1878,6 +1880,30 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
+ test("SPARK-36280: Remove redundant aliases after RewritePredicateSubquery") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 USING parquet AS SELECT id AS a, id AS b, id AS c FROM range(10)")
+ sql("CREATE TABLE t2 USING parquet AS SELECT id AS x, id AS y FROM range(8)")
+ val df = sql(
+ """
+ |SELECT *
+ |FROM t1
+ |WHERE a IN (SELECT x
+ | FROM (SELECT x AS x,
+ | RANK() OVER (PARTITION BY x ORDER BY SUM(y) DESC) AS ranking
+ | FROM t2
+ | GROUP BY x) tmp1
+ | WHERE ranking <= 5)
+ |""".stripMargin)
+
+ df.collect()
+ val exchanges = collect(df.queryExecution.executedPlan) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(exchanges.size === 1)
+ }
+ }
+
test("SPARK-36747: should not combine Project with Aggregate") {
withTempView("t") {
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
@@ -1896,6 +1922,77 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
+ test("SPARK-36656: Do not collapse projects with correlate scalar subqueries") {
+ withTempView("t1", "t2") {
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2")
+ val correctAnswer = Row(0, 2, 20) :: Row(1, null, null) :: Nil
+ checkAnswer(
+ sql(
+ """
+ |SELECT c1, s, s * 10 FROM (
+ | SELECT c1, (SELECT FIRST(c2) FROM t2 WHERE t1.c1 = t2.c1) s FROM t1)
+ |""".stripMargin),
+ correctAnswer)
+ checkAnswer(
+ sql(
+ """
+ |SELECT c1, s, s * 10 FROM (
+ | SELECT c1, SUM((SELECT FIRST(c2) FROM t2 WHERE t1.c1 = t2.c1)) s
+ | FROM t1 GROUP BY c1
+ |)
+ |""".stripMargin),
+ correctAnswer)
+ }
+ }
+
+ test("SPARK-37199: deterministic in QueryPlan considers subquery") {
+ val deterministicQueryPlan = sql("select (select 1 as b) as b")
+ .queryExecution.executedPlan
+ assert(deterministicQueryPlan.deterministic)
+
+ val nonDeterministicQueryPlan = sql("select (select rand(1) as b) as b")
+ .queryExecution.executedPlan
+ assert(!nonDeterministicQueryPlan.deterministic)
+ }
+
+ test("SPARK-38132: Not IN subquery correctness checks") {
+ val t = "test_table"
+ withTable(t) {
+ Seq[(Integer, Integer)](
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, null),
+ (null, 0))
+ .toDF("c1", "c2").write.saveAsTable(t)
+ val df = spark.table(t)
+
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) = true"), Seq.empty)
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) = true"),
+ Row(4, null) :: Nil)
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) <=> true"), Seq.empty)
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> true"),
+ Row(4, null) :: Nil)
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) != false"), Seq.empty)
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) != false"),
+ Row(4, null) :: Nil)
+ checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t)) <=> false)"), Seq.empty)
+ checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> false)"),
+ Row(4, null) :: Nil)
+ }
+ }
+
+ test("SPARK-38155: disallow distinct aggregate in lateral subqueries") {
+ withTempView("t1", "t2") {
+ Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
+ assert(intercept[AnalysisException] {
+ sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)")
+ }.getMessage.contains("Correlated column is not allowed in predicate"))
+ }
+ }
+
test("SPARK-38180: allow safe cast expressions in correlated equality conditions") {
withTempView("t1", "t2") {
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
@@ -1921,4 +2018,302 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}.getMessage.contains("Correlated column is not allowed in predicate"))
}
}
+
+ test("Merge non-correlated scalar subqueries") {
+ Seq(false, true).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
+ val df = sql(
+ """
+ |SELECT
+ | (SELECT avg(key) FROM testData),
+ | (SELECT sum(key) FROM testData),
+ | (SELECT count(distinct key) FROM testData)
+ """.stripMargin)
+
+ checkAnswer(df, Row(50.5, 5050, 100) :: Nil)
+
+ val plan = df.queryExecution.executedPlan
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 2,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+ }
+
+ test("Merge non-correlated scalar subqueries in a subquery") {
+ Seq(false, true).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
+ val df = sql(
+ """
+ |SELECT (
+ | SELECT
+ | SUM(
+ | (SELECT avg(key) FROM testData) +
+ | (SELECT sum(key) FROM testData) +
+ | (SELECT count(distinct key) FROM testData))
+ | FROM testData
+ |)
+ """.stripMargin)
+
+ checkAnswer(df, Row(520050.0) :: Nil)
+
+ val plan = df.queryExecution.executedPlan
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ if (enableAQE) {
+ assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 4,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ } else {
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 5,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+ }
+ }
+
+ test("Merge non-correlated scalar subqueries from different levels") {
+ Seq(false, true).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
+ val df = sql(
+ """
+ |SELECT
+ | (SELECT avg(key) FROM testData),
+ | (
+ | SELECT
+ | SUM(
+ | (SELECT sum(key) FROM testData)
+ | )
+ | FROM testData
+ | )
+ """.stripMargin)
+
+ checkAnswer(df, Row(50.5, 505000) :: Nil)
+
+ val plan = df.queryExecution.executedPlan
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 2,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+ }
+
+ test("Merge non-correlated scalar subqueries from different parent plans") {
+ Seq(false, true).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
+ val df = sql(
+ """
+ |SELECT
+ | (
+ | SELECT
+ | SUM(
+ | (SELECT avg(key) FROM testData)
+ | )
+ | FROM testData
+ | ),
+ | (
+ | SELECT
+ | SUM(
+ | (SELECT sum(key) FROM testData)
+ | )
+ | FROM testData
+ | )
+ """.stripMargin)
+
+ checkAnswer(df, Row(5050.0, 505000) :: Nil)
+
+ val plan = df.queryExecution.executedPlan
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ if (enableAQE) {
+ assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 3,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ } else {
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 4,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+ }
+ }
+
+ test("Merge non-correlated scalar subqueries with conflicting names") {
+ Seq(false, true).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
+ val df = sql(
+ """
+ |SELECT
+ | (SELECT avg(key) AS key FROM testData),
+ | (SELECT sum(key) AS key FROM testData),
+ | (SELECT count(distinct key) AS key FROM testData)
+ """.stripMargin)
+
+ checkAnswer(df, Row(50.5, 5050, 100) :: Nil)
+
+ val plan = df.queryExecution.executedPlan
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
+ val reusedSubqueryIds = collectWithSubqueries(plan) {
+ case rs: ReusedSubqueryExec => rs.child.id
+ }
+
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan")
+ assert(reusedSubqueryIds.size == 2,
+ "Missing or unexpected reused ReusedSubqueryExec in the plan")
+ }
+ }
+ }
+
+ test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") {
+ checkAnswer(
+ sql("""
+ |SELECT *
+ |FROM (
+ | SELECT '2022-06-01' AS c1
+ |) a
+ |WHERE c1 IN (
+ | SELECT date_add('2022-06-01', 0)
+ |)
+ |""".stripMargin),
+ Row("2022-06-01"))
+ checkAnswer(
+ sql("""
+ |SELECT *
+ |FROM (
+ | SELECT '2022-06-01' AS c1
+ |) a
+ |WHERE c1 IN (
+ | SELECT date_add(a.c1.k1, 0)
+ | FROM (
+ | SELECT named_struct('k1', '2022-06-01') AS c1
+ | ) a
+ |)
+ |""".stripMargin),
+ Row("2022-06-01"))
+ }
+
+ test("SPARK-39672: Fix removing project before filter with correlated subquery") {
+ withTempView("v1", "v2") {
+ Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c").createTempView("v1")
+ Seq((1, 3, 5), (4, 5, 6)).toDF("a", "b", "c").createTempView("v2")
+
+ def findProject(df: DataFrame): Seq[Project] = {
+ df.queryExecution.optimizedPlan.collect {
+ case p: Project => p
+ }
+ }
+
+ // project before filter cannot be removed since subquery has conflicting attributes
+ // with outer reference
+ val df1 = sql(
+ """
+ |select * from
+ |(
+ |select
+ |v1.a,
+ |v1.b,
+ |v2.c
+ |from v1
+ |inner join v2
+ |on v1.a=v2.a) t3
+ |where not exists (
+ | select 1
+ | from v2
+ | where t3.a=v2.a and t3.b=v2.b and t3.c=v2.c
+ |)
+ |""".stripMargin)
+ checkAnswer(df1, Row(1, 2, 5))
+ assert(findProject(df1).size == 4)
+
+ // project before filter can be removed when there are no conflicting attributes
+ val df2 = sql(
+ """
+ |select * from
+ |(
+ |select
+ |v1.b,
+ |v2.c
+ |from v1
+ |inner join v2
+ |on v1.b=v2.c) t3
+ |where not exists (
+ | select 1
+ | from v2
+ | where t3.b=v2.b and t3.c=v2.c
+ |)
+ |""".stripMargin)
+
+ checkAnswer(df2, Row(5, 5))
+ assert(findProject(df2).size == 3)
+ }
+ }
+
+ test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") {
+ withTempView("t1") {
+ Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1")
+
+ checkAnswer(sql(
+ """
+ |SELECT
+ | (SELECT count(distinct c1) FROM t1),
+ | (SELECT count(distinct c2) FROM t1)
+ |""".stripMargin),
+ Row(2, 2))
+
+ // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the
+ // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that
+ // are inserted by the rewrite.
+ checkAnswer(sql(
+ """
+ |SELECT
+ | (SELECT count(distinct c1) + sum(distinct c2) FROM t1),
+ | (SELECT count(distinct c2) + sum(distinct c1) FROM t1)
+ |""".stripMargin),
+ Row(8, 6))
+ }
+ }
+
+ test("SPARK-42937: Outer join with subquery in condition") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ withTempView("t2") {
+ // this is the same as the view t created in beforeAll, but that gets dropped by
+ // one of the tests above
+ r.filter($"c".isNotNull && $"d".isNotNull).createOrReplaceTempView("t2")
+ val expected = Row(1, 2.0d, null, null) :: Row(1, 2.0d, null, null) ::
+ Row(3, 3.0d, 3, 2.0d) :: Row(null, 5.0d, null, null) :: Nil
+ checkAnswer(sql(
+ """
+ |select *
+ |from l
+ |left outer join r
+ |on a = c
+ |and a in (select c from t2 where d in (1.0, 2.0))
+ |where b > 1.0""".stripMargin),
+ expected)
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala
new file mode 100644
index 0000000000..1764584922
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+trait TPCBase extends SharedSparkSession {
+
+ protected def injectStats: Boolean = false
+
+ override protected def sparkConf: SparkConf = {
+ if (injectStats) {
+ super.sparkConf
+ .set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue)
+ .set(SQLConf.CBO_ENABLED, true)
+ .set(SQLConf.PLAN_STATS_ENABLED, true)
+ .set(SQLConf.JOIN_REORDER_ENABLED, true)
+ } else {
+ super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue)
+ }
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ createTables()
+ }
+
+ override def afterAll(): Unit = {
+ dropTables()
+ super.afterAll()
+ }
+
+ protected def createTables(): Unit
+
+ protected def dropTables(): Unit
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala
index 20cfcecc22..39587ce063 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala
@@ -18,10 +18,8 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
-trait TPCDSBase extends SharedSparkSession with TPCDSSchema {
+trait TPCDSBase extends TPCBase with TPCDSSchema {
// The TPCDS queries below are based on v1.4
private val tpcdsAllQueries: Seq[String] = Seq(
@@ -79,19 +77,7 @@ trait TPCDSBase extends SharedSparkSession with TPCDSSchema {
""".stripMargin)
}
- private val originalCBCEnabled = conf.cboEnabled
- private val originalPlanStatsEnabled = conf.planStatsEnabled
- private val originalJoinReorderEnabled = conf.joinReorderEnabled
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- if (injectStats) {
- // Sets configurations for enabling the optimization rules that
- // exploit data statistics.
- conf.setConf(SQLConf.CBO_ENABLED, true)
- conf.setConf(SQLConf.PLAN_STATS_ENABLED, true)
- conf.setConf(SQLConf.JOIN_REORDER_ENABLED, true)
- }
+ override def createTables(): Unit = {
tableNames.foreach { tableName =>
createTable(spark, tableName)
if (injectStats) {
@@ -102,15 +88,9 @@ trait TPCDSBase extends SharedSparkSession with TPCDSSchema {
}
}
- override def afterAll(): Unit = {
- conf.setConf(SQLConf.CBO_ENABLED, originalCBCEnabled)
- conf.setConf(SQLConf.PLAN_STATS_ENABLED, originalPlanStatsEnabled)
- conf.setConf(SQLConf.JOIN_REORDER_ENABLED, originalJoinReorderEnabled)
+ override def dropTables(): Unit = {
tableNames.foreach { tableName =>
spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true)
}
- super.afterAll()
}
-
- protected def injectStats: Boolean = false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
index 22e1b838f3..8c4d25a7eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.tags.ExtendedSQLTest
@ExtendedSQLTest
class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSBase {
- tpcdsQueries.foreach { name =>
+ // q72 is skipped due to GitHub Actions' memory limit.
+ tpcdsQueries.filterNot(sys.env.contains("GITHUB_ACTIONS") && _ == "q72").foreach { name =>
val queryString = resourceToString(s"tpcds/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(name) {
@@ -39,7 +40,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSBase {
}
}
- tpcdsQueriesV2_7_0.foreach { name =>
+ // q72 is skipped due to GitHub Actions' memory limit.
+ tpcdsQueriesV2_7_0.filterNot(sys.env.contains("GITHUB_ACTIONS") && _ == "q72").foreach { name =>
val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(s"$name-v2.7") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala
index 952e896802..8019fc98a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala
@@ -20,10 +20,13 @@ package org.apache.spark.sql
import java.io.File
import java.nio.file.{Files, Paths}
+import scala.collection.JavaConverters._
+
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.tags.ExtendedSQLTest
/**
* End-to-end tests to check TPCDS query results.
@@ -51,13 +54,14 @@ import org.apache.spark.sql.test.TestSparkSession
* build/sbt "sql/testOnly *TPCDSQueryTestSuite -- -z q79"
* }}}
*/
+@ExtendedSQLTest
class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelper {
private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA")
private val regenerateGoldenFiles = sys.env.get("SPARK_GENERATE_GOLDEN_FILES").exists(_ == "1")
// To make output results deterministic
- protected override def sparkConf: SparkConf = super.sparkConf
+ override protected def sparkConf: SparkConf = super.sparkConf
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
protected override def createSparkSession: TestSparkSession = {
@@ -97,50 +101,111 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp
""".stripMargin)
}
- private def runQuery(query: String, goldenFile: File): Unit = {
- val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
- val queryString = query.trim
- val outputString = output.mkString("\n").replaceAll("\\s+$", "")
- if (regenerateGoldenFiles) {
- val goldenOutput = {
- s"-- Automatically generated by ${getClass.getSimpleName}\n\n" +
- s"-- !query schema\n" +
- schema + "\n" +
- s"-- !query output\n" +
- outputString +
- "\n"
- }
- val parent = goldenFile.getParentFile
- if (!parent.exists()) {
- assert(parent.mkdirs(), "Could not create directory: " + parent)
+ private def runQuery(
+ query: String,
+ goldenFile: File,
+ conf: Map[String, String]): Unit = {
+ val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins
+ withSQLConf(conf.toSeq: _*) {
+ try {
+ val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
+ val queryString = query.trim
+ val outputString = output.mkString("\n").replaceAll("\\s+$", "")
+ if (regenerateGoldenFiles) {
+ val goldenOutput = {
+ s"-- Automatically generated by ${getClass.getSimpleName}\n\n" +
+ s"-- !query schema\n" +
+ schema + "\n" +
+ s"-- !query output\n" +
+ outputString +
+ "\n"
+ }
+ val parent = goldenFile.getParentFile
+ if (!parent.exists()) {
+ assert(parent.mkdirs(), "Could not create directory: " + parent)
+ }
+ stringToFile(goldenFile, goldenOutput)
+ }
+
+ // Read back the golden file.
+ val (expectedSchema, expectedOutput) = {
+ val goldenOutput = fileToString(goldenFile)
+ val segments = goldenOutput.split("-- !query.*\n")
+
+ // query has 3 segments, plus the header
+ assert(segments.size == 3,
+ s"Expected 3 blocks in result file but got ${segments.size}. " +
+ "Try regenerate the result files.")
+
+ (segments(1).trim, segments(2).replaceAll("\\s+$", ""))
+ }
+
+ assertResult(expectedSchema, s"Schema did not match\n$queryString") {
+ schema
+ }
+ if (shouldSortResults) {
+ val expectSorted = expectedOutput.split("\n").sorted.map(_.trim)
+ .mkString("\n").replaceAll("\\s+$", "")
+ val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "")
+ assertResult(expectSorted, s"Result did not match\n$queryString") {
+ outputSorted
+ }
+ } else {
+ assertResult(expectedOutput, s"Result did not match\n$queryString") {
+ outputString
+ }
+ }
+ } catch {
+ case e: Throwable =>
+ val configs = conf.map {
+ case (k, v) => s"$k=$v"
+ }
+ throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}")
}
- stringToFile(goldenFile, goldenOutput)
}
+ }
- // Read back the golden file.
- val (expectedSchema, expectedOutput) = {
- val goldenOutput = fileToString(goldenFile)
- val segments = goldenOutput.split("-- !query.*\n")
+ val sortMergeJoinConf = Map(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "true")
- // query has 3 segments, plus the header
- assert(segments.size == 3,
- s"Expected 3 blocks in result file but got ${segments.size}. " +
- "Try regenerate the result files.")
+ val broadcastHashJoinConf = Map(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")
- (segments(1).trim, segments(2).replaceAll("\\s+$", ""))
- }
+ val shuffledHashJoinConf = Map(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ "spark.sql.join.forceApplyShuffledHashJoin" -> "true")
- assertResult(expectedSchema, s"Schema did not match\n$queryString") { schema }
- assertResult(expectedOutput, s"Result did not match\n$queryString") { outputString }
+ val allJoinConfCombinations = Seq(
+ sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf)
+
+ val joinConfs: Seq[Map[String, String]] = if (regenerateGoldenFiles) {
+ require(
+ !sys.env.contains("SPARK_TPCDS_JOIN_CONF"),
+ "'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'")
+ Seq(sortMergeJoinConf)
+ } else {
+ sys.env.get("SPARK_TPCDS_JOIN_CONF").map { s =>
+ val p = new java.util.Properties()
+ p.load(new java.io.StringReader(s))
+ Seq(p.asScala.toMap)
+ }.getOrElse(allJoinConfCombinations)
}
+ assert(joinConfs.nonEmpty)
+ joinConfs.foreach(conf => require(
+ allJoinConfCombinations.contains(conf),
+ s"Join configurations [$conf] should be one of $allJoinConfCombinations"))
+
if (tpcdsDataPath.nonEmpty) {
tpcdsQueries.foreach { name =>
val queryString = resourceToString(s"tpcds/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(name) {
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out")
- runQuery(queryString, goldenFile)
+ joinConfs.foreach { conf =>
+ System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368
+ runQuery(queryString, goldenFile, conf)
+ }
}
}
@@ -149,7 +214,10 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp
classLoader = Thread.currentThread().getContextClassLoader)
test(s"$name-v2.7") {
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out")
- runQuery(queryString, goldenFile)
+ joinConfs.foreach { conf =>
+ System.gc() // SPARK-37368
+ runQuery(queryString, goldenFile, conf)
+ }
}
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala
new file mode 100644
index 0000000000..e7edd1090b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+trait TPCHBase extends TPCBase {
+
+ override def createTables(): Unit = {
+ tpchCreateTable.values.foreach { sql =>
+ spark.sql(sql)
+ }
+ }
+
+ override def dropTables(): Unit = {
+ tpchCreateTable.keys.foreach { tableName =>
+ spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true)
+ }
+ }
+
+ val tpchCreateTable = Map(
+ "orders" ->
+ """
+ |CREATE TABLE `orders` (
+ |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING,
+ |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING,
+ |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "nation" ->
+ """
+ |CREATE TABLE `nation` (
+ |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "region" ->
+ """
+ |CREATE TABLE `region` (
+ |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "part" ->
+ """
+ |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING,
+ |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING,
+ |`p_retailprice` DECIMAL(10,0), `p_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "partsupp" ->
+ """
+ |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT,
+ |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "customer" ->
+ """
+ |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING,
+ |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0),
+ |`c_mktsegment` STRING, `c_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "supplier" ->
+ """
+ |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING,
+ |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING)
+ |USING parquet
+ """.stripMargin,
+ "lineitem" ->
+ """
+ |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT,
+ |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0),
+ |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING,
+ |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE,
+ |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ val tpchQueries = Seq(
+ "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
+ "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", "q22")
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala
index ba99e18714..89dfcefb1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala
@@ -23,79 +23,7 @@ import org.apache.spark.sql.catalyst.util.resourceToString
* This test suite ensures all the TPC-H queries can be successfully analyzed, optimized
* and compiled without hitting the max iteration threshold.
*/
-class TPCHQuerySuite extends BenchmarkQueryTest {
-
- override def beforeAll(): Unit = {
- super.beforeAll()
-
- sql(
- """
- |CREATE TABLE `orders` (
- |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING,
- |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING,
- |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `nation` (
- |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `region` (
- |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING,
- |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING,
- |`p_retailprice` DECIMAL(10,0), `p_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT,
- |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING,
- |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0),
- |`c_mktsegment` STRING, `c_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING,
- |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING)
- |USING parquet
- """.stripMargin)
-
- sql(
- """
- |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT,
- |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0),
- |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING,
- |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE,
- |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING)
- |USING parquet
- """.stripMargin)
- }
-
- val tpchQueries = Seq(
- "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
- "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", "q22")
-
+class TPCHQuerySuite extends BenchmarkQueryTest with TPCHBase {
tpchQueries.foreach { name =>
val queryString = resourceToString(s"tpch/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index aad1060f11..912811bfda 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -424,7 +424,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
("N", Integer.valueOf(3), null)).toDF("a", "b", "c")
val udf1 = udf((a: String, b: Int, c: Any) => a + b + c)
- val df = input.select(udf1('a, 'b, 'c))
+ val df = input.select(udf1(Symbol("a"), 'b, 'c))
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
// test Java UDF. Java UDF can't have primitive inputs, as it's generic typed.
@@ -554,7 +554,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
spark.udf.register("buildLocalDateInstantType",
udf((d: LocalDate, i: Instant) => LocalDateInstantType(d, i)))
checkAnswer(df.selectExpr(s"buildLocalDateInstantType(d, i) as di")
- .select('di.cast(StringType)),
+ .select(Symbol("di").cast(StringType)),
Row(s"{$expectedDate, $expectedInstant}") :: Nil)
// test null cases
@@ -584,7 +584,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
spark.udf.register("buildTimestampInstantType",
udf((t: Timestamp, i: Instant) => TimestampInstantType(t, i)))
checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti")
- .select('ti.cast(StringType)),
+ .select(Symbol("ti").cast(StringType)),
Row(s"{$expectedTimestamp, $expectedInstant}"))
// test null cases
@@ -730,7 +730,8 @@ class UDFSuite extends QueryTest with SharedSparkSession {
.select(lit(50).as("a"))
.select(struct("a").as("col"))
val error = intercept[AnalysisException](df.select(myUdf(Column("col"))))
- assert(error.getMessage.contains("cannot resolve 'b' given input columns: [a]"))
+ assert(error.getErrorClass == "MISSING_COLUMN")
+ assert(error.messageParameters.sameElements(Array("b", "a")))
}
test("wrong order of input fields for case class") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
index e6f0426428..1d7af84ef6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
@@ -190,5 +190,55 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess
}
}
+ test("SPARK-36607: Support BooleanType in UnwrapCastInBinaryComparison") {
+ // If ANSI mode is on, Spark disallows comparing Int with Boolean.
+ if (!conf.ansiEnabled) {
+ withTable(t) {
+ Seq(Some(true), Some(false), None).toDF().write.saveAsTable(t)
+ val df = spark.table(t)
+
+ checkAnswer(df.where("value = -1"), Seq.empty)
+ checkAnswer(df.where("value = 0"), Row(false))
+ checkAnswer(df.where("value = 1"), Row(true))
+ checkAnswer(df.where("value = 2"), Seq.empty)
+ checkAnswer(df.where("value <=> -1"), Seq.empty)
+ checkAnswer(df.where("value <=> 0"), Row(false))
+ checkAnswer(df.where("value <=> 1"), Row(true))
+ checkAnswer(df.where("value <=> 2"), Seq.empty)
+ }
+ }
+ }
+
+ test("SPARK-39476: Should not unwrap cast from Long to Double/Float") {
+ withTable(t) {
+ Seq((6470759586864300301L))
+ .toDF("c1").write.saveAsTable(t)
+ val df = spark.table(t)
+
+ checkAnswer(
+ df.where("cast(c1 as double) == cast(6470759586864300301L as double)")
+ .select("c1"),
+ Row(6470759586864300301L))
+
+ checkAnswer(
+ df.where("cast(c1 as float) == cast(6470759586864300301L as float)")
+ .select("c1"),
+ Row(6470759586864300301L))
+ }
+ }
+
+ test("SPARK-39476: Should not unwrap cast from Integer to Float") {
+ withTable(t) {
+ Seq((33554435))
+ .toDF("c1").write.saveAsTable(t)
+ val df = spark.table(t)
+
+ checkAnswer(
+ df.where("cast(c1 as float) == cast(33554435 as float)")
+ .select("c1"),
+ Row(33554435))
+ }
+ }
+
private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index cc52b6d8a1..729312c3e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -82,14 +82,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
}
test("register user type: MyDenseVector for MyLabeledPoint") {
- val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
+ val labels: RDD[Double] = pointsRDD.select(Symbol("label")).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
val features: RDD[TestUDT.MyDenseVector] =
- pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v }
+ pointsRDD.select(Symbol("features")).rdd.map { case Row(v: TestUDT.MyDenseVector) => v }
val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0))))
@@ -137,8 +137,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
val df = Seq((1, vec)).toDF("int", "vec")
assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1))
assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1))
- checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec))
- checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec))
+ checkAnswer(df.limit(1).groupBy(Symbol("int")).agg(first(Symbol("vec"))), Row(1, vec))
+ checkAnswer(df.orderBy(Symbol("int")).limit(1).groupBy(Symbol("int"))
+ .agg(first(Symbol("vec"))), Row(1, vec))
}
test("UDTs with JSON") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala
index 1b0898fbc1..19f3f86c94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala
@@ -1070,7 +1070,7 @@ trait AlterTableTests extends SharedSparkSession {
}
}
- test("AlterTable: drop column must exist") {
+ test("AlterTable: drop column must exist if required") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (id int) USING $v2Format")
@@ -1080,10 +1080,15 @@ trait AlterTableTests extends SharedSparkSession {
}
assert(exc.getMessage.contains("Missing field data"))
+
+ // with if exists it should pass
+ sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS data")
+ val table = getTableMetadata(fullTableName(t))
+ assert(table.schema == new StructType().add("id", IntegerType))
}
}
- test("AlterTable: nested drop column must exist") {
+ test("AlterTable: nested drop column must exist if required") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (id int) USING $v2Format")
@@ -1093,6 +1098,27 @@ trait AlterTableTests extends SharedSparkSession {
}
assert(exc.getMessage.contains("Missing field point.x"))
+
+ // with if exists it should pass
+ sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS point.x")
+ val table = getTableMetadata(fullTableName(t))
+ assert(table.schema == new StructType().add("id", IntegerType))
+
+ }
+ }
+
+ test("AlterTable: drop mixed existing/non-existing columns using IF EXISTS") {
+ val t = s"${catalogAndNamespace}table_name"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id int, name string, points array<struct<x: double, y: double>>) " +
+ s"USING $v2Format")
+
+ // with if exists it should pass
+ sql(s"ALTER TABLE $t DROP COLUMNS IF EXISTS " +
+ s"names, name, points.element.z, id, points.element.x")
+ val table = getTableMetadata(fullTableName(t))
+ assert(table.schema == new StructType()
+ .add("points", ArrayType(StructType(Seq(StructField("y", DoubleType))))))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 91ac7db335..98d95e48f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
@@ -83,10 +81,10 @@ class DataSourceV2DataFrameSessionCatalogSuite
test("saveAsTable passes path and provider information properly") {
val t1 = "prop_table"
withTable(t1) {
- spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1)
+ spark.range(20).write.format(v2Format).option("path", "/abc").saveAsTable(t1)
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableInfo = cat.loadTable(Identifier.of(Array("default"), t1))
- assert(tableInfo.properties().get("location") === "abc")
+ assert(tableInfo.properties().get("location") === "file:/abc")
assert(tableInfo.properties().get("provider") === v2Format)
}
}
@@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTable = {
+ properties: java.util.Map[String, String]): InMemoryTable = {
new InMemoryTable(name, schema, partitions, properties)
}
@@ -210,7 +208,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio
verifyTable(t1, df)
// Check that appends are by name
- df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1)
+ df.select(Symbol("data"), Symbol("id")).write.format(v2Format).mode("append").saveAsTable(t1)
verifyTable(t1, df.union(df))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
index 951d787571..03dcfcf7dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
@@ -21,7 +21,7 @@ import java.util.Collections
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -93,7 +93,7 @@ class DataSourceV2DataFrameSuite
assert(spark.table(t1).count() === 0)
// appends are by name not by position
- df.select('data, 'id).write.mode("append").saveAsTable(t1)
+ df.select(Symbol("data"), Symbol("id")).write.mode("append").saveAsTable(t1)
checkAnswer(spark.table(t1), df)
}
}
@@ -207,4 +207,50 @@ class DataSourceV2DataFrameSuite
assert(options.get(optionName) === "false")
}
}
+
+ test("CTAS and RTAS should take write options") {
+
+ var plan: LogicalPlan = null
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ plan = qe.analyzed
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ }
+
+ try {
+ spark.listenerManager.register(listener)
+
+ val t1 = "testcat.ns1.ns2.tbl"
+
+ val df1 = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
+ df1.write.option("option1", "20").saveAsTable(t1)
+
+ sparkContext.listenerBus.waitUntilEmpty()
+ plan match {
+ case o: CreateTableAsSelect =>
+ assert(o.writeOptions == Map("option1" -> "20"))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," +
+ s"got ${other.getClass.getName}: $plan")
+ }
+ checkAnswer(spark.table(t1), df1)
+
+ val df2 = Seq((1L, "d"), (2L, "e"), (3L, "f")).toDF("id", "data")
+ df2.write.option("option2", "30").mode("overwrite").saveAsTable(t1)
+
+ sparkContext.listenerBus.waitUntilEmpty()
+ plan match {
+ case o: ReplaceTableAsSelect =>
+ assert(o.writeOptions == Map("option2" -> "30"))
+ case other =>
+ fail(s"Expected to parse ${classOf[ReplaceTableAsSelect].getName} from query," +
+ s"got ${other.getClass.getName}: $plan")
+ }
+
+ checkAnswer(spark.table(t1), df2)
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index ace66199f3..92a5c55210 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions._
@@ -37,7 +36,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
- private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
+ private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
@@ -53,10 +52,73 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
assert(intercept[AnalysisException](
sql("SELECT testcat.strlen('abc')").collect()
- ).getMessage.contains("is not a FunctionCatalog"))
+ ).getMessage.contains("Catalog testcat does not support functions"))
}
}
+ test("DESCRIBE FUNCTION: only support session catalog") {
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl))
+
+ val e = intercept[AnalysisException] {
+ sql("DESCRIBE FUNCTION testcat.abc")
+ }
+ assert(e.message.contains("Catalog testcat does not support functions"))
+
+ val e1 = intercept[AnalysisException] {
+ sql("DESCRIBE FUNCTION default.ns1.ns2.fun")
+ }
+ assert(e1.message.contains("requires a single-part namespace"))
+ }
+
+ test("SHOW FUNCTIONS: only support session catalog") {
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl))
+
+ val e = intercept[AnalysisException] {
+ sql(s"SHOW FUNCTIONS LIKE testcat.abc")
+ }
+ assert(e.message.contains("Catalog testcat does not support functions"))
+ }
+
+ test("DROP FUNCTION: only support session catalog") {
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl))
+
+ val e = intercept[AnalysisException] {
+ sql("DROP FUNCTION testcat.abc")
+ }
+ assert(e.message.contains("Catalog testcat does not support DROP FUNCTION"))
+
+ val e1 = intercept[AnalysisException] {
+ sql("DROP FUNCTION default.ns1.ns2.fun")
+ }
+ assert(e1.message.contains("requires a single-part namespace"))
+ }
+
+ test("CREATE FUNCTION: only support session catalog") {
+ val e = intercept[AnalysisException] {
+ sql("CREATE FUNCTION testcat.ns1.ns2.fun as 'f'")
+ }
+ assert(e.message.contains("Catalog testcat does not support CREATE FUNCTION"))
+
+ val e1 = intercept[AnalysisException] {
+ sql("CREATE FUNCTION default.ns1.ns2.fun as 'f'")
+ }
+ assert(e1.message.contains("requires a single-part namespace"))
+ }
+
+ test("REFRESH FUNCTION: only support session catalog") {
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl))
+
+ val e = intercept[AnalysisException] {
+ sql("REFRESH FUNCTION testcat.abc")
+ }
+ assert(e.message.contains("Catalog testcat does not support REFRESH FUNCTION"))
+
+ val e1 = intercept[AnalysisException] {
+ sql("REFRESH FUNCTION default.ns1.ns2.fun")
+ }
+ assert(e1.message.contains("requires a single-part namespace"))
+ }
+
test("built-in with non-function catalog should still work") {
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat",
"spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
index 44fbc639a5..95624f3f61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.{DataFrame, Row, SaveMode}
+import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog}
class DataSourceV2SQLSessionCatalogSuite
@@ -64,22 +64,6 @@ class DataSourceV2SQLSessionCatalogSuite
}
}
- test("SPARK-31624: SHOW TBLPROPERTIES working with V2 tables and the session catalog") {
- val t1 = "tbl"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format TBLPROPERTIES " +
- "(key='v', key2='v2')")
-
- checkAnswer(sql(s"SHOW TBLPROPERTIES $t1"), Seq(Row("key", "v"), Row("key2", "v2")))
-
- checkAnswer(sql(s"SHOW TBLPROPERTIES $t1('key')"), Row("key", "v"))
-
- checkAnswer(
- sql(s"SHOW TBLPROPERTIES $t1('keyX')"),
- Row("keyX", s"Table default.$t1 does not have property: keyX"))
- }
- }
-
test("SPARK-33651: allow CREATE EXTERNAL TABLE without LOCATION") {
withTable("t") {
val prop = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY + "=true"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 68a0f8b282..7470911c9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -18,15 +18,16 @@
package org.apache.spark.sql.connector
import java.sql.Timestamp
-import java.time.LocalDate
+import java.time.{Duration, LocalDate, Period}
import scala.collection.JavaConverters._
+import scala.concurrent.duration.MICROSECONDS
-import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException}
+import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
@@ -125,7 +126,7 @@ class DataSourceV2SQLSuite
" PARTITIONED BY (id)" +
" TBLPROPERTIES ('bar'='baz')" +
" COMMENT 'this is a test table'" +
- " LOCATION '/tmp/testcat/table_name'")
+ " LOCATION 'file:/tmp/testcat/table_name'")
val descriptionDf = spark.sql("DESCRIBE TABLE EXTENDED testcat.table_name")
assert(descriptionDf.schema.map(field => (field.name, field.dataType))
=== Seq(
@@ -148,7 +149,7 @@ class DataSourceV2SQLSuite
Array("# Detailed Table Information", "", ""),
Array("Name", "testcat.table_name", ""),
Array("Comment", "this is a test table", ""),
- Array("Location", "/tmp/testcat/table_name", ""),
+ Array("Location", "file:/tmp/testcat/table_name", ""),
Array("Provider", "foo", ""),
Array(TableCatalog.PROP_OWNER.capitalize, defaultUser, ""),
Array("Table Properties", "[bar=baz]", "")))
@@ -173,9 +174,10 @@ class DataSourceV2SQLSuite
Row("data_type", "string"),
Row("comment", "hello")))
- assertAnalysisError(
+ assertAnalysisErrorClass(
s"DESCRIBE $t invalid_col",
- "cannot resolve 'invalid_col' given input columns: [testcat.tbl.data, testcat.tbl.id]")
+ "MISSING_COLUMN",
+ Array("invalid_col", "testcat.tbl.id, testcat.tbl.data"))
}
}
@@ -339,13 +341,15 @@ class DataSourceV2SQLSuite
}
test("CTAS/RTAS: invalid schema if has interval type") {
- Seq("CREATE", "REPLACE").foreach { action =>
- val e1 = intercept[AnalysisException](
- sql(s"$action TABLE table_name USING $v2Format as select interval 1 day"))
- assert(e1.getMessage.contains(s"Cannot use interval type in the table schema."))
- val e2 = intercept[AnalysisException](
- sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)"))
- assert(e2.getMessage.contains(s"Cannot use interval type in the table schema."))
+ withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
+ Seq("CREATE", "REPLACE").foreach { action =>
+ val e1 = intercept[AnalysisException](
+ sql(s"$action TABLE table_name USING $v2Format as select interval 1 day"))
+ assert(e1.getMessage.contains(s"Cannot use interval type in the table schema."))
+ val e2 = intercept[AnalysisException](
+ sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)"))
+ assert(e2.getMessage.contains(s"Cannot use interval type in the table schema."))
+ }
}
}
@@ -404,6 +408,83 @@ class DataSourceV2SQLSuite
}
}
+ test("SPARK-36850: CreateTableAsSelect partitions can be specified using " +
+ "PARTITIONED BY and/or CLUSTERED BY") {
+ val identifier = "testcat.table_name"
+ val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
+ (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
+ df.createOrReplaceTempView("source_table")
+ withTable(identifier) {
+ spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " +
+ s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table")
+ val describe = spark.sql(s"DESCRIBE $identifier")
+ val part1 = describe
+ .filter("col_name = 'Part 0'")
+ .select("data_type").head.getString(0)
+ assert(part1 === "id")
+ val part2 = describe
+ .filter("col_name = 'Part 1'")
+ .select("data_type").head.getString(0)
+ assert(part2 === "bucket(4, data1, data2, data3, data4)")
+ }
+ }
+
+ test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " +
+ "PARTITIONED BY and/or CLUSTERED BY") {
+ val identifier = "testcat.table_name"
+ val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
+ (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
+ df.createOrReplaceTempView("source_table")
+ withTable(identifier) {
+ spark.sql(s"CREATE TABLE $identifier USING foo " +
+ "AS SELECT id FROM source")
+ spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " +
+ s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " +
+ s"AS SELECT * FROM source_table")
+ val describe = spark.sql(s"DESCRIBE $identifier")
+ val part1 = describe
+ .filter("col_name = 'Part 0'")
+ .select("data_type").head.getString(0)
+ assert(part1 === "id")
+ val part2 = describe
+ .filter("col_name = 'Part 1'")
+ .select("data_type").head.getString(0)
+ assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)")
+ }
+ }
+
+ test("SPARK-37545: CreateTableAsSelect should store location as qualified") {
+ val basicIdentifier = "testcat.table_name"
+ val atomicIdentifier = "testcat_atomic.table_name"
+ Seq(basicIdentifier, atomicIdentifier).foreach { identifier =>
+ withTable(identifier) {
+ spark.sql(s"CREATE TABLE $identifier USING foo LOCATION '/tmp/foo' " +
+ "AS SELECT id FROM source")
+ val location = spark.sql(s"DESCRIBE EXTENDED $identifier")
+ .filter("col_name = 'Location'")
+ .select("data_type").head.getString(0)
+ assert(location === "file:/tmp/foo")
+ }
+ }
+ }
+
+ test("SPARK-37546: ReplaceTableAsSelect should store location as qualified") {
+ val basicIdentifier = "testcat.table_name"
+ val atomicIdentifier = "testcat_atomic.table_name"
+ Seq(basicIdentifier, atomicIdentifier).foreach { identifier =>
+ withTable(identifier) {
+ spark.sql(s"CREATE TABLE $identifier USING foo LOCATION '/tmp/foo' " +
+ "AS SELECT id, data FROM source")
+ spark.sql(s"REPLACE TABLE $identifier USING foo LOCATION '/tmp/foo' " +
+ "AS SELECT id FROM source")
+ val location = spark.sql(s"DESCRIBE EXTENDED $identifier")
+ .filter("col_name = 'Location'")
+ .select("data_type").head.getString(0)
+ assert(location === "file:/tmp/foo")
+ }
+ }
+ }
+
test("ReplaceTableAsSelect: basic v2 implementation.") {
val basicCatalog = catalog("testcat").asTableCatalog
val atomicCatalog = catalog("testcat_atomic").asTableCatalog
@@ -962,7 +1043,8 @@ class DataSourceV2SQLSuite
val ex = intercept[AnalysisException] {
sql(s"SELECT ns1.ns2.ns3.tbl.id from $t")
}
- assert(ex.getMessage.contains("cannot resolve 'ns1.ns2.ns3.tbl.id"))
+ assert(ex.getErrorClass == "MISSING_COLUMN")
+ assert(ex.messageParameters.head == "ns1.ns2.ns3.tbl.id")
}
}
@@ -1015,76 +1097,7 @@ class DataSourceV2SQLSuite
sql("SHOW VIEWS FROM testcat")
}
- assert(exception.getMessage.contains("Catalog testcat doesn't support SHOW VIEWS," +
- " only SessionCatalog supports this command."))
- }
-
- test("CreateNameSpace: basic tests") {
- // Session catalog is used.
- withNamespace("ns") {
- sql("CREATE NAMESPACE ns")
- testShowNamespaces("SHOW NAMESPACES", Seq("default", "ns"))
- }
-
- // V2 non-session catalog is used.
- withNamespace("testcat.ns1.ns2") {
- sql("CREATE NAMESPACE testcat.ns1.ns2")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1"))
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2"))
- }
-
- withNamespace("testcat.test") {
- withTempDir { tmpDir =>
- val path = tmpDir.getCanonicalPath
- sql(s"CREATE NAMESPACE testcat.test LOCATION '$path'")
- val metadata =
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("test")).asScala
- val catalogPath = metadata(SupportsNamespaces.PROP_LOCATION)
- assert(catalogPath.equals(catalogPath))
- }
- }
- }
-
- test("CreateNameSpace: test handling of 'IF NOT EXIST'") {
- withNamespace("testcat.ns1") {
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1")
-
- // The 'ns1' namespace already exists, so this should fail.
- val exception = intercept[NamespaceAlreadyExistsException] {
- sql("CREATE NAMESPACE testcat.ns1")
- }
- assert(exception.getMessage.contains("Namespace 'ns1' already exists"))
-
- // The following will be no-op since the namespace already exists.
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1")
- }
- }
-
- test("CreateNameSpace: reserved properties") {
- import SupportsNamespaces._
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) {
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
- val exception = intercept[ParseException] {
- sql(s"CREATE NAMESPACE testcat.reservedTest WITH DBPROPERTIES('$key'='dummyVal')")
- }
- assert(exception.getMessage.contains(s"$key is a reserved namespace property"))
- }
- }
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) {
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
- withNamespace("testcat.reservedTest") {
- sql(s"CREATE NAMESPACE testcat.reservedTest WITH DBPROPERTIES('$key'='foo')")
- assert(sql("DESC NAMESPACE EXTENDED testcat.reservedTest")
- .toDF("k", "v")
- .where("k='Properties'")
- .isEmpty, s"$key is a reserved namespace property and ignored")
- val meta =
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("reservedTest"))
- assert(meta.get(key) == null || !meta.get(key).contains("foo"),
- "reserved properties should not have side effects")
- }
- }
- }
+ assert(exception.getMessage.contains("Catalog testcat does not support views"))
}
test("create/replace/alter table - reserved properties") {
@@ -1155,8 +1168,9 @@ class DataSourceV2SQLSuite
s" ('path'='bar', 'Path'='noop')")
val tableCatalog = catalog("testcat").asTableCatalog
val identifier = Identifier.of(Array(), "reservedTest")
- assert(tableCatalog.loadTable(identifier).properties()
- .get(TableCatalog.PROP_LOCATION) == "foo",
+ val location = tableCatalog.loadTable(identifier).properties()
+ .get(TableCatalog.PROP_LOCATION)
+ assert(location.startsWith("file:") && location.endsWith("foo"),
"path as a table property should not have side effects")
assert(tableCatalog.loadTable(identifier).properties().get("path") == "bar",
"path as a table property should not have side effects")
@@ -1168,151 +1182,6 @@ class DataSourceV2SQLSuite
}
}
- test("DropNamespace: basic tests") {
- // Session catalog is used.
- sql("CREATE NAMESPACE ns")
- testShowNamespaces("SHOW NAMESPACES", Seq("default", "ns"))
- sql("DROP NAMESPACE ns")
- testShowNamespaces("SHOW NAMESPACES", Seq("default"))
-
- // V2 non-session catalog is used.
- sql("CREATE NAMESPACE testcat.ns1")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1"))
- sql("DROP NAMESPACE testcat.ns1")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq())
- }
-
- test("DropNamespace: drop non-empty namespace with a non-cascading mode") {
- sql("CREATE TABLE testcat.ns1.table (id bigint) USING foo")
- sql("CREATE TABLE testcat.ns1.ns2.table (id bigint) USING foo")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1"))
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2"))
-
- def assertDropFails(): Unit = {
- val e = intercept[SparkException] {
- sql("DROP NAMESPACE testcat.ns1")
- }
- assert(e.getMessage.contains("Cannot drop a non-empty namespace: ns1"))
- }
-
- // testcat.ns1.table is present, thus testcat.ns1 cannot be dropped.
- assertDropFails()
- sql("DROP TABLE testcat.ns1.table")
-
- // testcat.ns1.ns2.table is present, thus testcat.ns1 cannot be dropped.
- assertDropFails()
- sql("DROP TABLE testcat.ns1.ns2.table")
-
- // testcat.ns1.ns2 namespace is present, thus testcat.ns1 cannot be dropped.
- assertDropFails()
- sql("DROP NAMESPACE testcat.ns1.ns2")
-
- // Now that testcat.ns1 is empty, it can be dropped.
- sql("DROP NAMESPACE testcat.ns1")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq())
- }
-
- test("DropNamespace: drop non-empty namespace with a cascade mode") {
- sql("CREATE TABLE testcat.ns1.table (id bigint) USING foo")
- sql("CREATE TABLE testcat.ns1.ns2.table (id bigint) USING foo")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1"))
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2"))
-
- sql("DROP NAMESPACE testcat.ns1 CASCADE")
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq())
- }
-
- test("DropNamespace: test handling of 'IF EXISTS'") {
- sql("DROP NAMESPACE IF EXISTS testcat.unknown")
-
- val exception = intercept[NoSuchNamespaceException] {
- sql("DROP NAMESPACE testcat.ns1")
- }
- assert(exception.getMessage.contains("Namespace 'ns1' not found"))
- }
-
- test("DescribeNamespace using v2 catalog") {
- withNamespace("testcat.ns1.ns2") {
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " +
- "'test namespace' LOCATION '/tmp/ns_test'")
- val descriptionDf = sql("DESCRIBE NAMESPACE testcat.ns1.ns2")
- assert(descriptionDf.schema.map(field => (field.name, field.dataType)) ===
- Seq(
- ("info_name", StringType),
- ("info_value", StringType)
- ))
- val description = descriptionDf.collect()
- assert(description === Seq(
- Row("Namespace Name", "ns2"),
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"),
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"),
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser))
- )
- }
- }
-
- test("ALTER NAMESPACE .. SET PROPERTIES using v2 catalog") {
- withNamespace("testcat.ns1.ns2") {
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " +
- "'test namespace' LOCATION '/tmp/ns_test' WITH PROPERTIES ('a'='a','b'='b','c'='c')")
- sql("ALTER NAMESPACE testcat.ns1.ns2 SET PROPERTIES ('a'='b','b'='a')")
- val descriptionDf = sql("DESCRIBE NAMESPACE EXTENDED testcat.ns1.ns2")
- assert(descriptionDf.collect() === Seq(
- Row("Namespace Name", "ns2"),
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"),
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"),
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser),
- Row("Properties", "((a,b),(b,a),(c,c))"))
- )
- }
- }
-
- test("ALTER NAMESPACE .. SET PROPERTIES reserved properties") {
- import SupportsNamespaces._
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) {
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
- withNamespace("testcat.reservedTest") {
- sql("CREATE NAMESPACE testcat.reservedTest")
- val exception = intercept[ParseException] {
- sql(s"ALTER NAMESPACE testcat.reservedTest SET PROPERTIES ('$key'='dummyVal')")
- }
- assert(exception.getMessage.contains(s"$key is a reserved namespace property"))
- }
- }
- }
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) {
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
- withNamespace("testcat.reservedTest") {
- sql(s"CREATE NAMESPACE testcat.reservedTest")
- sql(s"ALTER NAMESPACE testcat.reservedTest SET PROPERTIES ('$key'='foo')")
- assert(sql("DESC NAMESPACE EXTENDED testcat.reservedTest")
- .toDF("k", "v")
- .where("k='Properties'")
- .isEmpty, s"$key is a reserved namespace property and ignored")
- val meta =
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("reservedTest"))
- assert(meta.get(key) == null || !meta.get(key).contains("foo"),
- "reserved properties should not have side effects")
- }
- }
- }
- }
-
- test("ALTER NAMESPACE .. SET LOCATION using v2 catalog") {
- withNamespace("testcat.ns1.ns2") {
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " +
- "'test namespace' LOCATION '/tmp/ns_test_1'")
- sql("ALTER NAMESPACE testcat.ns1.ns2 SET LOCATION '/tmp/ns_test_2'")
- val descriptionDf = sql("DESCRIBE NAMESPACE EXTENDED testcat.ns1.ns2")
- assert(descriptionDf.collect() === Seq(
- Row("Namespace Name", "ns2"),
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"),
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test_2"),
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser))
- )
- }
- }
-
private def testShowNamespaces(
sqlText: String,
expected: Seq[String]): Unit = {
@@ -1335,6 +1204,7 @@ class DataSourceV2SQLSuite
sql("CREATE TABLE testcat2.ns2.ns2_2.table (id bigint) USING foo")
sql("CREATE TABLE testcat2.ns3.ns3_3.table (id bigint) USING foo")
sql("CREATE TABLE testcat2.testcat.table (id bigint) USING foo")
+ sql("CREATE TABLE testcat2.testcat.ns1.ns1_1.table (id bigint) USING foo")
// Catalog is resolved to 'testcat'.
sql("USE testcat.ns1.ns1_1")
@@ -1356,6 +1226,11 @@ class DataSourceV2SQLSuite
assert(catalogManager.currentCatalog.name() == "testcat2")
assert(catalogManager.currentNamespace === Array("testcat"))
+ // Only the namespace is changed (explicit).
+ sql("USE NAMESPACE testcat.ns1.ns1_1")
+ assert(catalogManager.currentCatalog.name() == "testcat2")
+ assert(catalogManager.currentNamespace === Array("testcat", "ns1", "ns1_1"))
+
// Catalog is resolved to `testcat`.
sql("USE testcat")
assert(catalogManager.currentCatalog.name() == "testcat")
@@ -1609,6 +1484,27 @@ class DataSourceV2SQLSuite
}
}
+ test("create table using - with sorted bucket") {
+ val identifier = "testcat.table_name"
+ withTable(identifier) {
+ sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" +
+ s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS")
+ val describe = spark.sql(s"DESCRIBE $identifier")
+ val part1 = describe
+ .filter("col_name = 'Part 0'")
+ .select("data_type").head.getString(0)
+ assert(part1 === "a")
+ val part2 = describe
+ .filter("col_name = 'Part 1'")
+ .select("data_type").head.getString(0)
+ assert(part2 === "b")
+ val part3 = describe
+ .filter("col_name = 'Part 2'")
+ .select("data_type").head.getString(0)
+ assert(part3 === "sorted_bucket(c, d, 4, e, f)")
+ }
+ }
+
test("REFRESH TABLE: v2 table") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -1805,12 +1701,20 @@ class DataSourceV2SQLSuite
"Table or view not found")
// UPDATE non-existing column
- assertAnalysisError(
+ assertAnalysisErrorClass(
s"UPDATE $t SET dummy='abc'",
- "cannot resolve")
- assertAnalysisError(
+ "MISSING_COLUMN",
+ Array(
+ "dummy",
+ "testcat.ns1.ns2.tbl.p, testcat.ns1.ns2.tbl.id, " +
+ "testcat.ns1.ns2.tbl.age, testcat.ns1.ns2.tbl.name"))
+ assertAnalysisErrorClass(
s"UPDATE $t SET name='abc' WHERE dummy=1",
- "cannot resolve")
+ "MISSING_COLUMN",
+ Array(
+ "dummy",
+ "testcat.ns1.ns2.tbl.p, testcat.ns1.ns2.tbl.id, " +
+ "testcat.ns1.ns2.tbl.age, testcat.ns1.ns2.tbl.name"))
// UPDATE is not implemented yet.
val e = intercept[UnsupportedOperationException] {
@@ -1961,109 +1865,6 @@ class DataSourceV2SQLSuite
}
}
- test("SPARK-33898: SHOW CREATE TABLE AS SERDE") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo")
- val e = intercept[AnalysisException] {
- sql(s"SHOW CREATE TABLE $t AS SERDE")
- }
- assert(e.message.contains(s"SHOW CREATE TABLE AS SERDE is not supported for v2 tables."))
- }
- }
-
- test("SPARK-33898: SHOW CREATE TABLE") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(
- s"""
- |CREATE TABLE $t (
- | a bigint NOT NULL,
- | b bigint,
- | c bigint,
- | `extra col` ARRAY<INT>,
- | `<another>` STRUCT<x: INT, y: ARRAY<BOOLEAN>>
- |)
- |USING foo
- |OPTIONS (
- | from = 0,
- | to = 1,
- | via = 2)
- |COMMENT 'This is a comment'
- |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4)
- |PARTITIONED BY (a)
- |LOCATION '/tmp'
- """.stripMargin)
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t")
- assert(showDDL === Array(
- "CREATE TABLE testcat.ns1.ns2.tbl (",
- "`a` BIGINT NOT NULL,",
- "`b` BIGINT,",
- "`c` BIGINT,",
- "`extra col` ARRAY<INT>,",
- "`<another>` STRUCT<`x`: INT, `y`: ARRAY<BOOLEAN>>)",
- "USING foo",
- "OPTIONS(",
- "'from' = '0',",
- "'to' = '1',",
- "'via' = '2')",
- "PARTITIONED BY (a)",
- "COMMENT 'This is a comment'",
- "LOCATION '/tmp'",
- "TBLPROPERTIES(",
- "'prop1' = '1',",
- "'prop2' = '2',",
- "'prop3' = '3',",
- "'prop4' = '4')"
- ))
- }
- }
-
- test("SPARK-33898: SHOW CREATE TABLE WITH AS SELECT") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(
- s"""
- |CREATE TABLE $t
- |USING foo
- |AS SELECT 1 AS a, "foo" AS b
- """.stripMargin)
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t")
- assert(showDDL === Array(
- "CREATE TABLE testcat.ns1.ns2.tbl (",
- "`a` INT,",
- "`b` STRING)",
- "USING foo"
- ))
- }
- }
-
- test("SPARK-33898: SHOW CREATE TABLE PARTITIONED BY Transforms") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(
- s"""
- |CREATE TABLE $t (a INT, b STRING, ts TIMESTAMP) USING foo
- |PARTITIONED BY (
- | a,
- | bucket(16, b),
- | years(ts),
- | months(ts),
- | days(ts),
- | hours(ts))
- """.stripMargin)
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t")
- assert(showDDL === Array(
- "CREATE TABLE testcat.ns1.ns2.tbl (",
- "`a` INT,",
- "`b` STRING,",
- "`ts` TIMESTAMP)",
- "USING foo",
- "PARTITIONED BY (a, bucket(16, b), years(ts), months(ts), days(ts), hours(ts))"
- ))
- }
- }
-
test("CACHE/UNCACHE TABLE") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -2117,120 +1918,7 @@ class DataSourceV2SQLSuite
val e = intercept[AnalysisException] {
sql(s"CREATE VIEW $v AS SELECT 1")
}
- assert(e.message.contains("CREATE VIEW is only supported with v1 tables"))
- }
-
- test("SHOW TBLPROPERTIES: v2 table") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- val user = "andrew"
- val status = "new"
- val provider = "foo"
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING $provider " +
- s"TBLPROPERTIES ('user'='$user', 'status'='$status')")
-
- val properties = sql(s"SHOW TBLPROPERTIES $t")
-
- val schema = new StructType()
- .add("key", StringType, nullable = false)
- .add("value", StringType, nullable = false)
-
- val expected = Seq(
- Row("status", status),
- Row("user", user))
-
- assert(properties.schema === schema)
- assert(expected === properties.collect())
- }
- }
-
- test("SHOW TBLPROPERTIES(key): v2 table") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- val user = "andrew"
- val status = "new"
- val provider = "foo"
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING $provider " +
- s"TBLPROPERTIES ('user'='$user', 'status'='$status')")
-
- val properties = sql(s"SHOW TBLPROPERTIES $t ('status')")
-
- val expected = Seq(Row("status", status))
-
- assert(expected === properties.collect())
- }
- }
-
- test("SHOW TBLPROPERTIES(key): v2 table, key not found") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- val nonExistingKey = "nonExistingKey"
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo " +
- s"TBLPROPERTIES ('user'='andrew', 'status'='new')")
-
- val properties = sql(s"SHOW TBLPROPERTIES $t ('$nonExistingKey')")
-
- val expected = Seq(Row(nonExistingKey, s"Table $t does not have property: $nonExistingKey"))
-
- assert(expected === properties.collect())
- }
- }
-
- test("DESCRIBE FUNCTION: only support session catalog") {
- val e = intercept[AnalysisException] {
- sql("DESCRIBE FUNCTION testcat.ns1.ns2.fun")
- }
- assert(e.message.contains("function is only supported in v1 catalog"))
-
- val e1 = intercept[AnalysisException] {
- sql("DESCRIBE FUNCTION default.ns1.ns2.fun")
- }
- assert(e1.message.contains("requires a single-part namespace"))
- }
-
- test("SHOW FUNCTIONS not valid v1 namespace") {
- val function = "testcat.ns1.ns2.fun"
-
- val e = intercept[AnalysisException] {
- sql(s"SHOW FUNCTIONS LIKE $function")
- }
- assert(e.message.contains("function is only supported in v1 catalog"))
- }
-
- test("DROP FUNCTION: only support session catalog") {
- val e = intercept[AnalysisException] {
- sql("DROP FUNCTION testcat.ns1.ns2.fun")
- }
- assert(e.message.contains("function is only supported in v1 catalog"))
-
- val e1 = intercept[AnalysisException] {
- sql("DROP FUNCTION default.ns1.ns2.fun")
- }
- assert(e1.message.contains("requires a single-part namespace"))
- }
-
- test("CREATE FUNCTION: only support session catalog") {
- val e = intercept[AnalysisException] {
- sql("CREATE FUNCTION testcat.ns1.ns2.fun as 'f'")
- }
- assert(e.message.contains("function is only supported in v1 catalog"))
-
- val e1 = intercept[AnalysisException] {
- sql("CREATE FUNCTION default.ns1.ns2.fun as 'f'")
- }
- assert(e1.message.contains("requires a single-part namespace"))
- }
-
- test("REFRESH FUNCTION: only support session catalog") {
- val e = intercept[AnalysisException] {
- sql("REFRESH FUNCTION testcat.ns1.ns2.fun")
- }
- assert(e.message.contains("function is only supported in v1 catalog"))
-
- val e1 = intercept[AnalysisException] {
- sql("REFRESH FUNCTION default.ns1.ns2.fun")
- }
- assert(e1.message.contains("requires a single-part namespace"))
+ assert(e.message.contains("Catalog testcat does not support views"))
}
test("global temp view should not be masked by v2 catalog") {
@@ -2457,14 +2145,6 @@ class DataSourceV2SQLSuite
.head().getString(1) === expectedComment)
}
- test("SPARK-30799: temp view name can't contain catalog name") {
- val sessionCatalogName = CatalogManager.SESSION_CATALOG_NAME
- val e2 = intercept[AnalysisException] {
- sql(s"CREATE TEMP VIEW $sessionCatalogName.v AS SELECT 1")
- }
- assert(e2.message.contains("It is not allowed to add database prefix"))
- }
-
test("SPARK-31015: star expression should work for qualified column names for v2 tables") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -2524,100 +2204,6 @@ class DataSourceV2SQLSuite
}
}
- test("SPARK-31255: Project a metadata column") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
-
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1")
- val dfQuery = spark.table(t1).select("id", "data", "index", "_partition")
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
- }
- }
- }
-
- test("SPARK-31255: Projects data column when metadata column has the same name") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, index), index)")
- sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")
-
- val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1")
- val dfQuery = spark.table(t1).select("index", "data", "_partition")
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
- }
- }
- }
-
- test("SPARK-31255: * expansion does not include metadata columns") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")
-
- val sqlQuery = spark.sql(s"SELECT * FROM $t1")
- val dfQuery = spark.table(t1)
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
- }
- }
- }
-
- test("SPARK-31255: metadata column should only be produced when necessary") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
-
- val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
- val dfQuery = spark.table(t1).filter("index = 0")
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- assert(query.schema.fieldNames.toSeq == Seq("id", "data"))
- }
- }
- }
-
- test("SPARK-34547: metadata columns are resolved last") {
- val t1 = s"${catalogAndNamespace}tableOne"
- val t2 = "t2"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
- withTempView(t2) {
- sql(s"CREATE TEMPORARY VIEW $t2 AS SELECT * FROM " +
- s"VALUES (1, -1), (2, -2), (3, -3) AS $t2(id, index)")
-
- val sqlQuery = spark.sql(s"SELECT $t1.id, $t2.id, data, index, $t1.index, $t2.index FROM " +
- s"$t1 JOIN $t2 WHERE $t1.id = $t2.id")
- val t1Table = spark.table(t1)
- val t2Table = spark.table(t2)
- val dfQuery = t1Table.join(t2Table, t1Table.col("id") === t2Table.col("id"))
- .select(s"$t1.id", s"$t2.id", "data", "index", s"$t1.index", s"$t2.index")
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query,
- Seq(
- Row(1, 1, "a", -1, 0, -1),
- Row(2, 2, "b", -2, 0, -2),
- Row(3, 3, "c", -3, 0, -3)
- )
- )
- }
- }
- }
- }
-
test("SPARK-33505: insert into partitioned table") {
val t = "testpart.ns1.ns2.tbl"
withTable(t) {
@@ -2702,27 +2288,6 @@ class DataSourceV2SQLSuite
}
}
- test("SPARK-34555: Resolve DataFrame metadata column") {
- val tbl = s"${catalogAndNamespace}table"
- withTable(tbl) {
- sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
- val table = spark.table(tbl)
- val dfQuery = table.select(
- table.col("id"),
- table.col("data"),
- table.col("index"),
- table.col("_partition")
- )
-
- checkAnswer(
- dfQuery,
- Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
- )
- }
- }
-
test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") {
val tbl = s"${catalogAndNamespace}tbl"
withTable(tbl) {
@@ -2766,125 +2331,265 @@ class DataSourceV2SQLSuite
}
}
- test("SPARK-34577: drop/add columns to a dataset of `DESCRIBE NAMESPACE`") {
- withNamespace("ns") {
- sql("CREATE NAMESPACE ns")
- val description = sql(s"DESCRIBE NAMESPACE ns")
- val noCommentDataset = description.drop("info_name")
- val expectedSchema = new StructType()
- .add(
- name = "info_value",
- dataType = StringType,
- nullable = true,
- metadata = new MetadataBuilder()
- .putString("comment", "value of the namespace info").build())
- assert(noCommentDataset.schema === expectedSchema)
- val isNullDataset = noCommentDataset
- .withColumn("is_null", noCommentDataset("info_value").isNull)
- assert(isNullDataset.schema === expectedSchema.add("is_null", BooleanType, false))
- }
+ test("SPARK-36481: Test for SET CATALOG statement") {
+ val catalogManager = spark.sessionState.catalogManager
+ assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME)
+
+ sql("SET CATALOG testcat")
+ assert(catalogManager.currentCatalog.name() == "testcat")
+
+ sql("SET CATALOG testcat2")
+ assert(catalogManager.currentCatalog.name() == "testcat2")
+
+ val errMsg = intercept[CatalogNotFoundException] {
+ sql("SET CATALOG not_exist_catalog")
+ }.getMessage
+ assert(errMsg.contains("Catalog 'not_exist_catalog' plugin class not found"))
}
- test("SPARK-34923: do not propagate metadata columns through Project") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ test("SPARK-35973: ShowCatalogs") {
+ val schema = new StructType()
+ .add("catalog", StringType, nullable = false)
- assertThrows[AnalysisException] {
- sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)")
- }
- assertThrows[AnalysisException] {
- spark.table(t1).select("id", "data").select("index", "_partition")
- }
- }
+ val df = sql("SHOW CATALOGS")
+ assert(df.schema === schema)
+ assert(df.collect === Array(Row("spark_catalog")))
+
+ sql("use testcat")
+ sql("use testpart")
+ sql("use testcat2")
+ assert(sql("SHOW CATALOGS").collect === Array(
+ Row("spark_catalog"), Row("testcat"), Row("testcat2"), Row("testpart")))
+
+ assert(sql("SHOW CATALOGS LIKE 'test*'").collect === Array(
+ Row("testcat"), Row("testcat2"), Row("testpart")))
+
+ assert(sql("SHOW CATALOGS LIKE 'testcat*'").collect === Array(
+ Row("testcat"), Row("testcat2")))
}
- test("SPARK-34923: do not propagate metadata columns through View") {
- val t1 = s"${catalogAndNamespace}table"
- val view = "view"
+ test("CREATE INDEX should fail") {
+ val t = "testcat.tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string COMMENT 'hello') USING foo")
+ val e1 = intercept[AnalysisException] {
+ sql(s"CREATE index i1 ON $t(non_exist)")
+ }
+ assert(e1.getMessage.contains(s"Missing field non_exist in table $t"))
- withTable(t1) {
- withTempView(view) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
- sql(s"CACHE TABLE $view AS SELECT * FROM $t1")
- assertThrows[AnalysisException] {
- sql(s"SELECT index, _partition FROM $view")
+ val e2 = intercept[AnalysisException] {
+ sql(s"CREATE index i1 ON $t(id)")
+ }
+ assert(e2.getMessage.contains(s"CreateIndex is not supported in this table $t."))
+ }
+ }
+
+ test("SPARK-37294: insert ANSI intervals into a table partitioned by the interval columns") {
+ val tbl = "testpart.interval_table"
+ Seq(PartitionOverwriteMode.DYNAMIC, PartitionOverwriteMode.STATIC).foreach { mode =>
+ withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> mode.toString) {
+ withTable(tbl) {
+ sql(
+ s"""
+ |CREATE TABLE $tbl (i INT, part1 INTERVAL YEAR, part2 INTERVAL DAY) USING $v2Format
+ |PARTITIONED BY (part1, part2)
+ """.stripMargin)
+
+ sql(
+ s"""ALTER TABLE $tbl ADD PARTITION (
+ |part1 = INTERVAL '2' YEAR,
+ |part2 = INTERVAL '3' DAY)""".stripMargin)
+ sql(s"INSERT OVERWRITE TABLE $tbl SELECT 1, INTERVAL '2' YEAR, INTERVAL '3' DAY")
+ sql(s"INSERT INTO TABLE $tbl SELECT 4, INTERVAL '5' YEAR, INTERVAL '6' DAY")
+ sql(
+ s"""
+ |INSERT INTO $tbl
+ | PARTITION (part1 = INTERVAL '8' YEAR, part2 = INTERVAL '9' DAY)
+ |SELECT 7""".stripMargin)
+
+ checkAnswer(
+ spark.table(tbl),
+ Seq(Row(1, Period.ofYears(2), Duration.ofDays(3)),
+ Row(4, Period.ofYears(5), Duration.ofDays(6)),
+ Row(7, Period.ofYears(8), Duration.ofDays(9))))
}
}
}
}
- test("SPARK-34923: propagate metadata columns through Filter") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ test("Check HasPartitionKey from InMemoryPartitionTable") {
+ val t = "testpart.tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id string) USING foo PARTITIONED BY (key int)")
+ val table = catalog("testpart").asTableCatalog
+ .loadTable(Identifier.of(Array(), "tbl"))
+ .asInstanceOf[InMemoryPartitionTable]
+
+ sql(s"INSERT INTO $t VALUES ('a', 1), ('b', 2), ('c', 3)")
+ var partKeys = table.data.map(_.partitionKey().getInt(0))
+ assert(partKeys.length == 3)
+ assert(partKeys.toSet == Set(1, 2, 3))
+
+ sql(s"ALTER TABLE $t DROP PARTITION (key=3)")
+ partKeys = table.data.map(_.partitionKey().getInt(0))
+ assert(partKeys.length == 2)
+ assert(partKeys.toSet == Set(1, 2))
+
+ sql(s"ALTER TABLE $t ADD PARTITION (key=4)")
+ partKeys = table.data.map(_.partitionKey().getInt(0))
+ assert(partKeys.length == 3)
+ assert(partKeys.toSet == Set(1, 2, 4))
+
+ sql(s"INSERT INTO $t VALUES ('c', 3), ('e', 5)")
+ partKeys = table.data.map(_.partitionKey().getInt(0))
+ assert(partKeys.length == 5)
+ assert(partKeys.toSet == Set(1, 2, 3, 4, 5))
+ }
+ }
+
+ test("time travel") {
+ sql("use testcat")
+ // The testing in-memory table simply append the version/timestamp to the table name when
+ // looking up tables.
+ val t1 = "testcat.tSnapshot123456789"
+ val t2 = "testcat.t2345678910"
+ withTable(t1, t2) {
+ sql(s"CREATE TABLE $t1 (id int) USING foo")
+ sql(s"CREATE TABLE $t2 (id int) USING foo")
+
+ sql(s"INSERT INTO $t1 VALUES (1)")
+ sql(s"INSERT INTO $t1 VALUES (2)")
+ sql(s"INSERT INTO $t2 VALUES (3)")
+ sql(s"INSERT INTO $t2 VALUES (4)")
+
+ assert(sql("SELECT * FROM t VERSION AS OF 'Snapshot123456789'").collect
+ === Array(Row(1), Row(2)))
+ assert(sql("SELECT * FROM t VERSION AS OF 2345678910").collect
+ === Array(Row(3), Row(4)))
+ }
+
+ val ts1 = DateTimeUtils.stringToTimestampAnsi(
+ UTF8String.fromString("2019-01-29 00:37:58"),
+ DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
+ val ts2 = DateTimeUtils.stringToTimestampAnsi(
+ UTF8String.fromString("2021-01-29 00:00:00"),
+ DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
+ val ts1InSeconds = MICROSECONDS.toSeconds(ts1).toString
+ val ts2InSeconds = MICROSECONDS.toSeconds(ts2).toString
+ val t3 = s"testcat.t$ts1"
+ val t4 = s"testcat.t$ts2"
+
+ withTable(t3, t4) {
+ sql(s"CREATE TABLE $t3 (id int) USING foo")
+ sql(s"CREATE TABLE $t4 (id int) USING foo")
+
+ sql(s"INSERT INTO $t3 VALUES (5)")
+ sql(s"INSERT INTO $t3 VALUES (6)")
+ sql(s"INSERT INTO $t4 VALUES (7)")
+ sql(s"INSERT INTO $t4 VALUES (8)")
+
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF '2019-01-29 00:37:58'").collect
+ === Array(Row(5), Row(6)))
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF '2021-01-29 00:00:00'").collect
+ === Array(Row(7), Row(8)))
+ assert(sql(s"SELECT * FROM t TIMESTAMP AS OF $ts1InSeconds").collect
+ === Array(Row(5), Row(6)))
+ assert(sql(s"SELECT * FROM t TIMESTAMP AS OF $ts2InSeconds").collect
+ === Array(Row(7), Row(8)))
+ assert(sql(s"SELECT * FROM t FOR SYSTEM_TIME AS OF $ts1InSeconds").collect
+ === Array(Row(5), Row(6)))
+ assert(sql(s"SELECT * FROM t FOR SYSTEM_TIME AS OF $ts2InSeconds").collect
+ === Array(Row(7), Row(8)))
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF make_date(2021, 1, 29)").collect
+ === Array(Row(7), Row(8)))
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF to_timestamp('2021-01-29 00:00:00')").collect
+ === Array(Row(7), Row(8)))
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1")
- val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition")
+ val e1 = intercept[AnalysisException](
+ sql("SELECT * FROM t TIMESTAMP AS OF INTERVAL 1 DAY").collect()
+ )
+ assert(e1.message.contains("is not a valid timestamp expression for time travel"))
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
- }
- }
- }
+ val e2 = intercept[AnalysisException](
+ sql("SELECT * FROM t TIMESTAMP AS OF 'abc'").collect()
+ )
+ assert(e2.message.contains("is not a valid timestamp expression for time travel"))
- test("SPARK-34923: propagate metadata columns through Sort") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ val e3 = intercept[AnalysisException](
+ sql("SELECT * FROM t TIMESTAMP AS OF current_user()").collect()
+ )
+ assert(e3.message.contains("is not a valid timestamp expression for time travel"))
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id")
- val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition")
+ val e4 = intercept[AnalysisException](
+ sql("SELECT * FROM t TIMESTAMP AS OF CAST(rand() AS STRING)").collect()
+ )
+ assert(e4.message.contains("is not a valid timestamp expression for time travel"))
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
- }
- }
- }
+ val e5 = intercept[AnalysisException](
+ sql("SELECT * FROM t TIMESTAMP AS OF abs(true)").collect()
+ )
+ assert(e5.message.contains("cannot resolve 'abs(true)' due to data type mismatch"))
- test("SPARK-34923: propagate metadata columns through RepartitionBy") {
- val t1 = s"${catalogAndNamespace}table"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
-
- val sqlQuery = spark.sql(
- s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1")
- val tbl = spark.table(t1)
- val dfQuery = tbl.repartitionByRange(3, tbl.col("id"))
- .select("id", "data", "index", "_partition")
-
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
- }
+ val e6 = intercept[AnalysisException](
+ sql("SELECT * FROM parquet.`/the/path` VERSION AS OF 1")
+ )
+ assert(e6.message.contains("Cannot time travel path-based tables"))
+
+ val e7 = intercept[AnalysisException](
+ sql("WITH x AS (SELECT 1) SELECT * FROM x VERSION AS OF 1")
+ )
+ assert(e7.message.contains("Cannot time travel subqueries from WITH clause"))
}
}
- test("SPARK-34923: propagate metadata columns through SubqueryAlias") {
- val t1 = s"${catalogAndNamespace}table"
- val sbq = "sbq"
- withTable(t1) {
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
-
- val sqlQuery = spark.sql(
- s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq")
- val dfQuery = spark.table(t1).as(sbq).select(
- s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
+ test("SPARK-37827: put build-in properties into V1Table.properties to adapt v2 command") {
+ val t = "tbl"
+ withTable(t) {
+ sql(
+ s"""
+ |CREATE TABLE $t (
+ | a bigint,
+ | b bigint
+ |)
+ |using parquet
+ |OPTIONS (
+ | from = 0,
+ | to = 1)
+ |COMMENT 'This is a comment'
+ |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2')
+ |PARTITIONED BY (a)
+ |LOCATION '/tmp'
+ """.stripMargin)
- Seq(sqlQuery, dfQuery).foreach { query =>
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
- }
+ val table = spark.sessionState.catalogManager.v2SessionCatalog.asTableCatalog
+ .loadTable(Identifier.of(Array("default"), t))
+ val properties = table.properties
+ assert(properties.get(TableCatalog.PROP_PROVIDER) == "parquet")
+ assert(properties.get(TableCatalog.PROP_COMMENT) == "This is a comment")
+ assert(properties.get(TableCatalog.PROP_LOCATION) == "file:///tmp")
+ assert(properties.containsKey(TableCatalog.PROP_OWNER))
+ assert(properties.get(TableCatalog.PROP_EXTERNAL) == "true")
+ assert(properties.get(s"${TableCatalog.OPTION_PREFIX}from") == "0")
+ assert(properties.get(s"${TableCatalog.OPTION_PREFIX}to") == "1")
+ assert(properties.get("prop1") == "1")
+ assert(properties.get("prop2") == "2")
+ }
+ }
+
+ test("SPARK-41154: Incorrect relation caching for queries with time travel spec") {
+ sql("use testcat")
+ val t1 = "testcat.t1"
+ val t2 = "testcat.t2"
+ withTable(t1, t2) {
+ sql(s"CREATE TABLE $t1 USING foo AS SELECT 1 as c")
+ sql(s"CREATE TABLE $t2 USING foo AS SELECT 2 as c")
+ assert(
+ sql("""
+ |SELECT * FROM t VERSION AS OF '1'
+ |UNION ALL
+ |SELECT * FROM t VERSION AS OF '2'
+ |""".stripMargin
+ ).collect() === Array(Row(1), Row(2)))
}
}
@@ -2895,15 +2600,24 @@ class DataSourceV2SQLSuite
assert(e.message.contains(s"$sqlCommand is not supported for v2 tables"))
}
- private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = {
- val errMsg = intercept[AnalysisException] {
+ private def assertAnalysisError(
+ sqlStatement: String,
+ expectedError: String): Unit = {
+ val ex = intercept[AnalysisException] {
sql(sqlStatement)
- }.getMessage
- assert(errMsg.contains(expectedError))
+ }
+ assert(ex.getMessage.contains(expectedError))
}
- private def getShowCreateDDL(showCreateTableSql: String): Array[String] = {
- sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim)
+ private def assertAnalysisErrorClass(
+ sqlStatement: String,
+ expectedErrorClass: String,
+ expectedErrorMessageParameters: Array[String]): Unit = {
+ val ex = intercept[AnalysisException] {
+ sql(sqlStatement)
+ }
+ assert(ex.getErrorClass == expectedErrorClass)
+ assert(ex.messageParameters.sameElements(expectedErrorMessageParameters))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index b42d48d873..491d27e546 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -18,11 +18,8 @@
package org.apache.spark.sql.connector
import java.io.File
-import java.util
import java.util.OptionalLong
-import scala.collection.JavaConverters._
-
import test.org.apache.spark.sql.connector._
import org.apache.spark.SparkException
@@ -30,14 +27,16 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
-import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, Transform}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read._
-import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -54,6 +53,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}.head
}
+ private def getBatchWithV2Filter(query: DataFrame): AdvancedBatchWithV2Filter = {
+ query.queryExecution.executedPlan.collect {
+ case d: BatchScanExec =>
+ d.batch.asInstanceOf[AdvancedBatchWithV2Filter]
+ }.head
+ }
+
private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = {
query.queryExecution.executedPlan.collect {
case d: BatchScanExec =>
@@ -61,13 +67,21 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}.head
}
+ private def getJavaBatchWithV2Filter(
+ query: DataFrame): JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter = {
+ query.queryExecution.executedPlan.collect {
+ case d: BatchScanExec =>
+ d.batch.asInstanceOf[JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter]
+ }.head
+ }
+
test("simplest implementation") {
Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
- checkAnswer(df.select('j), (0 until 10).map(i => Row(-i)))
- checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i)))
+ checkAnswer(df.select(Symbol("j")), (0 until 10).map(i => Row(-i)))
+ checkAnswer(df.filter(Symbol("i") > 5), (6 until 10).map(i => Row(i, -i)))
}
}
}
@@ -78,7 +92,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
- val q1 = df.select('j)
+ val q1 = df.select(Symbol("j"))
checkAnswer(q1, (0 until 10).map(i => Row(-i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val batch = getBatch(q1)
@@ -90,7 +104,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
assert(batch.requiredSchema.fieldNames === Seq("j"))
}
- val q2 = df.filter('i > 3)
+ val q2 = df.filter(Symbol("i") > 3)
checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val batch = getBatch(q2)
@@ -102,7 +116,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
assert(batch.requiredSchema.fieldNames === Seq("i", "j"))
}
- val q3 = df.select('i).filter('i > 6)
+ val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6)
checkAnswer(q3, (7 until 10).map(i => Row(i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val batch = getBatch(q3)
@@ -114,16 +128,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
assert(batch.requiredSchema.fieldNames === Seq("i"))
}
- val q4 = df.select('j).filter('j < -10)
+ val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10)
checkAnswer(q4, Nil)
if (cls == classOf[AdvancedDataSourceV2]) {
val batch = getBatch(q4)
- // 'j < 10 is not supported by the testing data source.
+ // Symbol("j") < 10 is not supported by the testing data source.
assert(batch.filters.isEmpty)
assert(batch.requiredSchema.fieldNames === Seq("j"))
} else {
val batch = getJavaBatch(q4)
- // 'j < 10 is not supported by the testing data source.
+ // Symbol("j") < 10 is not supported by the testing data source.
assert(batch.filters.isEmpty)
assert(batch.requiredSchema.fieldNames === Seq("j"))
}
@@ -131,13 +145,73 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
+ test("advanced implementation with V2 Filter") {
+ Seq(classOf[AdvancedDataSourceV2WithV2Filter], classOf[JavaAdvancedDataSourceV2WithV2Filter])
+ .foreach { cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
+
+ val q1 = df.select(Symbol("j"))
+ checkAnswer(q1, (0 until 10).map(i => Row(-i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q1)
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q1)
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ }
+
+ val q2 = df.filter(Symbol("i") > 3)
+ checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q2)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q2)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j"))
+ }
+
+ val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6)
+ checkAnswer(q3, (7 until 10).map(i => Row(i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q3)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q3)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i"))
+ }
+
+ val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10)
+ checkAnswer(q4, Nil)
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q4)
+ // Symbol("j") < 10 is not supported by the testing data source.
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q4)
+ // Symbol("j") < 10 is not supported by the testing data source.
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ }
+ }
+ }
+ }
+
test("columnar batch scan implementation") {
Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 90).map(i => Row(i, -i)))
- checkAnswer(df.select('j), (0 until 90).map(i => Row(-i)))
- checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i)))
+ checkAnswer(df.select(Symbol("j")), (0 until 90).map(i => Row(-i)))
+ checkAnswer(df.filter(Symbol("i") > 50), (51 until 90).map(i => Row(i, -i)))
}
}
}
@@ -161,45 +235,47 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
"supports external metadata") {
withTempDir { dir =>
val cls = classOf[SupportsExternalMetadataWritableDataSource].getName
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls)
- .option("path", dir.getCanonicalPath).mode("append").save()
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls).option("path", dir.getCanonicalPath).mode("append").save()
val schema = new StructType().add("i", "long").add("j", "long")
checkAnswer(
spark.read.format(cls).option("path", dir.getCanonicalPath).schema(schema).load(),
- spark.range(10).select('id, -'id))
+ spark.range(10).select(Symbol("id"), -Symbol("id")))
}
}
test("partitioning reporting") {
import org.apache.spark.sql.functions.{count, sum}
- Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls =>
- withClue(cls.getName) {
- val df = spark.read.format(cls.getName).load()
- checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
-
- val groupByColA = df.groupBy('i).agg(sum('j))
- checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
- assert(collectFirst(groupByColA.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
- }.isEmpty)
-
- val groupByColAB = df.groupBy('i, 'j).agg(count("*"))
- checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
- assert(collectFirst(groupByColAB.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
- }.isEmpty)
-
- val groupByColB = df.groupBy('j).agg(sum('i))
- checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
- assert(collectFirst(groupByColB.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
- }.isDefined)
-
- val groupByAPlusB = df.groupBy('i + 'j).agg(count("*"))
- checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
- assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) {
- case e: ShuffleExchangeExec => e
- }.isDefined)
+ withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") {
+ Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
+
+ val groupByColA = df.groupBy(Symbol("i")).agg(sum(Symbol("j")))
+ checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
+ assert(collectFirst(groupByColA.queryExecution.executedPlan) {
+ case e: ShuffleExchangeExec => e
+ }.isEmpty)
+
+ val groupByColAB = df.groupBy(Symbol("i"), Symbol("j")).agg(count("*"))
+ checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
+ assert(collectFirst(groupByColAB.queryExecution.executedPlan) {
+ case e: ShuffleExchangeExec => e
+ }.isEmpty)
+
+ val groupByColB = df.groupBy(Symbol("j")).agg(sum(Symbol("i")))
+ checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
+ assert(collectFirst(groupByColB.queryExecution.executedPlan) {
+ case e: ShuffleExchangeExec => e
+ }.isDefined)
+
+ val groupByAPlusB = df.groupBy(Symbol("i") + Symbol("j")).agg(count("*"))
+ checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
+ assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) {
+ case e: ShuffleExchangeExec => e
+ }.isDefined)
+ }
}
}
}
@@ -233,37 +309,43 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val path = file.getCanonicalPath
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.option("path", path).mode("append").save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
- spark.range(10).select('id, -'id))
+ spark.range(10).select(Symbol("id"), -Symbol("id")))
// default save mode is ErrorIfExists
intercept[AnalysisException] {
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.option("path", path).save()
}
- spark.range(10).select('id as 'i, -'id as 'j).write.mode("append").format(cls.getName)
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.mode("append").format(cls.getName)
.option("path", path).save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
- spark.range(10).union(spark.range(10)).select('id, -'id))
+ spark.range(10).union(spark.range(10)).select(Symbol("id"), -Symbol("id")))
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.option("path", path).mode("overwrite").save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
- spark.range(5).select('id, -'id))
+ spark.range(5).select(Symbol("id"), -Symbol("id")))
val e = intercept[AnalysisException] {
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.option("path", path).mode("ignore").save()
}
assert(e.message.contains("please use Append or Overwrite modes instead"))
val e2 = intercept[AnalysisException] {
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.option("path", path).mode("error").save()
}
assert(e2.getMessage.contains("please use Append or Overwrite modes instead"))
@@ -280,7 +362,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
// this input data will fail to read middle way.
- val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j)
+ val input = spark.range(15).select(failingUdf(Symbol("id")).as(Symbol("i")))
+ .select(Symbol("i"), -Symbol("i") as Symbol("j"))
val e3 = intercept[SparkException] {
input.write.format(cls.getName).option("path", path).mode("overwrite").save()
}
@@ -300,11 +383,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
val numPartition = 6
- spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ spark.range(0, 10, 1, numPartition)
+ .select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j"))
+ .write.format(cls.getName)
.mode("append").option("path", path).save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
- spark.range(10).select('id, -'id))
+ spark.range(10).select(Symbol("id"), -Symbol("id")))
assert(SimpleCounter.getCounter == numPartition,
"method onDataWriterCommit should be called as many as the number of partitions")
@@ -321,7 +406,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
test("SPARK-23301: column pruning with arbitrary expressions") {
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
- val q1 = df.select('i + 1)
+ val q1 = df.select(Symbol("i") + 1)
checkAnswer(q1, (1 until 11).map(i => Row(i)))
val batch1 = getBatch(q1)
assert(batch1.requiredSchema.fieldNames === Seq("i"))
@@ -332,14 +417,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
assert(batch2.requiredSchema.isEmpty)
// 'j === 1 can't be pushed down, but we should still be able do column pruning
- val q3 = df.filter('j === -1).select('j * 2)
+ val q3 = df.filter(Symbol("j") === -1).select(Symbol("j") * 2)
checkAnswer(q3, Row(-2))
val batch3 = getBatch(q3)
assert(batch3.filters.isEmpty)
assert(batch3.requiredSchema.fieldNames === Seq("j"))
// column pruning should work with other operators.
- val q4 = df.sort('i).limit(1).select('i + 1)
+ val q4 = df.sort(Symbol("i")).limit(1).select(Symbol("i") + 1)
checkAnswer(q4, Row(1))
val batch4 = getBatch(q4)
assert(batch4.requiredSchema.fieldNames === Seq("i"))
@@ -361,7 +446,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
checkCanonicalizedOutput(df, 2, 2)
- checkCanonicalizedOutput(df.select('i), 2, 1)
+ checkCanonicalizedOutput(df.select(Symbol("i")), 2, 1)
}
test("SPARK-25425: extra options should override sessions options during reading") {
@@ -400,7 +485,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
withTempView("t1") {
val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load()
Seq(2, 3).toDF("a").createTempView("t1")
- val df = t2.where("i < (select max(a) from t1)").select('i)
+ val df = t2.where("i < (select max(a) from t1)").select(Symbol("i"))
val subqueries = stripAQEPlan(df.queryExecution.executedPlan).collect {
case p => p.subqueries
}.flatten
@@ -419,8 +504,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
- val q1 = df.select('i).filter('i > 6)
- val q2 = df.select('i).filter('i > 5)
+ val q1 = df.select(Symbol("i")).filter(Symbol("i") > 6)
+ val q2 = df.select(Symbol("i")).filter(Symbol("i") > 5)
val scan1 = getScanExec(q1)
val scan2 = getScanExec(q2)
assert(!scan1.equals(scan2))
@@ -433,7 +518,19 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
// before SPARK-33267 below query just threw NPE
- df.select('i).where("i in (1, null)").collect()
+ df.select(Symbol("i")).where("i in (1, null)").collect()
+ }
+ }
+ }
+
+ test("SPARK-35803: Support datasorce V2 in CREATE VIEW USING") {
+ Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls =>
+ withClue(cls.getName) {
+ sql(s"CREATE or REPLACE TEMPORARY VIEW s1 USING ${cls.getName}")
+ checkAnswer(sql("select * from s1"), (0 until 10).map(i => Row(i, -i)))
+ checkAnswer(sql("select j from s1"), (0 until 10).map(i => Row(-i)))
+ checkAnswer(sql("select * from s1 where i > 5"),
+ (6 until 10).map(i => Row(i, -i)))
}
}
}
@@ -466,7 +563,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead {
override def name(): String = this.getClass.toString
- override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ)
}
abstract class SimpleScanBuilder extends ScanBuilder
@@ -489,7 +586,7 @@ trait TestingV2Source extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
getTable(new CaseInsensitiveStringMap(properties))
}
@@ -597,6 +694,75 @@ class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType)
}
}
+class AdvancedDataSourceV2WithV2Filter extends TestingV2Source {
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new AdvancedScanBuilderWithV2Filter()
+ }
+ }
+}
+
+class AdvancedScanBuilderWithV2Filter extends ScanBuilder
+ with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns {
+
+ var requiredSchema = TestingV2Source.schema
+ var predicates = Array.empty[Predicate]
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.requiredSchema = requiredSchema
+ }
+
+ override def readSchema(): StructType = requiredSchema
+
+ override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
+ val (supported, unsupported) = predicates.partition {
+ case p: Predicate if p.name() == ">" => true
+ case _ => false
+ }
+ this.predicates = supported
+ unsupported
+ }
+
+ override def pushedPredicates(): Array[Predicate] = predicates
+
+ override def build(): Scan = this
+
+ override def toBatch: Batch = new AdvancedBatchWithV2Filter(predicates, requiredSchema)
+}
+
+class AdvancedBatchWithV2Filter(
+ val predicates: Array[Predicate],
+ val requiredSchema: StructType) extends Batch {
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val lowerBound = predicates.collectFirst {
+ case p: Predicate if p.name().equals(">") =>
+ val value = p.children()(1)
+ assert(value.isInstanceOf[Literal[_]])
+ value.asInstanceOf[Literal[_]]
+ }
+
+ val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
+
+ if (lowerBound.isEmpty) {
+ res.append(RangeInputPartition(0, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 4) {
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 9) {
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 10))
+ }
+
+ res.toArray
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ new AdvancedReaderFactory(requiredSchema)
+ }
+}
+
class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
@@ -640,7 +806,7 @@ class SchemaRequiredDataSource extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val userGivenSchema = schema
new SimpleBatchTable {
override def schema(): StructType = userGivenSchema
@@ -733,7 +899,8 @@ class PartitionAwareDataSource extends TestingV2Source {
SpecificReaderFactory
}
- override def outputPartitioning(): Partitioning = new MyPartitioning
+ override def outputPartitioning(): Partitioning =
+ new KeyGroupedPartitioning(Array(FieldReference("i")), 2)
}
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
@@ -741,18 +908,13 @@ class PartitionAwareDataSource extends TestingV2Source {
new MyScanBuilder()
}
}
-
- class MyPartitioning extends Partitioning {
- override def numPartitions(): Int = 2
-
- override def satisfy(distribution: Distribution): Boolean = distribution match {
- case c: ClusteredDistribution => c.clusteredColumns.contains("i")
- case _ => false
- }
- }
}
-case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition
+case class SpecificInputPartition(
+ i: Array[Int],
+ j: Array[Int]) extends InputPartition with HasPartitionKey {
+ override def partitionKey(): InternalRow = InternalRow.fromSeq(Seq(i(0)))
+}
object SpecificReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
new file mode 100644
index 0000000000..a2cfdde267
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
@@ -0,0 +1,629 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import java.util.Collections
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog}
+import org.apache.spark.sql.connector.expressions.LogicalExpressions._
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.QueryExecutionListener
+
+abstract class DeleteFromTableSuiteBase
+ extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper {
+
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+ import testImplicits._
+
+ before {
+ spark.conf.set("spark.sql.catalog.cat", classOf[InMemoryRowLevelOperationTableCatalog].getName)
+ }
+
+ after {
+ spark.sessionState.catalogManager.reset()
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.cat")
+ }
+
+ private val namespace = Array("ns1")
+ private val ident = Identifier.of(namespace, "test_table")
+ private val tableNameAsString = "cat." + ident.toString
+
+ private def catalog: InMemoryRowLevelOperationTableCatalog = {
+ val catalog = spark.sessionState.catalogManager.catalog("cat")
+ catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+ }
+
+ test("EXPLAIN only delete") {
+ createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""")
+
+ sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE id <= 10")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Nil)
+ }
+
+ test("delete from empty tables") {
+ createTable("id INT, dep STRING")
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with basic filters") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "software") :: Row(3, "hr") :: Nil)
+ }
+
+ test("delete with aliases") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString AS t WHERE t.id <= 1 OR t.dep = 'hr'")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "software") :: Nil)
+ }
+
+ test("delete with IN predicates") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, null)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "software") :: Row(null, "hr") :: Nil)
+ }
+
+ test("delete with NOT IN predicates") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (null, 1)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "software") :: Row(null, "hr") :: Nil)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (1, 10)")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(null, "hr") :: Nil)
+ }
+
+ test("delete with conditions on nested columns") {
+ createAndInitTable("id INT, complex STRUCT<c1:INT,c2:STRING>, dep STRING",
+ """{ "id": 1, "complex": { "c1": 3, "c2": "v1" }, "dep": "hr" }
+ |{ "id": 2, "complex": { "c1": 2, "c2": "v2" }, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"DELETE FROM $tableNameAsString WHERE complex.c1 = id + 2")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, Row(2, "v2"), "software") :: Nil)
+
+ sql(s"DELETE FROM $tableNameAsString t WHERE t.complex.c1 = id")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with IN subqueries") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(0), Some(1), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IN (SELECT * FROM deleted_id)
+ | AND
+ | dep IN (SELECT * FROM deleted_dep)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": -1, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IS NULL
+ | OR
+ | id IN (SELECT value + 2 FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": null, "dep": "hr" }
+ |{ "id": 2, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hr") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id IN (SELECT value + 2 FROM deleted_id)
+ | AND
+ | dep = 'hr'
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(-1, "hr") :: Row(1, "hr") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete with multi-column IN subqueries") {
+ withTempView("deleted_employee") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedEmployeeDF = Seq((None, "hr"), (Some(1), "hr")).toDF()
+ deletedEmployeeDF.createOrReplaceTempView("deleted_employee")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | (id, dep) IN (SELECT * FROM deleted_employee)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete with NOT IN subqueries") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(null, "hr") :: Nil)
+
+ append("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id)
+ | OR
+ | dep IN ('software', 'hr')
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ | AND
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE dep = deleted_dep.value)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | t.id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)
+ | OR
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+ }
+
+ test("delete with EXISTS subquery") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value)
+ | AND
+ | EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Nil)
+ }
+ }
+
+ test("delete with NOT EXISTS subquery") {
+ withTempView("deleted_id", "deleted_dep") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val deletedDepDF = Seq("software", "hr").toDF()
+ deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2)
+ | AND
+ | NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(1, "hr") :: Row(null, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(1, "hr") :: Nil)
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)
+ | OR
+ | t.id = 1
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+ }
+
+ test("delete with a scalar subquery") {
+ withTempView("deleted_id") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": null, "dep": "hr" }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ sql(
+ s"""DELETE FROM $tableNameAsString t
+ |WHERE
+ | id <= (SELECT min(value) FROM deleted_id)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(null, "hr") :: Nil)
+ }
+ }
+
+ test("delete refreshes relation cache") {
+ withTempView("temp") {
+ withCache("temp") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 1, "dep": "hardware" }
+ |{ "id": 2, "dep": "hardware" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ // define a view on top of the table
+ val query = sql(s"SELECT * FROM $tableNameAsString WHERE id = 1")
+ query.createOrReplaceTempView("temp")
+
+ // cache the view
+ sql("CACHE TABLE temp")
+
+ // verify the view returns expected results
+ checkAnswer(
+ sql("SELECT * FROM temp"),
+ Row(1, "hr") :: Row(1, "hardware") :: Nil)
+
+ // delete some records from the table
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1")
+
+ // verify the delete was successful
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, "hardware") :: Row(3, "hr") :: Nil)
+
+ // verify the view reflects the changes in the table
+ checkAnswer(sql("SELECT * FROM temp"), Nil)
+ }
+ }
+ }
+
+ test("delete with nondeterministic conditions") {
+ createAndInitTable("id INT, dep STRING",
+ """{ "id": 1, "dep": "hr" }
+ |{ "id": 2, "dep": "software" }
+ |{ "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5")
+ }
+ assert(e.message.contains("nondeterministic expressions are only allowed"))
+ }
+
+ test("delete without condition executed as delete with filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString")
+
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil)
+ }
+
+ test("delete with supported predicates gets converted into delete with filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString WHERE dep = 100")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Nil)
+ }
+
+ test("delete with unsupported predicates cannot be converted into delete with filters") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ executeDeleteWithRewrite(s"DELETE FROM $tableNameAsString WHERE dep = 100 OR dep < 200")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Nil)
+ }
+
+ test("delete with subquery cannot be converted into delete with filters") {
+ withTempView("deleted_id") {
+ createAndInitTable("id INT, dep INT",
+ """{ "id": 1, "dep": 100 }
+ |{ "id": 2, "dep": 200 }
+ |{ "id": 3, "dep": 100 }
+ |""".stripMargin)
+
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF()
+ deletedIdDF.createOrReplaceTempView("deleted_id")
+
+ val q = s"DELETE FROM $tableNameAsString WHERE dep = 100 AND id IN (SELECT * FROM deleted_id)"
+ executeDeleteWithRewrite(q)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Row(2, 200) :: Row(3, 100) :: Nil)
+ }
+ }
+
+ private def createTable(schemaString: String): Unit = {
+ val schema = StructType.fromDDL(schemaString)
+ val tableProps = Collections.emptyMap[String, String]
+ catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps)
+ }
+
+ private def createAndInitTable(schemaString: String, jsonData: String): Unit = {
+ createTable(schemaString)
+ append(schemaString, jsonData)
+ }
+
+ private def append(schemaString: String, jsonData: String): Unit = {
+ val df = toDF(jsonData, schemaString)
+ df.coalesce(1).writeTo(tableNameAsString).append()
+ }
+
+ private def toDF(jsonData: String, schemaString: String = null): DataFrame = {
+ val jsonRows = jsonData.split("\\n").filter(str => str.trim.nonEmpty)
+ val jsonDS = spark.createDataset(jsonRows)(Encoders.STRING)
+ if (schemaString == null) {
+ spark.read.json(jsonDS)
+ } else {
+ spark.read.schema(schemaString).json(jsonDS)
+ }
+ }
+
+ private def executeDeleteWithFilters(query: String): Unit = {
+ val executedPlan = executeAndKeepPlan {
+ sql(query)
+ }
+
+ executedPlan match {
+ case _: DeleteFromTableExec =>
+ // OK
+ case other =>
+ fail("unexpected executed plan: " + other)
+ }
+ }
+
+ private def executeDeleteWithRewrite(query: String): Unit = {
+ val executedPlan = executeAndKeepPlan {
+ sql(query)
+ }
+
+ executedPlan match {
+ case _: ReplaceDataExec =>
+ // OK
+ case other =>
+ fail("unexpected executed plan: " + other)
+ }
+ }
+
+ // executes an operation and keeps the executed plan
+ private def executeAndKeepPlan(func: => Unit): SparkPlan = {
+ var executedPlan: SparkPlan = null
+
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ executedPlan = qe.executedPlan
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ }
+ }
+ spark.listenerManager.register(listener)
+
+ func
+
+ sparkContext.listenerBus.waitUntilEmpty()
+
+ stripAQEPlan(executedPlan)
+ }
+}
+
+class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
new file mode 100644
index 0000000000..f4317e6327
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.connector.catalog.InMemoryCatalog
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.test.SharedSparkSession
+
+abstract class DistributionAndOrderingSuiteBase
+ extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
+ }
+
+ override def afterAll(): Unit = {
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat")
+ super.afterAll()
+ }
+
+ protected val resolver: Resolver = conf.resolver
+
+ protected def resolvePartitioning[T <: QueryPlan[T]](
+ partitioning: Partitioning,
+ plan: QueryPlan[T]): Partitioning = partitioning match {
+ case HashPartitioning(exprs, numPartitions) =>
+ HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
+ case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) =>
+ KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions,
+ partitionValues)
+ case PartitioningCollection(partitionings) =>
+ PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
+ case RangePartitioning(ordering, numPartitions) =>
+ RangePartitioning(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder]), numPartitions)
+ case p @ SinglePartition =>
+ p
+ case p: UnknownPartitioning =>
+ p
+ case p =>
+ fail(s"unexpected partitioning: $p")
+ }
+
+ protected def resolveDistribution[T <: QueryPlan[T]](
+ distribution: physical.Distribution,
+ plan: QueryPlan[T]): physical.Distribution = distribution match {
+ case physical.ClusteredDistribution(clustering, numPartitions, _) =>
+ physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)), numPartitions)
+ case physical.OrderedDistribution(ordering) =>
+ physical.OrderedDistribution(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder]))
+ case physical.UnspecifiedDistribution =>
+ physical.UnspecifiedDistribution
+ case d =>
+ fail(s"unexpected distribution: $d")
+ }
+
+ protected def resolveAttrs[T <: QueryPlan[T]](
+ expr: catalyst.expressions.Expression,
+ plan: QueryPlan[T]): catalyst.expressions.Expression = {
+
+ expr.transform {
+ case UnresolvedAttribute(Seq(attrName)) =>
+ plan.output.find(attr => resolver(attr.name, attrName)).get
+ case UnresolvedAttribute(nameParts) =>
+ val attrName = nameParts.mkString(".")
+ fail(s"cannot resolve a nested attr: $attrName")
+ }
+ }
+
+ protected def attr(name: String): UnresolvedAttribute = {
+ UnresolvedAttribute(name)
+ }
+
+ protected def catalog: InMemoryCatalog = {
+ val catalog = spark.sessionState.catalogManager.catalog("testcat")
+ catalog.asTableCatalog.asInstanceOf[InMemoryCatalog]
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index be0dae2563..cfc8b2cc84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -16,7 +16,6 @@
*/
package org.apache.spark.sql.connector
-import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkConf
@@ -56,7 +55,7 @@ class DummyReadOnlyFileTable extends Table with SupportsRead {
}
override def capabilities(): java.util.Set[TableCapability] =
- Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava
+ java.util.EnumSet.of(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA)
}
class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 {
@@ -79,7 +78,7 @@ class DummyWriteOnlyFileTable extends Table with SupportsWrite {
throw new AnalysisException("Dummy file writer")
override def capabilities(): java.util.Set[TableCapability] =
- Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava
+ java.util.EnumSet.of(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA)
}
class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
@@ -185,7 +184,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession {
val df = spark.read.format(format).load(path.getCanonicalPath)
checkAnswer(df, inputData.toDF())
assert(
- df.queryExecution.executedPlan.find(_.isInstanceOf[FileSourceScanExec]).isDefined)
+ df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec]))
}
} finally {
spark.listenerManager.unregister(listener)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
index 0dee48fbb5..85904bbf12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
@@ -259,7 +259,7 @@ trait InsertIntoSQLOnlyTests
verifyTable(t1, spark.emptyDataFrame)
assert(exc.getMessage.contains(
- "PARTITION clause cannot contain a non-partition column name"))
+ "PARTITION clause cannot contain the non-partition column"))
assert(exc.getMessage.contains("id"))
assert(exc.getErrorClass == "NON_PARTITION_COLUMN")
}
@@ -276,28 +276,12 @@ trait InsertIntoSQLOnlyTests
verifyTable(t1, spark.emptyDataFrame)
assert(exc.getMessage.contains(
- "PARTITION clause cannot contain a non-partition column name"))
+ "PARTITION clause cannot contain the non-partition column"))
assert(exc.getMessage.contains("data"))
assert(exc.getErrorClass == "NON_PARTITION_COLUMN")
}
}
- test("InsertInto: IF PARTITION NOT EXISTS not supported") {
- val t1 = s"${catalogAndNamespace}tbl"
- withTableAndData(t1) { view =>
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)")
-
- val exc = intercept[AnalysisException] {
- sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view")
- }
-
- verifyTable(t1, spark.emptyDataFrame)
- assert(exc.getMessage.contains("Cannot write, IF NOT EXISTS is not supported for table"))
- assert(exc.getMessage.contains(t1))
- assert(exc.getErrorClass == "IF_PARTITION_NOT_EXISTS_UNSUPPORTED")
- }
- }
-
test("InsertInto: overwrite - dynamic clause - static mode") {
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) {
val t1 = s"${catalogAndNamespace}tbl"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
new file mode 100644
index 0000000000..bdbf309214
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector
+
+import java.util.Collections
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.TransformExpression
+import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
+import org.apache.spark.sql.connector.catalog.functions._
+import org.apache.spark.sql.connector.distributions.Distributions
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.connector.expressions.Expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf._
+import org.apache.spark.sql.types._
+
+class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
+ private var originalV2BucketingEnabled: Boolean = false
+ private var originalAutoBroadcastJoinThreshold: Long = -1
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ originalV2BucketingEnabled = conf.getConf(V2_BUCKETING_ENABLED)
+ conf.setConf(V2_BUCKETING_ENABLED, true)
+ originalAutoBroadcastJoinThreshold = conf.getConf(AUTO_BROADCASTJOIN_THRESHOLD)
+ conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, -1L)
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ super.afterAll()
+ } finally {
+ conf.setConf(V2_BUCKETING_ENABLED, originalV2BucketingEnabled)
+ conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, originalAutoBroadcastJoinThreshold)
+ }
+ }
+
+ before {
+ Seq(UnboundYearsFunction, UnboundDaysFunction, UnboundBucketFunction).foreach { f =>
+ catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
+ }
+ }
+
+ after {
+ catalog.clearTables()
+ catalog.clearFunctions()
+ }
+
+ private val emptyProps: java.util.Map[String, String] = {
+ Collections.emptyMap[String, String]
+ }
+ private val table: String = "tbl"
+ private val schema = new StructType()
+ .add("id", IntegerType)
+ .add("data", StringType)
+ .add("ts", TimestampType)
+
+ test("clustered distribution: output partitioning should be KeyGroupedPartitioning") {
+ val partitions: Array[Transform] = Array(Expressions.years("ts"))
+
+ // create a table with 3 partitions, partitioned by `years` transform
+ createTable(table, schema, partitions)
+ sql(s"INSERT INTO testcat.ns.$table VALUES " +
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+ var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts")
+ val catalystDistribution = physical.ClusteredDistribution(
+ Seq(TransformExpression(YearsFunction, Seq(attr("ts")))))
+ val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v)))
+
+ checkQueryPlan(df, catalystDistribution,
+ physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues))
+
+ // multiple group keys should work too as long as partition keys are subset of them
+ df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts")
+ checkQueryPlan(df, catalystDistribution,
+ physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues))
+ }
+
+ test("non-clustered distribution: no partition") {
+ val partitions: Array[Transform] = Array(bucket(32, "ts"))
+ createTable(table, schema, partitions)
+
+ val df = sql(s"SELECT * FROM testcat.ns.$table")
+ val distribution = physical.ClusteredDistribution(
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ }
+
+ test("non-clustered distribution: single partition") {
+ val partitions: Array[Transform] = Array(bucket(32, "ts"))
+ createTable(table, schema, partitions)
+ sql(s"INSERT INTO testcat.ns.$table VALUES (0, 'aaa', CAST('2020-01-01' AS timestamp))")
+
+ val df = sql(s"SELECT * FROM testcat.ns.$table")
+ val distribution = physical.ClusteredDistribution(
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+
+ checkQueryPlan(df, distribution, physical.SinglePartition)
+ }
+
+ test("non-clustered distribution: no V2 catalog") {
+ spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
+ val nonFunctionCatalog = spark.sessionState.catalogManager.catalog("testcat2")
+ .asInstanceOf[InMemoryTableCatalog]
+ val partitions: Array[Transform] = Array(bucket(32, "ts"))
+ createTable(table, schema, partitions, catalog = nonFunctionCatalog)
+ sql(s"INSERT INTO testcat2.ns.$table VALUES " +
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+ val df = sql(s"SELECT * FROM testcat2.ns.$table")
+ val distribution = physical.UnspecifiedDistribution
+
+ try {
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ } finally {
+ spark.conf.unset("spark.sql.catalog.testcat2")
+ }
+ }
+
+ test("non-clustered distribution: no V2 function provided") {
+ catalog.clearFunctions()
+
+ val partitions: Array[Transform] = Array(bucket(32, "ts"))
+ createTable(table, schema, partitions)
+ sql(s"INSERT INTO testcat.ns.$table VALUES " +
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+ val df = sql(s"SELECT * FROM testcat.ns.$table")
+ val distribution = physical.UnspecifiedDistribution
+
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ }
+
+ test("non-clustered distribution: V2 bucketing disabled") {
+ withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "false") {
+ val partitions: Array[Transform] = Array(bucket(32, "ts"))
+ createTable(table, schema, partitions)
+ sql(s"INSERT INTO testcat.ns.$table VALUES " +
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+ val df = sql(s"SELECT * FROM testcat.ns.$table")
+ val distribution = physical.ClusteredDistribution(
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ }
+ }
+
+ /**
+ * Check whether the query plan from `df` has the expected `distribution`, `ordering` and
+ * `partitioning`.
+ */
+ private def checkQueryPlan(
+ df: DataFrame,
+ distribution: physical.Distribution,
+ partitioning: physical.Partitioning): Unit = {
+ // check distribution & ordering are correctly populated in logical plan
+ val relation = df.queryExecution.optimizedPlan.collect {
+ case r: DataSourceV2ScanRelation => r
+ }.head
+
+ resolveDistribution(distribution, relation) match {
+ case physical.ClusteredDistribution(clustering, _, _) =>
+ assert(relation.keyGroupedPartitioning.isDefined &&
+ relation.keyGroupedPartitioning.get == clustering)
+ case _ =>
+ assert(relation.keyGroupedPartitioning.isEmpty)
+ }
+
+ // check distribution, ordering and output partitioning are correctly populated in physical plan
+ val scan = collect(df.queryExecution.executedPlan) {
+ case s: BatchScanExec => s
+ }.head
+
+ val expectedPartitioning = resolvePartitioning(partitioning, scan)
+ assert(expectedPartitioning == scan.outputPartitioning)
+ }
+
+ private def createTable(
+ table: String,
+ schema: StructType,
+ partitions: Array[Transform],
+ catalog: InMemoryTableCatalog = catalog): Unit = {
+ catalog.createTable(Identifier.of(Array("ns"), table),
+ schema, partitions, emptyProps, Distributions.unspecified(), Array.empty, None)
+ }
+
+ private val customers: String = "customers"
+ private val customers_schema = new StructType()
+ .add("customer_name", StringType)
+ .add("customer_age", IntegerType)
+ .add("customer_id", LongType)
+
+ private val orders: String = "orders"
+ private val orders_schema = new StructType()
+ .add("order_amount", DoubleType)
+ .add("customer_id", LongType)
+
+ private def testWithCustomersAndOrders(
+ customers_partitions: Array[Transform],
+ orders_partitions: Array[Transform],
+ expectedNumOfShuffleExecs: Int): Unit = {
+ createTable(customers, customers_schema, customers_partitions)
+ sql(s"INSERT INTO testcat.ns.$customers VALUES " +
+ s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")
+
+ createTable(orders, orders_schema, orders_partitions)
+ sql(s"INSERT INTO testcat.ns.$orders VALUES " +
+ s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)")
+
+ val df = sql("SELECT customer_name, customer_age, order_amount " +
+ s"FROM testcat.ns.$customers c JOIN testcat.ns.$orders o " +
+ "ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.length == expectedNumOfShuffleExecs)
+
+ checkAnswer(df,
+ Seq(Row("aaa", 10, 100.0), Row("aaa", 10, 200.0), Row("bbb", 20, 150.0),
+ Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
+ }
+
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
+ // here we skip collecting shuffle operators that are not associated with SMJ
+ collect(plan) {
+ case s: SortMergeJoinExec => s
+ }.flatMap(smj =>
+ collect(smj) {
+ case s: ShuffleExchangeExec => s
+ })
+ }
+
+ test("partitioned join: exact distribution (same number of buckets) from both sides") {
+ val customers_partitions = Array(bucket(4, "customer_id"))
+ val orders_partitions = Array(bucket(4, "customer_id"))
+
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 0)
+ }
+
+ test("partitioned join: number of buckets mismatch should trigger shuffle") {
+ val customers_partitions = Array(bucket(4, "customer_id"))
+ val orders_partitions = Array(bucket(2, "customer_id"))
+
+ // should shuffle both sides when number of buckets are not the same
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 2)
+ }
+
+ test("partitioned join: only one side reports partitioning") {
+ val customers_partitions = Array(bucket(4, "customer_id"))
+ val orders_partitions = Array(bucket(2, "customer_id"))
+
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 2)
+ }
+
+ private val items: String = "items"
+ private val items_schema: StructType = new StructType()
+ .add("id", LongType)
+ .add("name", StringType)
+ .add("price", FloatType)
+ .add("arrive_time", TimestampType)
+
+ private val purchases: String = "purchases"
+ private val purchases_schema: StructType = new StructType()
+ .add("item_id", LongType)
+ .add("price", FloatType)
+ .add("time", TimestampType)
+
+ test("partitioned join: join with two partition keys and matching & sorted partitions") {
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+ createTable(items, items_schema, items_partitions)
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+ createTable(purchases, purchases_schema, purchases_partitions)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+ s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
+ checkAnswer(df,
+ Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0),
+ Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5))
+ )
+ }
+
+ test("partitioned join: join with two partition keys and unsorted partitions") {
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+ createTable(items, items_schema, items_partitions)
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp))")
+
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+ createTable(purchases, purchases_schema, purchases_partitions)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+ s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
+ checkAnswer(df,
+ Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0),
+ Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5))
+ )
+ }
+
+ test("partitioned join: join with two partition keys and different # of partition keys") {
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+ createTable(purchases, purchases_schema, purchases_partitions)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 11.0, cast('2020-01-01' as timestamp))")
+
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.nonEmpty, "should add shuffle when partition keys mismatch")
+ }
+
+ test("data source partitioning + dynamic partition filtering") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
+ SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
+ SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ val purchases_partitions = Array(identity("item_id"))
+ createTable(purchases, purchases_schema, purchases_partitions)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+ s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ // number of unique partitions changed after dynamic filtering - should throw exception
+ var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " +
+ s"i.id = p.item_id AND i.price > 40.0")
+ val e = intercept[Exception](df.collect())
+ assert(e.getMessage.contains("number of unique partition values"))
+
+ // dynamic filtering doesn't change partitioning so storage-partitioned join should kick in
+ df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " +
+ s"i.id = p.item_id AND i.price >= 10.0")
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
+ checkAnswer(df, Seq(Row(303.5)))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
index db71eeb75e..e3d61a846f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -63,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val table = new TestLocalScanTable(ident.toString)
tables.put(ident, table)
table
@@ -78,7 +74,8 @@ object TestLocalScanTable {
class TestLocalScanTable(override val name: String) extends Table with SupportsRead {
override def schema(): StructType = TestLocalScanTable.schema
- override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new TestLocalScanBuilder
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
new file mode 100644
index 0000000000..7f0e74f6bc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.functions.struct
+
+class MetadataColumnSuite extends DatasourceV2SQLBase {
+ import testImplicits._
+
+ private val tbl = "testcat.t"
+
+ private def prepareTable(): Unit = {
+ sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (bucket(4, id), id)")
+ sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ }
+
+ test("SPARK-31255: Project a metadata column") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl")
+ val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ }
+ }
+ }
+
+ test("SPARK-31255: Projects data column when metadata column has the same name") {
+ withTable(tbl) {
+ sql(s"CREATE TABLE $tbl (index bigint, data string) PARTITIONED BY (bucket(4, index), index)")
+ sql(s"INSERT INTO $tbl VALUES (3, 'c'), (2, 'b'), (1, 'a')")
+
+ val sqlQuery = sql(s"SELECT index, data, _partition FROM $tbl")
+ val dfQuery = spark.table(tbl).select("index", "data", "_partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
+ }
+ }
+ }
+
+ test("SPARK-31255: * expansion does not include metadata columns") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(s"SELECT * FROM $tbl")
+ val dfQuery = spark.table(tbl)
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c")))
+ }
+ }
+ }
+
+ test("SPARK-31255: metadata column should only be produced when necessary") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(s"SELECT * FROM $tbl WHERE index = 0")
+ val dfQuery = spark.table(tbl).filter("index = 0")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ assert(query.schema.fieldNames.toSeq == Seq("id", "data"))
+ }
+ }
+ }
+
+ test("SPARK-34547: metadata columns are resolved last") {
+ withTable(tbl) {
+ prepareTable()
+ withTempView("v") {
+ sql(s"CREATE TEMPORARY VIEW v AS SELECT * FROM " +
+ s"VALUES (1, -1), (2, -2), (3, -3) AS v(id, index)")
+
+ val sqlQuery = sql(s"SELECT $tbl.id, v.id, data, index, $tbl.index, v.index " +
+ s"FROM $tbl JOIN v WHERE $tbl.id = v.id")
+ val tableDf = spark.table(tbl)
+ val viewDf = spark.table("v")
+ val dfQuery = tableDf.join(viewDf, tableDf.col("id") === viewDf.col("id"))
+ .select(s"$tbl.id", "v.id", "data", "index", s"$tbl.index", "v.index")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query,
+ Seq(
+ Row(1, 1, "a", -1, 0, -1),
+ Row(2, 2, "b", -2, 0, -2),
+ Row(3, 3, "c", -3, 0, -3)
+ )
+ )
+ }
+ }
+ }
+ }
+
+ test("SPARK-34555: Resolve DataFrame metadata column") {
+ withTable(tbl) {
+ prepareTable()
+ val table = spark.table(tbl)
+ val dfQuery = table.select(
+ table.col("id"),
+ table.col("data"),
+ table.col("index"),
+ table.col("_partition")
+ )
+
+ checkAnswer(
+ dfQuery,
+ Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
+ )
+ }
+ }
+
+ test("SPARK-34923: propagate metadata columns through Project") {
+ withTable(tbl) {
+ prepareTable()
+ checkAnswer(
+ spark.table(tbl).select("id", "data").select("index", "_partition"),
+ Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
+ )
+ }
+ }
+
+ test("SPARK-34923: do not propagate metadata columns through View") {
+ val view = "view"
+ withTable(tbl) {
+ withTempView(view) {
+ prepareTable()
+ sql(s"CACHE TABLE $view AS SELECT * FROM $tbl")
+ assertThrows[AnalysisException] {
+ sql(s"SELECT index, _partition FROM $view")
+ }
+ }
+ }
+ }
+
+ test("SPARK-34923: propagate metadata columns through Filter") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl WHERE id > 1")
+ val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ }
+ }
+ }
+
+ test("SPARK-34923: propagate metadata columns through Sort") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl ORDER BY id")
+ val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ }
+ }
+ }
+
+ test("SPARK-34923: propagate metadata columns through RepartitionBy") {
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(
+ s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $tbl")
+ val dfQuery = spark.table(tbl).repartitionByRange(3, $"id")
+ .select("id", "data", "index", "_partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ }
+ }
+ }
+
+ test("SPARK-34923: propagate metadata columns through SubqueryAlias if child is leaf node") {
+ val sbq = "sbq"
+ withTable(tbl) {
+ prepareTable()
+ val sqlQuery = sql(
+ s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $tbl $sbq")
+ val dfQuery = spark.table(tbl).as(sbq).select(
+ s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
+
+ Seq(sqlQuery, dfQuery).foreach { query =>
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
+ }
+
+ assertThrows[AnalysisException] {
+ sql(s"SELECT $sbq.index FROM (SELECT id FROM $tbl) $sbq")
+ }
+ assertThrows[AnalysisException] {
+ spark.table(tbl).select($"id").as(sbq).select(s"$sbq.index")
+ }
+ }
+ }
+
+ test("SPARK-40149: select outer join metadata columns with DataFrame API") {
+ val df1 = Seq(1 -> "a").toDF("k", "v").as("left")
+ val df2 = Seq(1 -> "b").toDF("k", "v").as("right")
+ val dfQuery = df1.join(df2, Seq("k"), "outer")
+ .withColumn("left_all", struct($"left.*"))
+ .withColumn("right_all", struct($"right.*"))
+ checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b")))
+ }
+
+ test("SPARK-40429: Only set KeyGroupedPartitioning when the referenced column is in the output") {
+ withTable(tbl) {
+ sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (id)")
+ sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ checkAnswer(
+ spark.table(tbl).select("index", "_partition"),
+ Seq(Row(0, "3"), Row(0, "2"), Row(0, "1"))
+ )
+
+ checkAnswer(
+ spark.table(tbl).select("id", "index", "_partition"),
+ Seq(Row(3, 0, "3"), Row(2, 0, "2"), Row(1, 0, "1"))
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
index bb2acecc78..64c893ed74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.connector
import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.util
import scala.collection.JavaConverters._
@@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source {
new MyWriteBuilder(path, info)
}
- override def capabilities(): util.Set[TableCapability] =
- Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)
}
override def getTable(options: CaseInsensitiveStringMap): Table = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index 076dad7530..8d771b0736 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -17,15 +17,19 @@
package org.apache.spark.sql.connector
+import java.util.Optional
+
+import scala.concurrent.duration.MICROSECONDS
import scala.language.implicitConversions
import scala.util.Try
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
-import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsCatalogOptions, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
@@ -35,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener}
+import org.apache.spark.unsafe.types.UTF8String
class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter {
@@ -71,7 +76,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
saveMode: SaveMode,
withCatalogOption: Option[String],
partitionBy: Seq[String]): Unit = {
- val df = spark.range(10).withColumn("part", 'id % 5)
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5)
val dfw = df.write.format(format).mode(saveMode).option("name", "t1")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.partitionBy(partitionBy: _*).save()
@@ -136,7 +141,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
test("Ignore mode if table exists - session catalog") {
sql(s"create table t1 (id bigint) using $format")
- val df = spark.range(10).withColumn("part", 'id % 5)
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5)
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
dfw.save()
@@ -148,7 +153,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
test("Ignore mode if table exists - testcat catalog") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
- val df = spark.range(10).withColumn("part", 'id % 5)
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5)
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
dfw.option("catalog", catalogName).save()
@@ -271,6 +276,69 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}
}
+ test("time travel") {
+ // The testing in-memory table simply append the version/timestamp to the table name when
+ // looking up tables.
+ val t1 = s"$catalogName.tSnapshot123456789"
+ val t2 = s"$catalogName.t2345678910"
+ withTable(t1, t2) {
+ sql(s"create table $t1 (id bigint) using $format")
+ sql(s"create table $t2 (id bigint) using $format")
+
+ val df1 = spark.range(10)
+ df1.write.format(format).option("name", "tSnapshot123456789").option("catalog", catalogName)
+ .mode(SaveMode.Append).save()
+
+ val df2 = spark.range(10, 20)
+ df2.write.format(format).option("name", "t2345678910").option("catalog", catalogName)
+ .mode(SaveMode.Overwrite).save()
+
+ // load with version
+ checkAnswer(load("t", Some(catalogName), version = Some("Snapshot123456789")), df1.toDF())
+ checkAnswer(load("t", Some(catalogName), version = Some("2345678910")), df2.toDF())
+ }
+
+ val ts1 = DateTimeUtils.stringToTimestampAnsi(
+ UTF8String.fromString("2019-01-29 00:37:58"),
+ DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
+ val ts2 = DateTimeUtils.stringToTimestampAnsi(
+ UTF8String.fromString("2021-01-29 00:37:58"),
+ DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
+ val t3 = s"$catalogName.t$ts1"
+ val t4 = s"$catalogName.t$ts2"
+ withTable(t3, t4) {
+ sql(s"create table $t3 (id bigint) using $format")
+ sql(s"create table $t4 (id bigint) using $format")
+
+ val df3 = spark.range(30, 40)
+ df3.write.format(format).option("name", s"t$ts1").option("catalog", catalogName)
+ .mode(SaveMode.Append).save()
+
+ val df4 = spark.range(50, 60)
+ df4.write.format(format).option("name", s"t$ts2").option("catalog", catalogName)
+ .mode(SaveMode.Overwrite).save()
+
+ // load with timestamp
+ checkAnswer(load("t", Some(catalogName), version = None,
+ timestamp = Some("2019-01-29 00:37:58")), df3.toDF())
+ checkAnswer(load("t", Some(catalogName), version = None,
+ timestamp = Some("2021-01-29 00:37:58")), df4.toDF())
+
+ // load with timestamp in number format
+ checkAnswer(load("t", Some(catalogName), version = None,
+ timestamp = Some(MICROSECONDS.toSeconds(ts1).toString)), df3.toDF())
+ checkAnswer(load("t", Some(catalogName), version = None,
+ timestamp = Some(MICROSECONDS.toSeconds(ts2).toString)), df4.toDF())
+ }
+
+ val e = intercept[AnalysisException] {
+ load("t", Some(catalogName), version = Some("12345678"),
+ timestamp = Some("2019-01-29 00:37:58"))
+ }
+ assert(e.getMessage
+ .contains("Cannot specify both version and timestamp when time travelling the table."))
+ }
+
private def checkV2Identifiers(
plan: LogicalPlan,
identifier: String = "t1",
@@ -281,9 +349,19 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
assert(v2.catalog.exists(_ == catalogPlugin))
}
- private def load(name: String, catalogOpt: Option[String]): DataFrame = {
+ private def load(
+ name: String,
+ catalogOpt: Option[String],
+ version: Option[String] = None,
+ timestamp: Option[String] = None): DataFrame = {
val dfr = spark.read.format(format).option("name", name)
catalogOpt.foreach(cName => dfr.option("catalog", cName))
+ if (version.nonEmpty) {
+ dfr.option("versionAsOf", version.get)
+ }
+ if (timestamp.nonEmpty) {
+ dfr.option("timestampAsOf", timestamp.get)
+ }
dfr.load()
}
@@ -312,4 +390,20 @@ class CatalogSupportingInMemoryTableProvider
override def extractCatalog(options: CaseInsensitiveStringMap): String = {
options.get("catalog")
}
+
+ override def extractTimeTravelVersion(options: CaseInsensitiveStringMap): Optional[String] = {
+ if (options.get("versionAsOf") != null) {
+ Optional.of(options.get("versionAsOf"))
+ } else {
+ Optional.empty[String]
+ }
+ }
+
+ override def extractTimeTravelTimestamp(options: CaseInsensitiveStringMap): Optional[String] = {
+ if (options.get("timestampAsOf") != null) {
+ Optional.of(options.get("timestampAsOf"))
+ } else {
+ Optional.empty[String]
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
index ce94d3b5c2..5f2e0b28ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
@@ -217,7 +213,11 @@ private case object TestRelation extends LeafNode with NamedRelation {
private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
override def name(): String = "capability_test_table"
override def schema(): StructType = TableCapabilityCheckSuite.schema
- override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava
+ override def capabilities(): java.util.Set[TableCapability] = {
+ val set = java.util.EnumSet.noneOf(classOf[TableCapability])
+ _capabilities.foreach(set.add)
+ set
+ }
}
private class TestStreamSourceProvider extends StreamSourceProvider {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index bf2749d1af..0a0aaa8021 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
@@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType
*/
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension {
- protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
+ protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
@@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): T
+ properties: java.util.Map[String, String]): T
override def loadTable(ident: Identifier): Table = {
if (tables.containsKey(ident)) {
@@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY
val propsWithLocation = if (properties.containsKey(key)) {
// Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified.
if (!properties.containsKey(TableCatalog.PROP_LOCATION)) {
- val newProps = new util.HashMap[String, String]()
+ val newProps = new java.util.HashMap[String, String]()
newProps.putAll(properties)
newProps.put(TableCatalog.PROP_LOCATION, "file:/abc")
newProps
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
index 847953e09c..c5be222645 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -106,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
// To simplify the test implementation, only support fixed schema.
if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) {
throw new UnsupportedOperationException
@@ -131,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp
override def schema(): StructType = V1ReadFallbackCatalog.schema
- override def capabilities(): util.Set[TableCapability] = {
- Set(TableCapability.BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] = {
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 7effc747ab..992c46cc6c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
+ properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = {
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
InMemoryV1Provider.tables.put(name, t)
tables.put(Identifier.of(Array("default"), name), t)
@@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback(
override val name: String,
override val schema: StructType,
override val partitioning: Array[Transform],
- override val properties: util.Map[String, String])
+ override val properties: java.util.Map[String, String])
extends Table
with SupportsWrite with SupportsRead {
@@ -331,11 +329,11 @@ class InMemoryTableWithV1Fallback(
}
}
- override def capabilities: util.Set[TableCapability] = Set(
+ override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of(
TableCapability.BATCH_READ,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
- TableCapability.TRUNCATE).asJava
+ TableCapability.TRUNCATE)
@volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty
private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala
index f262cf152c..ea30a6f25c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition}
-import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedDBObjectName, UnresolvedFieldName, UnresolvedFieldPosition}
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, TableSpec}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
@@ -46,12 +46,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
Seq("ID", "iD").foreach { ref =>
+ val tableSpec = TableSpec(Map.empty, None, Map.empty,
+ None, None, None, false)
val plan = CreateTableAsSelect(
- catalog,
- Identifier.of(Array(), "table_name"),
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false),
Expressions.identity(ref) :: Nil,
TestRelation2,
- Map.empty,
+ tableSpec,
Map.empty,
ignoreIfExists = false)
@@ -69,12 +70,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref =>
+ val tableSpec = TableSpec(Map.empty, None, Map.empty,
+ None, None, None, false)
val plan = CreateTableAsSelect(
- catalog,
- Identifier.of(Array(), "table_name"),
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false),
Expressions.bucket(4, ref) :: Nil,
TestRelation2,
- Map.empty,
+ tableSpec,
Map.empty,
ignoreIfExists = false)
@@ -93,12 +95,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
Seq("ID", "iD").foreach { ref =>
+ val tableSpec = TableSpec(Map.empty, None, Map.empty,
+ None, None, None, false)
val plan = ReplaceTableAsSelect(
- catalog,
- Identifier.of(Array(), "table_name"),
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false),
Expressions.identity(ref) :: Nil,
TestRelation2,
- Map.empty,
+ tableSpec,
Map.empty,
orCreate = true)
@@ -116,12 +119,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref =>
+ val tableSpec = TableSpec(Map.empty, None, Map.empty,
+ None, None, None, false)
val plan = ReplaceTableAsSelect(
- catalog,
- Identifier.of(Array(), "table_name"),
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false),
Expressions.bucket(4, ref) :: Nil,
TestRelation2,
- Map.empty,
+ tableSpec,
Map.empty,
orCreate = true)
@@ -273,10 +277,21 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
test("AlterTable: drop column resolution") {
Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref =>
- alterTableTest(
- DropColumns(table, Seq(UnresolvedFieldName(ref))),
- Seq("Missing field " + ref.quoted)
- )
+ Seq(true, false).foreach { ifExists =>
+ val expectedErrors = if (ifExists) {
+ Seq.empty[String]
+ } else {
+ Seq("Missing field " + ref.quoted)
+ }
+ val alter = DropColumns(table, Seq(UnresolvedFieldName(ref)), ifExists)
+ if (ifExists) {
+ // using IF EXISTS will silence all errors for missing columns
+ assertAnalysisSuccess(alter, caseSensitive = true)
+ assertAnalysisSuccess(alter, caseSensitive = false)
+ } else {
+ alterTableTest(alter, expectedErrors, expectErrorOnCaseSensitive = true)
+ }
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index db4a9c153c..36efe5ec1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -19,39 +19,31 @@ package org.apache.spark.sql.connector
import java.util.Collections
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning}
-import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
+import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder}
import org.apache.spark.sql.connector.expressions.LogicalExpressions._
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
-class WriteDistributionAndOrderingSuite
- extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper {
-
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
- before {
- spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
- }
+class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase {
+ import testImplicits._
after {
spark.sessionState.catalogManager.reset()
- spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat")
}
+ private val microBatchPrefix = "micro_batch_"
private val namespace = Array("ns1")
private val ident = Identifier.of(namespace, "test_table")
private val tableNameAsString = "testcat." + ident.toString
@@ -60,8 +52,6 @@ class WriteDistributionAndOrderingSuite
.add("id", IntegerType)
.add("data", StringType)
- private val resolver = conf.resolver
-
test("ordered distribution and sort with same exprs: append") {
checkOrderedDistributionAndSortWithSameExprs("append")
}
@@ -74,6 +64,18 @@ class WriteDistributionAndOrderingSuite
checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic")
}
+ test("ordered distribution and sort with same exprs: micro-batch append") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append")
+ }
+
+ test("ordered distribution and sort with same exprs: micro-batch update") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update")
+ }
+
+ test("ordered distribution and sort with same exprs: micro-batch complete") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete")
+ }
+
test("ordered distribution and sort with same exprs with numPartitions: append") {
checkOrderedDistributionAndSortWithSameExprs("append", Some(10))
}
@@ -86,6 +88,18 @@ class WriteDistributionAndOrderingSuite
checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic", Some(10))
}
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch append") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10))
+ }
+
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch update") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10))
+ }
+
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch complete") {
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10))
+ }
+
private def checkOrderedDistributionAndSortWithSameExprs(command: String): Unit = {
checkOrderedDistributionAndSortWithSameExprs(command, None)
}
@@ -129,6 +143,18 @@ class WriteDistributionAndOrderingSuite
checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic")
}
+ test("clustered distribution and sort with same exprs: micro-batch append") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append")
+ }
+
+ test("clustered distribution and sort with same exprs: micro-batch update") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update")
+ }
+
+ test("clustered distribution and sort with same exprs: micro-batch complete") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete")
+ }
+
test("clustered distribution and sort with same exprs with numPartitions: append") {
checkClusteredDistributionAndSortWithSameExprs("append", Some(10))
}
@@ -141,6 +167,18 @@ class WriteDistributionAndOrderingSuite
checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic", Some(10))
}
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch append") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10))
+ }
+
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch update") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10))
+ }
+
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch complete") {
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10))
+ }
+
private def checkClusteredDistributionAndSortWithSameExprs(command: String): Unit = {
checkClusteredDistributionAndSortWithSameExprs(command, None)
}
@@ -193,6 +231,18 @@ class WriteDistributionAndOrderingSuite
checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic")
}
+ test("clustered distribution and sort with extended exprs: micro-batch append") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append")
+ }
+
+ test("clustered distribution and sort with extended exprs: micro-batch update") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update")
+ }
+
+ test("clustered distribution and sort with extended exprs: micro-batch complete") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete")
+ }
+
test("clustered distribution and sort with extended exprs with numPartitions: append") {
checkClusteredDistributionAndSortWithExtendedExprs("append", Some(10))
}
@@ -206,6 +256,21 @@ class WriteDistributionAndOrderingSuite
checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic", Some(10))
}
+ test("clustered distribution and sort with extended exprs with numPartitions: " +
+ "micro-batch append") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append", Some(10))
+ }
+
+ test("clustered distribution and sort with extended exprs with numPartitions: " +
+ "micro-batch update") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update", Some(10))
+ }
+
+ test("clustered distribution and sort with extended exprs with numPartitions: " +
+ "micro-batch complete") {
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete", Some(10))
+ }
+
private def checkClusteredDistributionAndSortWithExtendedExprs(command: String): Unit = {
checkClusteredDistributionAndSortWithExtendedExprs(command, None)
}
@@ -258,6 +323,18 @@ class WriteDistributionAndOrderingSuite
checkUnspecifiedDistributionAndLocalSort("overwriteDynamic")
}
+ test("unspecified distribution and local sort: micro-batch append") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append")
+ }
+
+ test("unspecified distribution and local sort: micro-batch update") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update")
+ }
+
+ test("unspecified distribution and local sort: micro-batch complete") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete")
+ }
+
test("unspecified distribution and local sort with numPartitions: append") {
checkUnspecifiedDistributionAndLocalSort("append", Some(10))
}
@@ -270,6 +347,18 @@ class WriteDistributionAndOrderingSuite
checkUnspecifiedDistributionAndLocalSort("overwriteDynamic", Some(10))
}
+ test("unspecified distribution and local sort with numPartitions: micro-batch append") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append", Some(10))
+ }
+
+ test("unspecified distribution and local sort with numPartitions: micro-batch update") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update", Some(10))
+ }
+
+ test("unspecified distribution and local sort with numPartitions: micro-batch complete") {
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete", Some(10))
+ }
+
private def checkUnspecifiedDistributionAndLocalSort(command: String): Unit = {
checkUnspecifiedDistributionAndLocalSort(command, None)
}
@@ -316,6 +405,18 @@ class WriteDistributionAndOrderingSuite
checkUnspecifiedDistributionAndNoSort("overwriteDynamic")
}
+ test("unspecified distribution and no sort: micro-batch append") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append")
+ }
+
+ test("unspecified distribution and no sort: micro-batch update") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update")
+ }
+
+ test("unspecified distribution and no sort: micro-batch complete") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete")
+ }
+
test("unspecified distribution and no sort with numPartitions: append") {
checkUnspecifiedDistributionAndNoSort("append", Some(10))
}
@@ -328,6 +429,18 @@ class WriteDistributionAndOrderingSuite
checkUnspecifiedDistributionAndNoSort("overwriteDynamic", Some(10))
}
+ test("unspecified distribution and no sort with numPartitions: micro-batch append") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append", Some(10))
+ }
+
+ test("unspecified distribution and no sort with numPartitions: micro-batch update") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update", Some(10))
+ }
+
+ test("unspecified distribution and no sort with numPartitions: micro-batch complete") {
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete", Some(10))
+ }
+
private def checkUnspecifiedDistributionAndNoSort(command: String): Unit = {
checkUnspecifiedDistributionAndNoSort(command, None)
}
@@ -677,7 +790,95 @@ class WriteDistributionAndOrderingSuite
writeCommand = command)
}
+ test("continuous mode does not support write distribution and ordering") {
+ val ordering = Array[SortOrder](
+ sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)
+ )
+ val distribution = Distributions.ordered(ordering)
+
+ catalog.createTable(ident, schema, Array.empty, emptyProps, distribution, ordering, None)
+
+ withTempDir { checkpointDir =>
+ val inputData = ContinuousMemoryStream[(Long, String)]
+ val inputDF = inputData.toDF().toDF("id", "data")
+
+ val writer = inputDF
+ .writeStream
+ .trigger(Trigger.Continuous(100))
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .outputMode("append")
+
+ val analysisException = intercept[AnalysisException] {
+ val query = writer.toTable(tableNameAsString)
+
+ inputData.addData((1, "a"), (2, "b"))
+
+ query.processAllAvailable()
+ query.stop()
+ }
+
+ assert(analysisException.message.contains("Sinks cannot request distribution and ordering"))
+ }
+ }
+
+ test("continuous mode allows unspecified distribution and empty ordering") {
+ catalog.createTable(ident, schema, Array.empty, emptyProps)
+
+ withTempDir { checkpointDir =>
+ val inputData = ContinuousMemoryStream[(Long, String)]
+ val inputDF = inputData.toDF().toDF("id", "data")
+
+ val writer = inputDF
+ .writeStream
+ .trigger(Trigger.Continuous(100))
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .outputMode("append")
+
+ val query = writer.toTable(tableNameAsString)
+
+ inputData.addData((1, "a"), (2, "b"))
+
+ query.processAllAvailable()
+ query.stop()
+
+ checkAnswer(spark.table(tableNameAsString), Row(1, "a") :: Row(2, "b") :: Nil)
+ }
+ }
+
private def checkWriteRequirements(
+ tableDistribution: Distribution,
+ tableOrdering: Array[SortOrder],
+ tableNumPartitions: Option[Int],
+ expectedWritePartitioning: physical.Partitioning,
+ expectedWriteOrdering: Seq[catalyst.expressions.SortOrder],
+ writeTransform: DataFrame => DataFrame = df => df,
+ writeCommand: String,
+ expectAnalysisException: Boolean = false): Unit = {
+
+ if (writeCommand.startsWith(microBatchPrefix)) {
+ checkMicroBatchWriteRequirements(
+ tableDistribution,
+ tableOrdering,
+ tableNumPartitions,
+ expectedWritePartitioning,
+ expectedWriteOrdering,
+ writeTransform,
+ outputMode = writeCommand.stripPrefix(microBatchPrefix),
+ expectAnalysisException)
+ } else {
+ checkBatchWriteRequirements(
+ tableDistribution,
+ tableOrdering,
+ tableNumPartitions,
+ expectedWritePartitioning,
+ expectedWriteOrdering,
+ writeTransform,
+ writeCommand,
+ expectAnalysisException)
+ }
+ }
+
+ private def checkBatchWriteRequirements(
tableDistribution: Distribution,
tableOrdering: Array[SortOrder],
tableNumPartitions: Option[Int],
@@ -712,15 +913,84 @@ class WriteDistributionAndOrderingSuite
}
}
+ private def checkMicroBatchWriteRequirements(
+ tableDistribution: Distribution,
+ tableOrdering: Array[SortOrder],
+ tableNumPartitions: Option[Int],
+ expectedWritePartitioning: physical.Partitioning,
+ expectedWriteOrdering: Seq[catalyst.expressions.SortOrder],
+ writeTransform: DataFrame => DataFrame = df => df,
+ outputMode: String = "append",
+ expectAnalysisException: Boolean = false): Unit = {
+
+ catalog.createTable(ident, schema, Array.empty, emptyProps, tableDistribution,
+ tableOrdering, tableNumPartitions)
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[(Long, String)]
+ val inputDF = inputData.toDF().toDF("id", "data")
+
+ val queryDF = outputMode match {
+ case "append" | "update" =>
+ inputDF
+ case "complete" =>
+ // add an aggregate for complete mode
+ inputDF
+ .groupBy("id")
+ .agg(Map("data" -> "count"))
+ .select($"id", $"count(data)".cast("string").as("data"))
+ }
+
+ val writer = writeTransform(queryDF)
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .outputMode(outputMode)
+
+ def executeCommand(): SparkPlan = execute {
+ val query = writer.toTable(tableNameAsString)
+
+ inputData.addData((1, "a"), (2, "b"))
+
+ query.processAllAvailable()
+ query.stop()
+ }
+
+ if (expectAnalysisException) {
+ val streamingQueryException = intercept[StreamingQueryException] {
+ executeCommand()
+ }
+ val cause = streamingQueryException.cause
+ assert(cause.getMessage.contains("number of partitions can't be specified"))
+
+ } else {
+ val executedPlan = executeCommand()
+
+ checkPartitioningAndOrdering(
+ executedPlan,
+ expectedWritePartitioning,
+ expectedWriteOrdering,
+ // there is an extra shuffle for groupBy in complete mode
+ maxNumShuffles = if (outputMode != "complete") 1 else 2)
+
+ val expectedRows = outputMode match {
+ case "append" | "update" => Row(1, "a") :: Row(2, "b") :: Nil
+ case "complete" => Row(1, "1") :: Row(2, "1") :: Nil
+ }
+ checkAnswer(spark.table(tableNameAsString), expectedRows)
+ }
+ }
+ }
+
private def checkPartitioningAndOrdering(
plan: SparkPlan,
partitioning: physical.Partitioning,
- ordering: Seq[catalyst.expressions.SortOrder]): Unit = {
+ ordering: Seq[catalyst.expressions.SortOrder],
+ maxNumShuffles: Int = 1): Unit = {
val sorts = collect(plan) { case s: SortExec => s }
assert(sorts.size <= 1, "must be at most one sort")
val shuffles = collect(plan) { case s: ShuffleExchangeLike => s }
- assert(shuffles.size <= 1, "must be at most one shuffle")
+ assert(shuffles.size <= maxNumShuffles, $"must be at most $maxNumShuffles shuffles")
val actualPartitioning = plan.outputPartitioning
val expectedPartitioning = partitioning match {
@@ -730,6 +1000,9 @@ class WriteDistributionAndOrderingSuite
case p: physical.HashPartitioning =>
val resolvedExprs = p.expressions.map(resolveAttrs(_, plan))
p.copy(expressions = resolvedExprs)
+ case _: UnknownPartitioning =>
+ // don't check partitioning if no particular one is expected
+ actualPartitioning
case other => other
}
assert(actualPartitioning == expectedPartitioning, "partitioning must match")
@@ -739,28 +1012,6 @@ class WriteDistributionAndOrderingSuite
assert(actualOrdering == expectedOrdering, "ordering must match")
}
- private def resolveAttrs(
- expr: catalyst.expressions.Expression,
- plan: SparkPlan): catalyst.expressions.Expression = {
-
- expr.transform {
- case UnresolvedAttribute(Seq(attrName)) =>
- plan.output.find(attr => resolver(attr.name, attrName)).get
- case UnresolvedAttribute(nameParts) =>
- val attrName = nameParts.mkString(".")
- fail(s"cannot resolve a nested attr: $attrName")
- }
- }
-
- private def attr(name: String): UnresolvedAttribute = {
- UnresolvedAttribute(name)
- }
-
- private def catalog: InMemoryTableCatalog = {
- val catalog = spark.sessionState.catalogManager.catalog("testcat")
- catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog]
- }
-
// executes a write operation and keeps the executed physical plan
private def execute(writeFunc: => Unit): SparkPlan = {
var executedPlan: SparkPlan = null
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
new file mode 100644
index 0000000000..1994874d32
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector.catalog.functions
+
+import org.apache.spark.sql.types._
+
+object UnboundYearsFunction extends UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.size == 1 && isValidType(inputType.head.dataType)) YearsFunction
+ else throw new UnsupportedOperationException(
+ "'years' only take date or timestamp as input type")
+ }
+
+ private def isValidType(dt: DataType): Boolean = dt match {
+ case DateType | TimestampType => true
+ case _ => false
+ }
+
+ override def description(): String = name()
+ override def name(): String = "years"
+}
+
+object YearsFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(TimestampType)
+ override def resultType(): DataType = LongType
+ override def name(): String = "years"
+ override def canonicalName(): String = name()
+}
+
+object DaysFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(TimestampType)
+ override def resultType(): DataType = LongType
+ override def name(): String = "days"
+ override def canonicalName(): String = name()
+}
+
+object UnboundDaysFunction extends UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.size == 1 && isValidType(inputType.head.dataType)) DaysFunction
+ else throw new UnsupportedOperationException(
+ "'days' only take date or timestamp as input type")
+ }
+
+ private def isValidType(dt: DataType): Boolean = dt match {
+ case DateType | TimestampType => true
+ case _ => false
+ }
+
+ override def description(): String = name()
+ override def name(): String = "days"
+}
+
+object UnboundBucketFunction extends UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = BucketFunction
+ override def description(): String = name()
+ override def name(): String = "bucket"
+}
+
+object BucketFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "bucket"
+ override def canonicalName(): String = name()
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala
new file mode 100644
index 0000000000..a8fbfc49a5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2Provider}
+import org.apache.spark.sql.test.SharedSparkSession
+
+class QueryCompilationErrorsDSv2Suite
+ extends QueryTest
+ with SharedSparkSession
+ with DatasourceV2SQLBase {
+
+ test("UNSUPPORTED_FEATURE: IF PARTITION NOT EXISTS not supported by INSERT") {
+ val v2Format = classOf[FakeV2Provider].getName
+ val tbl = "testcat.ns1.ns2.tbl"
+
+ withTable(tbl) {
+ val view = "tmp_view"
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
+ df.createOrReplaceTempView(view)
+ withTempView(view) {
+ sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format PARTITIONED BY (id)")
+
+ val e = intercept[AnalysisException] {
+ sql(s"INSERT OVERWRITE TABLE $tbl PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view")
+ }
+
+ checkAnswer(spark.table(tbl), spark.emptyDataFrame)
+ assert(e.getMessage === "The feature is not supported: " +
+ s"""IF NOT EXISTS for the table `testcat`.`ns1`.`ns2`.`tbl` by INSERT INTO.""")
+ assert(e.getErrorClass === "UNSUPPORTED_FEATURE")
+ assert(e.getSqlState === "0A000")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
new file mode 100644
index 0000000000..9e18e4e669
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import org.apache.spark.sql.{AnalysisException, ClassData, IntegratedUDFTestUtils, QueryTest}
+import org.apache.spark.sql.functions.{grouping, grouping_id, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+case class StringLongClass(a: String, b: Long)
+
+case class StringIntClass(a: String, b: Int)
+
+case class ComplexClass(a: Long, b: StringLongClass)
+
+case class ArrayClass(arr: Seq[StringIntClass])
+
+class QueryCompilationErrorsSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ test("CANNOT_UP_CAST_DATATYPE: invalid upcast data type") {
+ val msg1 = intercept[AnalysisException] {
+ sql("select 'value1' as a, 1L as b").as[StringIntClass]
+ }.message
+ assert(msg1 ===
+ s"""
+ |Cannot up cast b from "BIGINT" to "INT".
+ |The type path of the target object is:
+ |- field (class: "scala.Int", name: "b")
+ |- root class: "org.apache.spark.sql.errors.StringIntClass"
+ |You can either add an explicit cast to the input data or choose a higher precision type
+ """.stripMargin.trim + " of the field in the target object")
+
+ val msg2 = intercept[AnalysisException] {
+ sql("select 1L as a," +
+ " named_struct('a', 'value1', 'b', cast(1.0 as decimal(38,18))) as b")
+ .as[ComplexClass]
+ }.message
+ assert(msg2 ===
+ s"""
+ |Cannot up cast b.`b` from "DECIMAL(38,18)" to "BIGINT".
+ |The type path of the target object is:
+ |- field (class: "scala.Long", name: "b")
+ |- field (class: "org.apache.spark.sql.errors.StringLongClass", name: "b")
+ |- root class: "org.apache.spark.sql.errors.ComplexClass"
+ |You can either add an explicit cast to the input data or choose a higher precision type
+ """.stripMargin.trim + " of the field in the target object")
+ }
+
+ test("UNSUPPORTED_GROUPING_EXPRESSION: filter with grouping/grouping_Id expression") {
+ val df = Seq(
+ (536361, "85123A", 2, 17850),
+ (536362, "85123B", 4, 17850),
+ (536363, "86123A", 6, 17851)
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+ Seq("grouping", "grouping_id").foreach { grouping =>
+ val errMsg = intercept[AnalysisException] {
+ df.groupBy("CustomerId").agg(Map("Quantity" -> "max"))
+ .filter(s"$grouping(CustomerId)=17850")
+ }
+ assert(errMsg.message ===
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION"))
+ }
+ }
+
+ test("UNSUPPORTED_GROUPING_EXPRESSION: Sort with grouping/grouping_Id expression") {
+ val df = Seq(
+ (536361, "85123A", 2, 17850),
+ (536362, "85123B", 4, 17850),
+ (536363, "86123A", 6, 17851)
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+ Seq(grouping("CustomerId"), grouping_id("CustomerId")).foreach { grouping =>
+ val errMsg = intercept[AnalysisException] {
+ df.groupBy("CustomerId").agg(Map("Quantity" -> "max")).
+ sort(grouping)
+ }
+ assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION"))
+ assert(errMsg.message ===
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
+ }
+
+ test("INVALID_PARAMETER_VALUE: the argument_index of string format is invalid") {
+ withSQLConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING.key -> "false") {
+ val e = intercept[AnalysisException] {
+ sql("select format_string('%0$s', 'Hello')")
+ }
+ assert(e.errorClass === Some("INVALID_PARAMETER_VALUE"))
+ assert(e.message === "The value of parameter(s) 'strfmt' in `format_string` is invalid: " +
+ "expects %1$, %2$ and so on, but got %0$.")
+ }
+ }
+
+ test("CANNOT_USE_MIXTURE: Using aggregate function with grouped aggregate pandas UDF") {
+ import IntegratedUDFTestUtils._
+ assume(shouldTestGroupedAggPandasUDFs)
+
+ val df = Seq(
+ (536361, "85123A", 2, 17850),
+ (536362, "85123B", 4, 17850),
+ (536363, "86123A", 6, 17851)
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+ val e = intercept[AnalysisException] {
+ val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf")
+ df.groupBy("CustomerId")
+ .agg(pandasTestUDF(df("Quantity")), sum(df("Quantity"))).collect()
+ }
+
+ assert(e.errorClass === Some("CANNOT_USE_MIXTURE"))
+ assert(e.message ===
+ "Cannot use a mixture of aggregate function and group aggregate pandas UDF")
+ }
+
+ test("UNSUPPORTED_FEATURE: Using Python UDF with unsupported join condition") {
+ import IntegratedUDFTestUtils._
+
+ val df1 = Seq(
+ (536361, "85123A", 2, 17850),
+ (536362, "85123B", 4, 17850),
+ (536363, "86123A", 6, 17851)
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+ val df2 = Seq(
+ ("Bob", 17850),
+ ("Alice", 17850),
+ ("Tom", 17851)
+ ).toDF("CustomerName", "CustomerID")
+
+ val e = intercept[AnalysisException] {
+ val pythonTestUDF = TestPythonUDF(name = "python_udf")
+ df1.join(
+ df2, pythonTestUDF(df1("CustomerID") === df2("CustomerID")), "leftouter").collect()
+ }
+
+ assert(e.errorClass === Some("UNSUPPORTED_FEATURE"))
+ assert(e.getSqlState === "0A000")
+ assert(e.message ===
+ "The feature is not supported: " +
+ "Using PythonUDF in join condition of join type LEFT OUTER is not supported.")
+ }
+
+ test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") {
+ import IntegratedUDFTestUtils._
+ assume(shouldTestGroupedAggPandasUDFs)
+
+ val df = Seq(
+ (536361, "85123A", 2, 17850),
+ (536362, "85123B", 4, 17850),
+ (536363, "86123A", 6, 17851)
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
+
+ val e = intercept[AnalysisException] {
+ val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf")
+ df.groupBy(df("CustomerID")).pivot(df("CustomerID")).agg(pandasTestUDF(df("Quantity")))
+ }
+
+ assert(e.errorClass === Some("UNSUPPORTED_FEATURE"))
+ assert(e.getSqlState === "0A000")
+ assert(e.message ===
+ "The feature is not supported: " +
+ "Pandas UDF aggregate expressions don't support pivot.")
+ }
+
+ test("UNSUPPORTED_DESERIALIZER: data type mismatch") {
+ val e = intercept[AnalysisException] {
+ sql("select 1 as arr").as[ArrayClass]
+ }
+ assert(e.errorClass === Some("UNSUPPORTED_DESERIALIZER"))
+ assert(e.message ===
+ """The deserializer is not supported: need a(n) "ARRAY" field but got "INT".""")
+ }
+
+ test("UNSUPPORTED_DESERIALIZER:" +
+ "the real number of fields doesn't match encoder schema") {
+ val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ val e1 = intercept[AnalysisException] {
+ ds.as[(String, Int, Long)]
+ }
+ assert(e1.errorClass === Some("UNSUPPORTED_DESERIALIZER"))
+ assert(e1.message ===
+ "The deserializer is not supported: try to map \"STRUCT<a: STRING, b: INT>\" " +
+ "to Tuple3, but failed as the number of fields does not line up.")
+
+ val e2 = intercept[AnalysisException] {
+ ds.as[Tuple1[String]]
+ }
+ assert(e2.errorClass === Some("UNSUPPORTED_DESERIALIZER"))
+ assert(e2.message ===
+ "The deserializer is not supported: try to map \"STRUCT<a: STRING, b: INT>\" " +
+ "to Tuple1, but failed as the number of fields does not line up.")
+ }
+
+ test("UNSUPPORTED_GENERATOR: " +
+ "generators are not supported when it's nested in expressions") {
+ val e = intercept[AnalysisException](
+ sql("""select explode(Array(1, 2, 3)) + 1""").collect()
+ )
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR"))
+ assert(e.message ===
+ """The generator is not supported: """ +
+ """nested in expressions "(explode(array(1, 2, 3)) + 1)"""")
+ }
+
+ test("UNSUPPORTED_GENERATOR: only one generator allowed") {
+ val e = intercept[AnalysisException](
+ sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect()
+ )
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR"))
+ assert(e.message ===
+ "The generator is not supported: only one generator allowed per select clause " +
+ """but found 2: "explode(array(1, 2, 3))", "explode(array(1, 2, 3))"""")
+ }
+
+ test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") {
+ val e = intercept[AnalysisException](
+ sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect()
+ )
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR"))
+ assert(e.message ===
+ "The generator is not supported: outside the SELECT clause, found: " +
+ "'Sort [explode(array(1, 2, 3)) ASC NULLS FIRST], true")
+ }
+
+ test("UNSUPPORTED_GENERATOR: not a generator") {
+ val e = intercept[AnalysisException](
+ sql(
+ """
+ |SELECT explodedvalue.*
+ |FROM VALUES array(1, 2, 3) AS (value)
+ |LATERAL VIEW array_contains(value, 1) AS explodedvalue""".stripMargin).collect()
+ )
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR"))
+ assert(e.message ===
+ """The generator is not supported: `array_contains` is expected to be a generator. """ +
+ "However, its class is org.apache.spark.sql.catalyst.expressions.ArrayContains, " +
+ "which is not a generator.")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
new file mode 100644
index 0000000000..21acea53ed
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import java.io.File
+import java.net.URI
+
+import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException}
+import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
+import org.apache.spark.sql.execution.datasources.orc.OrcTest
+import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
+import org.apache.spark.sql.functions.{lit, lower, struct, sum}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.EXCEPTION
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.Utils
+
+class QueryExecutionErrorsSuite extends QueryTest
+ with ParquetTest with OrcTest with SharedSparkSession {
+
+ import testImplicits._
+
+ private def getAesInputs(): (DataFrame, DataFrame) = {
+ val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w=="
+ val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w=="
+ val encryptedText32 = "9J3iZbIxnmaG+OIA9Amd+A=="
+ val encryptedEmptyText16 = "jmTOhz8XTbskI/zYFFgOFQ=="
+ val encryptedEmptyText24 = "9RDK70sHNzqAFRcpfGM5gQ=="
+ val encryptedEmptyText32 = "j9IDsCvlYXtcVJUf4FAjQQ=="
+
+ val df1 = Seq("Spark", "").toDF
+ val df2 = Seq(
+ (encryptedText16, encryptedText24, encryptedText32),
+ (encryptedEmptyText16, encryptedEmptyText24, encryptedEmptyText32)
+ ).toDF("value16", "value24", "value32")
+
+ (df1, df2)
+ }
+
+ test("INVALID_PARAMETER_VALUE: invalid key lengths in AES functions") {
+ val (df1, df2) = getAesInputs()
+ def checkInvalidKeyLength(df: => DataFrame): Unit = {
+ val e = intercept[SparkException] {
+ df.collect
+ }.getCause.asInstanceOf[SparkRuntimeException]
+ assert(e.getErrorClass === "INVALID_PARAMETER_VALUE")
+ assert(e.getSqlState === "22023")
+ assert(e.getMessage.matches(
+ "The value of parameter\\(s\\) 'key' in the `aes_encrypt`/`aes_decrypt` function " +
+ "is invalid: expects a binary value with 16, 24 or 32 bytes, but got \\d+ bytes."))
+ }
+
+ // Encryption failure - invalid key length
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, '12345678901234567')"))
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, binary('123456789012345'))"))
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, binary(''))"))
+
+ // Decryption failure - invalid key length
+ Seq("value16", "value24", "value32").foreach { colName =>
+ checkInvalidKeyLength(df2.selectExpr(
+ s"aes_decrypt(unbase64($colName), '12345678901234567')"))
+ checkInvalidKeyLength(df2.selectExpr(
+ s"aes_decrypt(unbase64($colName), binary('123456789012345'))"))
+ checkInvalidKeyLength(df2.selectExpr(
+ s"aes_decrypt(unbase64($colName), '')"))
+ checkInvalidKeyLength(df2.selectExpr(
+ s"aes_decrypt(unbase64($colName), binary(''))"))
+ }
+ }
+
+ test("INVALID_PARAMETER_VALUE: AES decrypt failure - key mismatch") {
+ val (_, df2) = getAesInputs()
+ Seq(
+ ("value16", "1234567812345678"),
+ ("value24", "123456781234567812345678"),
+ ("value32", "12345678123456781234567812345678")).foreach { case (colName, key) =>
+ val e = intercept[SparkException] {
+ df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 'ECB')").collect
+ }.getCause.asInstanceOf[SparkRuntimeException]
+ assert(e.getErrorClass === "INVALID_PARAMETER_VALUE")
+ assert(e.getSqlState === "22023")
+ assert(e.getMessage ===
+ "The value of parameter(s) 'expr, key' in the `aes_encrypt`/`aes_decrypt` function " +
+ "is invalid: Detail message: " +
+ "Given final block not properly padded. " +
+ "Such issues can arise if a bad key is used during decryption.")
+ }
+ }
+
+ test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") {
+ val key16 = "abcdefghijklmnop"
+ val key32 = "abcdefghijklmnop12345678ABCDEFGH"
+ val (df1, df2) = getAesInputs()
+ def checkUnsupportedMode(df: => DataFrame): Unit = {
+ val e = intercept[SparkException] {
+ df.collect
+ }.getCause.asInstanceOf[SparkRuntimeException]
+ assert(e.getErrorClass === "UNSUPPORTED_FEATURE")
+ assert(e.getSqlState === "0A000")
+ assert(e.getMessage.matches("""The feature is not supported: AES-\w+ with the padding \w+""" +
+ " by the `aes_encrypt`/`aes_decrypt` function."))
+ }
+
+ // Unsupported AES mode and padding in encrypt
+ checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC')"))
+ checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 'NoPadding')"))
+
+ // Unsupported AES mode and padding in decrypt
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GSM')"))
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GCM', 'PKCS')"))
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')"))
+ }
+
+ test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") {
+ def checkUnsupportedTypeInLiteral(v: Any): Unit = {
+ val e1 = intercept[SparkRuntimeException] { lit(v) }
+ assert(e1.getErrorClass === "UNSUPPORTED_FEATURE")
+ assert(e1.getSqlState === "0A000")
+ assert(e1.getMessage.matches("""The feature is not supported: literal for '.+' of .+\."""))
+ }
+ checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2))
+ checkUnsupportedTypeInLiteral(("mike", 29, 1.0))
+
+ val e2 = intercept[SparkRuntimeException] {
+ trainingSales
+ .groupBy($"sales.year")
+ .pivot(struct(lower(trainingSales("sales.course")), trainingSales("training")))
+ .agg(sum($"sales.earnings"))
+ .collect()
+ }
+ assert(e2.getMessage === "The feature is not supported: pivoting by the value" +
+ """ '[dotnet,Dummies]' of the column data type "STRUCT<col1: STRING, training: STRING>".""")
+ }
+
+ test("UNSUPPORTED_FEATURE: unsupported pivot operations") {
+ val e1 = intercept[SparkUnsupportedOperationException] {
+ trainingSales
+ .groupBy($"sales.year")
+ .pivot($"sales.course")
+ .pivot($"training")
+ .agg(sum($"sales.earnings"))
+ .collect()
+ }
+ assert(e1.getErrorClass === "UNSUPPORTED_FEATURE")
+ assert(e1.getSqlState === "0A000")
+ assert(e1.getMessage === """The feature is not supported: Repeated PIVOTs.""")
+
+ val e2 = intercept[SparkUnsupportedOperationException] {
+ trainingSales
+ .rollup($"sales.year")
+ .pivot($"training")
+ .agg(sum($"sales.earnings"))
+ .collect()
+ }
+ assert(e2.getErrorClass === "UNSUPPORTED_FEATURE")
+ assert(e2.getSqlState === "0A000")
+ assert(e2.getMessage === """The feature is not supported: PIVOT not after a GROUP BY.""")
+ }
+
+ test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " +
+ "compatibility with Spark 2.4/3.2 in reading/writing dates") {
+
+ // Fail to read ancient datetime values.
+ withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) {
+ val fileName = "before_1582_date_v2_4_5.snappy.parquet"
+ val filePath = getResourceParquetFilePath("test-data/" + fileName)
+ val e = intercept[SparkException] {
+ spark.read.parquet(filePath).collect()
+ }.getCause.asInstanceOf[SparkUpgradeException]
+
+ val format = "Parquet"
+ val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_READ.key + "\""
+ val option = "\"" + "datetimeRebaseMode" + "\""
+ assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION")
+ assert(e.getMessage ===
+ "You may get a different result due to the upgrading to Spark >= 3.0: " +
+ s"""
+ |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z
+ |from $format files can be ambiguous, as the files may be written by
+ |Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar
+ |that is different from Spark 3.0+'s Proleptic Gregorian calendar.
+ |See more details in SPARK-31404. You can set the SQL config $config or
+ |the datasource option $option to "LEGACY" to rebase the datetime values
+ |w.r.t. the calendar difference during reading. To read the datetime values
+ |as it is, set the SQL config $config or the datasource option $option
+ |to "CORRECTED".
+ |""".stripMargin)
+ }
+
+ // Fail to write ancient datetime values.
+ withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> EXCEPTION.toString) {
+ withTempPath { dir =>
+ val df = Seq(java.sql.Date.valueOf("1001-01-01")).toDF("dt")
+ val e = intercept[SparkException] {
+ df.write.parquet(dir.getCanonicalPath)
+ }.getCause.getCause.getCause.asInstanceOf[SparkUpgradeException]
+
+ val format = "Parquet"
+ val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + "\""
+ assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION")
+ assert(e.getMessage ===
+ "You may get a different result due to the upgrading to Spark >= 3.0: " +
+ s"""
+ |writing dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z
+ |into $format files can be dangerous, as the files may be read by Spark 2.x
+ |or legacy versions of Hive later, which uses a legacy hybrid calendar that
+ |is different from Spark 3.0+'s Proleptic Gregorian calendar. See more
+ |details in SPARK-31404. You can set $config to "LEGACY" to rebase the
+ |datetime values w.r.t. the calendar difference during writing, to get maximum
+ |interoperability. Or set $config to "CORRECTED" to write the datetime
+ |values as it is, if you are 100% sure that the written files will only be read by
+ |Spark 3.0+ or other systems that use Proleptic Gregorian calendar.
+ |""".stripMargin)
+ }
+ }
+ }
+
+ test("UNSUPPORTED_OPERATION - SPARK-36346: can't read Timestamp as TimestampNTZ") {
+ withTempPath { file =>
+ sql("select timestamp_ltz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath)
+ withAllNativeOrcReaders {
+ val e = intercept[SparkException] {
+ spark.read.schema("time timestamp_ntz").orc(file.getCanonicalPath).collect()
+ }.getCause.asInstanceOf[SparkUnsupportedOperationException]
+
+ assert(e.getErrorClass === "UNSUPPORTED_OPERATION")
+ assert(e.getMessage === "The operation is not supported: " +
+ "Unable to convert \"TIMESTAMP\" of Orc to data type \"TIMESTAMP_NTZ\".")
+ }
+ }
+ }
+
+ test("UNSUPPORTED_OPERATION - SPARK-38504: can't read TimestampNTZ as TimestampLTZ") {
+ withTempPath { file =>
+ sql("select timestamp_ntz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath)
+ withAllNativeOrcReaders {
+ val e = intercept[SparkException] {
+ spark.read.schema("time timestamp_ltz").orc(file.getCanonicalPath).collect()
+ }.getCause.asInstanceOf[SparkUnsupportedOperationException]
+
+ assert(e.getErrorClass === "UNSUPPORTED_OPERATION")
+ assert(e.getMessage === "The operation is not supported: " +
+ "Unable to convert \"TIMESTAMP_NTZ\" of Orc to data type \"TIMESTAMP\".")
+ }
+ }
+ }
+
+ test("DATETIME_OVERFLOW: timestampadd() overflows its input timestamp") {
+ val e = intercept[SparkArithmeticException] {
+ sql("select timestampadd(YEAR, 1000000, timestamp'2022-03-09 01:02:03')").collect()
+ }
+ assert(e.getErrorClass === "DATETIME_OVERFLOW")
+ assert(e.getSqlState === "22008")
+ assert(e.getMessage ===
+ "Datetime operation overflow: add 1000000 YEAR to TIMESTAMP '2022-03-09 01:02:03'.")
+ }
+
+ test("UNSUPPORTED_SAVE_MODE: unsupported null saveMode whether the path exists or not") {
+ withTempPath { path =>
+ val e1 = intercept[SparkIllegalArgumentException] {
+ val saveMode: SaveMode = null
+ Seq(1, 2).toDS().write.mode(saveMode).parquet(path.getAbsolutePath)
+ }
+ assert(e1.getErrorClass === "UNSUPPORTED_SAVE_MODE")
+ assert(e1.getMessage === "The save mode NULL is not supported for: a non-existent path.")
+
+ Utils.createDirectory(path)
+
+ val e2 = intercept[SparkIllegalArgumentException] {
+ val saveMode: SaveMode = null
+ Seq(1, 2).toDS().write.mode(saveMode).parquet(path.getAbsolutePath)
+ }
+ assert(e2.getErrorClass === "UNSUPPORTED_SAVE_MODE")
+ assert(e2.getMessage === "The save mode NULL is not supported for: an existent path.")
+ }
+ }
+
+ test("INVALID_BUCKET_FILE: error if there exists any malformed bucket files") {
+ val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).
+ toDF("i", "j", "k").as("df1")
+
+ withTable("bucketed_table") {
+ df1.write.format("parquet").bucketBy(8, "i").
+ saveAsTable("bucketed_table")
+ val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath
+ val tableDir = new File(warehouseFilePath, "bucketed_table")
+ Utils.deleteRecursively(tableDir)
+ df1.write.parquet(tableDir.getAbsolutePath)
+
+ val aggregated = spark.table("bucketed_table").groupBy("i").count()
+
+ val e = intercept[SparkException] {
+ aggregated.count()
+ }
+ assert(e.getErrorClass === "INVALID_BUCKET_FILE")
+ assert(e.getMessage.matches("Invalid bucket file: .+"))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
new file mode 100644
index 0000000000..6494e541d4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.errors
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.test.SharedSparkSession
+
+// Turn of the length check because most of the tests check entire error messages
+// scalastyle:off line.size.limit
+class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession {
+ def validateParsingError(
+ sqlText: String,
+ errorClass: String,
+ sqlState: String,
+ message: String): Unit = {
+ val e = intercept[ParseException] {
+ sql(sqlText)
+ }
+ assert(e.getErrorClass === errorClass)
+ assert(e.getSqlState === sqlState)
+ assert(e.getMessage === message)
+ }
+
+ test("UNSUPPORTED_FEATURE: LATERAL join with NATURAL join not supported") {
+ validateParsingError(
+ sqlText = "SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ """
+ |The feature is not supported: LATERAL join with NATURAL join.(line 1, pos 14)
+ |
+ |== SQL ==
+ |SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)
+ |--------------^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: LATERAL join with USING join not supported") {
+ validateParsingError(
+ sqlText = "SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ """
+ |The feature is not supported: LATERAL join with USING join.(line 1, pos 14)
+ |
+ |== SQL ==
+ |SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)
+ |--------------^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: Unsupported LATERAL join type") {
+ Seq("RIGHT OUTER", "FULL OUTER", "LEFT SEMI", "LEFT ANTI").foreach { joinType =>
+ validateParsingError(
+ sqlText = s"SELECT * FROM t1 $joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ s"""
+ |The feature is not supported: LATERAL join type $joinType.(line 1, pos 14)
+ |
+ |== SQL ==
+ |SELECT * FROM t1 $joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3
+ |--------------^^^
+ |""".stripMargin)
+ }
+ }
+
+ test("INVALID_SQL_SYNTAX: LATERAL can only be used with subquery") {
+ Seq(
+ "SELECT * FROM t1, LATERAL t2" -> 26,
+ "SELECT * FROM t1 JOIN LATERAL t2" -> 30,
+ "SELECT * FROM t1, LATERAL (t2 JOIN t3)" -> 26,
+ "SELECT * FROM t1, LATERAL (LATERAL t2)" -> 26,
+ "SELECT * FROM t1, LATERAL VALUES (0, 1)" -> 26,
+ "SELECT * FROM t1, LATERAL RANGE(0, 1)" -> 26
+ ).foreach { case (sqlText, pos) =>
+ validateParsingError(
+ sqlText = sqlText,
+ errorClass = "INVALID_SQL_SYNTAX",
+ sqlState = "42000",
+ message =
+ s"""
+ |Invalid SQL syntax: LATERAL can only be used with subquery.(line 1, pos $pos)
+ |
+ |== SQL ==
+ |$sqlText
+ |${"-" * pos}^^^
+ |""".stripMargin)
+ }
+ }
+
+ test("UNSUPPORTED_FEATURE: NATURAL CROSS JOIN is not supported") {
+ validateParsingError(
+ sqlText = "SELECT * FROM a NATURAL CROSS JOIN b",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ """
+ |The feature is not supported: NATURAL CROSS JOIN.(line 1, pos 14)
+ |
+ |== SQL ==
+ |SELECT * FROM a NATURAL CROSS JOIN b
+ |--------------^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_SQL_SYNTAX: redefine window") {
+ validateParsingError(
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2",
+ errorClass = "INVALID_SQL_SYNTAX",
+ sqlState = "42000",
+ message =
+ """
+ |Invalid SQL syntax: The definition of window `win` is repetitive.(line 1, pos 31)
+ |
+ |== SQL ==
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2
+ |-------------------------------^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_SQL_SYNTAX: invalid window reference") {
+ validateParsingError(
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win",
+ errorClass = "INVALID_SQL_SYNTAX",
+ sqlState = "42000",
+ message =
+ """
+ |Invalid SQL syntax: Window reference `win` is not a window specification.(line 1, pos 31)
+ |
+ |== SQL ==
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win
+ |-------------------------------^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_SQL_SYNTAX: window reference cannot be resolved") {
+ validateParsingError(
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win2",
+ errorClass = "INVALID_SQL_SYNTAX",
+ sqlState = "42000",
+ message =
+ """
+ |Invalid SQL syntax: Cannot resolve window reference `win2`.(line 1, pos 31)
+ |
+ |== SQL ==
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win2
+ |-------------------------------^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: TRANSFORM does not support DISTINCT/ALL") {
+ validateParsingError(
+ sqlText = "SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ """
+ |The feature is not supported: TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17)
+ |
+ |== SQL ==
+ |SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t
+ |-----------------^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: In-memory mode does not support TRANSFORM with serde") {
+ validateParsingError(
+ sqlText = "SELECT TRANSFORM(a) ROW FORMAT SERDE " +
+ "'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t",
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ """
+ |The feature is not supported: TRANSFORM with serde is only supported in hive mode(line 1, pos 0)
+ |
+ |== SQL ==
+ |SELECT TRANSFORM(a) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t
+ |^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_SQL_SYNTAX: Too many arguments for transform") {
+ validateParsingError(
+ sqlText = "CREATE TABLE table(col int) PARTITIONED BY (years(col,col))",
+ errorClass = "INVALID_SQL_SYNTAX",
+ sqlState = "42000",
+ message =
+ """
+ |Invalid SQL syntax: Too many arguments for transform `years`(line 1, pos 44)
+ |
+ |== SQL ==
+ |CREATE TABLE table(col int) PARTITIONED BY (years(col,col))
+ |--------------------------------------------^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: cannot set reserved namespace property") {
+ val sql = "CREATE NAMESPACE IF NOT EXISTS a.b.c WITH PROPERTIES ('location'='/home/user/db')"
+ val msg = """The feature is not supported: location is a reserved namespace property, """ +
+ """please use the LOCATION clause to specify it.(line 1, pos 0)"""
+ validateParsingError(
+ sqlText = sql,
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ s"""
+ |$msg
+ |
+ |== SQL ==
+ |$sql
+ |^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: cannot set reserved table property") {
+ val sql = "CREATE TABLE student (id INT, name STRING, age INT) " +
+ "USING PARQUET TBLPROPERTIES ('provider'='parquet')"
+ val msg = """The feature is not supported: provider is a reserved table property, """ +
+ """please use the USING clause to specify it.(line 1, pos 66)"""
+ validateParsingError(
+ sqlText = sql,
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ s"""
+ |$msg
+ |
+ |== SQL ==
+ |$sql
+ |------------------------------------------------------------------^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_PROPERTY_KEY: invalid property key for set quoted configuration") {
+ val sql = "set =`value`"
+ val msg = """"" is an invalid property key, please use quotes, """ +
+ """e.g. SET ""="value"(line 1, pos 0)"""
+ validateParsingError(
+ sqlText = sql,
+ errorClass = "INVALID_PROPERTY_KEY",
+ sqlState = null,
+ message =
+ s"""
+ |$msg
+ |
+ |== SQL ==
+ |$sql
+ |^^^
+ |""".stripMargin)
+ }
+
+ test("INVALID_PROPERTY_VALUE: invalid property value for set quoted configuration") {
+ val sql = "set `key`=1;2;;"
+ val msg = """"1;2;;" is an invalid property value, please use quotes, """ +
+ """e.g. SET "key"="1;2;;"(line 1, pos 0)"""
+ validateParsingError(
+ sqlText = sql,
+ errorClass = "INVALID_PROPERTY_VALUE",
+ sqlState = null,
+ message =
+ s"""
+ |$msg
+ |
+ |== SQL ==
+ |$sql
+ |^^^
+ |""".stripMargin)
+ }
+
+ test("UNSUPPORTED_FEATURE: cannot set Properties and DbProperties at the same time") {
+ val sql = "CREATE NAMESPACE IF NOT EXISTS a.b.c WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') " +
+ "WITH DBPROPERTIES('a'='a', 'b'='b', 'c'='c')"
+ val msg = """The feature is not supported: set PROPERTIES and DBPROPERTIES at the same time.""" +
+ """(line 1, pos 0)"""
+ validateParsingError(
+ sqlText = sql,
+ errorClass = "UNSUPPORTED_FEATURE",
+ sqlState = "0A000",
+ message =
+ s"""
+ |$msg
+ |
+ |== SQL ==
+ |$sql
+ |^^^
+ |""".stripMargin)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala
index a33b9fad7f..06fc2022c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala
@@ -35,9 +35,9 @@ class AggregatingAccumulatorSuite
extends SparkFunSuite
with SharedSparkSession
with ExpressionEvalHelper {
- private val a = 'a.long
- private val b = 'b.string
- private val c = 'c.double
+ private val a = Symbol("a").long
+ private val b = Symbol("b").string
+ private val c = Symbol("c").double
private val inputAttributes = Seq(a, b, c)
private def str(s: String): UTF8String = UTF8String.fromString(s)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
index 80a2ee03ca..09a880a706 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
@@ -133,8 +133,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
""".stripMargin)
checkAnswer(query, identity, df.select(
- 'a.cast("string"),
- 'b.cast("string"),
+ Symbol("a").cast("string"),
+ Symbol("b").cast("string"),
'c.cast("string"),
'd.cast("string"),
'e.cast("string")).collect())
@@ -164,7 +164,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
'b.cast("string").as("value")).collect())
checkAnswer(
- df.select('a, 'b),
+ df.select(Symbol("a"), Symbol("b")),
(child: SparkPlan) => createScriptTransformationExec(
script = "cat",
output = Seq(
@@ -178,7 +178,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
'b.cast("string").as("value")).collect())
checkAnswer(
- df.select('a),
+ df.select(Symbol("a")),
(child: SparkPlan) => createScriptTransformationExec(
script = "cat",
output = Seq(
@@ -242,7 +242,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
child = child,
ioschema = serde
),
- df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect())
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"),
+ Symbol("f"), Symbol("g"), Symbol("h"), Symbol("i"), Symbol("j")).collect())
}
}
}
@@ -282,7 +283,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
child = child,
ioschema = defaultIOSchema
),
- df.select('a, 'b, 'c, 'd, 'e).collect())
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e")).collect())
}
}
@@ -304,7 +305,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
|USING 'cat' AS (a timestamp, b date)
|FROM v
""".stripMargin)
- checkAnswer(query, identity, df.select('a, 'b).collect())
+ checkAnswer(query, identity, df.select(Symbol("a"), Symbol("b")).collect())
}
}
}
@@ -379,7 +380,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18)
checkAnswer(
- df.select('a, 'b),
+ df.select(Symbol("a"), Symbol("b")),
(child: SparkPlan) => createScriptTransformationExec(
script = "cat",
output = Seq(
@@ -452,10 +453,10 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
(Array(6, 7, 8), Array(Array(6, 7), Array(8)),
Map("c" -> 3), Map("d" -> Array("e", "f")))
).toDF("a", "b", "c", "d")
- .select('a, 'b, 'c, 'd,
- struct('a, 'b).as("e"),
- struct('a, 'd).as("f"),
- struct(struct('a, 'b), struct('a, 'd)).as("g")
+ .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"),
+ struct(Symbol("a"), Symbol("b")).as("e"),
+ struct(Symbol("a"), Symbol("d")).as("f"),
+ struct(struct(Symbol("a"), Symbol("b")), struct(Symbol("a"), Symbol("d"))).as("g")
)
checkAnswer(
@@ -483,7 +484,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
child = child,
ioschema = defaultIOSchema
),
- df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect())
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"),
+ Symbol("f"), Symbol("g")).collect())
}
}
@@ -654,6 +656,20 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
df.select($"ym", $"dt").collect())
}
}
+
+ test("SPARK-36675: TRANSFORM should support timestamp_ntz (no serde)") {
+ val df = spark.sql("SELECT timestamp_ntz'2021-09-06 20:19:13' col")
+ checkAnswer(
+ df,
+ (child: SparkPlan) => createScriptTransformationExec(
+ script = "cat",
+ output = Seq(
+ AttributeReference("col", TimestampNTZType)()),
+ child = child,
+ ioschema = defaultIOSchema
+ ),
+ df.select($"col").collect())
+ }
}
case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
index 4ff96e6574..e4f17eb601 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
@@ -26,9 +26,11 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
test("basic") {
val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
- val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
- val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
- val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
+ val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)),
+ Seq(Symbol("i").int, Symbol("s").string))
+ val rightGrouped = GroupedIterator(rightInput, Seq(Symbol("i").int.at(0)),
+ Seq(Symbol("i").int, Symbol("l").long))
+ val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq(Symbol("i").int))
val result = cogrouped.map {
case (key, leftData, rightData) =>
@@ -52,7 +54,8 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
val leftInput = Seq(create_row(2, "a")).iterator
val rightInput = Seq(create_row(1, 2L)).iterator
- val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
+ val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)),
+ Seq(Symbol("i").int, Symbol("s").string))
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
index 2a28517868..ddf4d421f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
@@ -410,14 +410,14 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i)))
+ // Shuffle partition coalescing of the join is performed independent of the non-grouping
+ // aggregate on the other side of the union.
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- // As the pre-shuffle partition number are different, we will skip reducing
- // the shuffle partition numbers.
assert(
finalPlan.collect {
case r @ CoalescedShuffleRead() => r
- }.isEmpty)
+ }.size == 2)
}
withSparkSession(test, 100, None)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
index 612cd6f0d8..e29b7f579f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
@@ -18,13 +18,11 @@ package org.apache.spark.sql.execution
import java.io.File
-import scala.collection.mutable
import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
@@ -215,33 +213,4 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest {
}
}
}
-
- test("SPARK-30362: test input metrics for DSV2") {
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
- Seq("json", "orc", "parquet").foreach { format =>
- withTempPath { path =>
- val dir = path.getCanonicalPath
- spark.range(0, 10).write.format(format).save(dir)
- val df = spark.read.format(format).load(dir)
- val bytesReads = new mutable.ArrayBuffer[Long]()
- val recordsRead = new mutable.ArrayBuffer[Long]()
- val bytesReadListener = new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
- bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead
- recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
- }
- }
- sparkContext.addSparkListener(bytesReadListener)
- try {
- df.collect()
- sparkContext.listenerBus.waitUntilEmpty()
- assert(bytesReads.sum > 0)
- assert(recordsRead.sum == 10)
- } finally {
- sparkContext.removeSparkListener(bytesReadListener)
- }
- }
- }
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala
index b27a940c36..635c794338 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala
@@ -36,9 +36,9 @@ class DeprecatedWholeStageCodegenSuite extends QueryTest
.groupByKey(_._1).agg(typed.sum(_._2))
val plan = ds.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]))
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
index 4b2a2b439c..06c51cee02 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -32,7 +32,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
val fromRow = encoder.createDeserializer()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(toRow),
- Seq('i.int.at(0)), schema.toAttributes)
+ Seq(Symbol("i").int.at(0)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
@@ -59,7 +59,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
Row(3, 2L, "e"))
val grouped = GroupedIterator(input.iterator.map(toRow),
- Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
+ Seq(Symbol("i").int.at(0), Symbol("l").long.at(1)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
@@ -80,7 +80,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
val toRow = encoder.createSerializer()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(toRow),
- Seq('i.int.at(0)), schema.toAttributes)
+ Seq(Symbol("i").int.at(0)), schema.toAttributes)
assert(grouped.length == 2)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
index 5bcec9b1e5..743ec41dbe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -21,7 +21,7 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.TPCDSQuerySuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window, WithCTE}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -48,6 +48,7 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv
// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes.
// We may add `ColumnarToRowExec` and `InputAdapter` above the scan node after planning.
+ @scala.annotation.tailrec
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child)
case p: ProjectExec => isScanPlanTree(p.child)
@@ -107,7 +108,11 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv
// logical = Project(Filter(Scan A))
// physical = ProjectExec(ScanExec A)
// we only check that leaf modes match between logical and physical plan.
- val logicalLeaves = getLogicalPlan(actualPlan).collectLeaves()
+ val logicalPlan = getLogicalPlan(actualPlan) match {
+ case w: WithCTE => w.plan
+ case o => o
+ }
+ val logicalLeaves = logicalPlan.collectLeaves()
val physicalLeaves = plan.collectLeaves()
assert(logicalLeaves.length == 1)
assert(physicalLeaves.length == 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df310cbaee..000bd8c84f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -59,18 +59,21 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
test("count is partially aggregated") {
- val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
+ val query = testData.groupBy(Symbol("value")).agg(count(Symbol("key"))).queryExecution.analyzed
testPartialAggregationPlan(query)
}
test("count distinct is partially aggregated") {
- val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed
+ val query = testData.groupBy(Symbol("value")).agg(count_distinct(Symbol("key")))
+ .queryExecution.analyzed
testPartialAggregationPlan(query)
}
test("mixed aggregates are partially aggregated") {
val query =
- testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed
+ testData.groupBy(Symbol("value"))
+ .agg(count(Symbol("value")), count_distinct(Symbol("key")))
+ .queryExecution.analyzed
testPartialAggregationPlan(query)
}
@@ -193,47 +196,49 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
test("efficient terminal limit -> sort should use TakeOrderedAndProject") {
- val query = testData.select('key, 'value).sort('key).limit(2)
+ val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")).limit(2)
val planned = query.queryExecution.executedPlan
assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec])
- assert(planned.output === testData.select('key, 'value).logicalPlan.output)
+ assert(planned.output === testData.select(Symbol("key"), Symbol("value")).logicalPlan.output)
}
test("terminal limit -> project -> sort should use TakeOrderedAndProject") {
- val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2)
+ val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key"))
+ .select(Symbol("value"), Symbol("key")).limit(2)
val planned = query.queryExecution.executedPlan
assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec])
- assert(planned.output === testData.select('value, 'key).logicalPlan.output)
+ assert(planned.output === testData.select(Symbol("value"), Symbol("key")).logicalPlan.output)
}
test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") {
- val query = testData.select('value).limit(2)
+ val query = testData.select(Symbol("value")).limit(2)
val planned = query.queryExecution.sparkPlan
assert(planned.isInstanceOf[CollectLimitExec])
- assert(planned.output === testData.select('value).logicalPlan.output)
+ assert(planned.output === testData.select(Symbol("value")).logicalPlan.output)
}
test("TakeOrderedAndProject can appear in the middle of plans") {
- val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3)
+ val query = testData.select(Symbol("key"), Symbol("value"))
+ .sort(Symbol("key")).limit(2).filter('key === 3)
val planned = query.queryExecution.executedPlan
- assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined)
+ assert(planned.exists(_.isInstanceOf[TakeOrderedAndProjectExec]))
}
test("CollectLimit can appear in the middle of a plan when caching is used") {
- val query = testData.select('key, 'value).limit(2).cache()
+ val query = testData.select(Symbol("key"), Symbol("value")).limit(2).cache()
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
assert(planned.cachedPlan.isInstanceOf[CollectLimitExec])
}
test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") {
withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") {
- val query0 = testData.select('value).orderBy('key).limit(100)
+ val query0 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(100)
val planned0 = query0.queryExecution.executedPlan
- assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined)
+ assert(planned0.exists(_.isInstanceOf[TakeOrderedAndProjectExec]))
- val query1 = testData.select('value).orderBy('key).limit(2000)
+ val query1 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(2000)
val planned1 = query1.queryExecution.executedPlan
- assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty)
+ assert(!planned1.exists(_.isInstanceOf[TakeOrderedAndProjectExec]))
}
}
@@ -432,7 +437,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
test("EnsureRequirements should respect ClusteredDistribution's num partitioning") {
- val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13))
+ val distribution = ClusteredDistribution(Literal(1) :: Nil, requiredNumPartitions = Some(13))
// Number of partitions differ
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13)
val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 9636f33ff6..41a1cd9b29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -20,10 +20,11 @@ import scala.io.Source
import org.apache.spark.sql.{AnalysisException, FastOperator}
import org.apache.spark.sql.catalyst.analysis.UnresolvedNamespace
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan, OneRowRelation, Project, ShowTables, SubqueryAlias}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.execution.command.{ExecutedCommandExec, ShowTablesCommand}
+import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -227,7 +228,8 @@ class QueryExecutionSuite extends SharedSparkSession {
}
Seq("=== Applying Rule org.apache.spark.sql.execution",
"=== Result of Batch Preparations ===").foreach { expectedMsg =>
- assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg)))
+ assert(testAppender.loggingEvents.exists(
+ _.getMessage.getFormattedMessage.contains(expectedMsg)))
}
}
@@ -247,9 +249,7 @@ class QueryExecutionSuite extends SharedSparkSession {
assert(showTablesQe.commandExecuted.isInstanceOf[CommandResult])
assert(showTablesQe.executedPlan.isInstanceOf[CommandResultExec])
val showTablesResultExec = showTablesQe.executedPlan.asInstanceOf[CommandResultExec]
- assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec])
- assert(showTablesResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec]
- .cmd.isInstanceOf[ShowTablesCommand])
+ assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ShowTablesExec])
val project = Project(showTables.output, SubqueryAlias("s", showTables))
val projectQe = qe(project)
@@ -260,9 +260,15 @@ class QueryExecutionSuite extends SharedSparkSession {
assert(projectQe.commandExecuted.children(0).children(0).isInstanceOf[CommandResult])
assert(projectQe.executedPlan.isInstanceOf[CommandResultExec])
val cmdResultExec = projectQe.executedPlan.asInstanceOf[CommandResultExec]
- assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec])
- assert(cmdResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec]
- .cmd.isInstanceOf[ShowTablesCommand])
+ assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ShowTablesExec])
+ }
+
+ test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute methods") {
+ val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan
+ assert(plan.isInstanceOf[CommandResultExec])
+ plan.executeCollect().foreach { row => assert(row.isInstanceOf[UnsafeRow]) }
+ plan.executeTake(10).foreach { row => assert(row.isInstanceOf[UnsafeRow]) }
+ plan.executeTail(10).foreach { row => assert(row.isInstanceOf[UnsafeRow]) }
}
test("SPARK-38198: check specify maxFields when call toFile method") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 751078d08f..21702b6cf5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -51,7 +51,7 @@ abstract class RemoveRedundantSortsSuiteBase
test("remove redundant sorts with limit") {
withTempView("t") {
- spark.range(100).select('id as "key").createOrReplaceTempView("t")
+ spark.range(100).select(Symbol("id") as "key").createOrReplaceTempView("t")
val query =
"""
|SELECT key FROM
@@ -64,8 +64,8 @@ abstract class RemoveRedundantSortsSuiteBase
test("remove redundant sorts with broadcast hash join") {
withTempView("t1", "t2") {
- spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
- spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1")
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2")
val queryTemplate = """
|SELECT /*+ BROADCAST(%s) */ t1.key FROM
@@ -100,8 +100,8 @@ abstract class RemoveRedundantSortsSuiteBase
test("remove redundant sorts with sort merge join") {
withTempView("t1", "t2") {
- spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
- spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1")
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2")
val query = """
|SELECT /*+ MERGE(t1) */ t1.key FROM
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
@@ -123,8 +123,8 @@ abstract class RemoveRedundantSortsSuiteBase
test("cached sorted data doesn't need to be re-sorted") {
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
- val df = spark.range(1000).select('id as "key").sort('key.desc).cache()
- val resorted = df.sort('key.desc)
+ val df = spark.range(1000).select(Symbol("id") as "key").sort(Symbol("key").desc).cache()
+ val resorted = df.sort(Symbol("key").desc)
val sortedAsc = df.sort('key.asc)
checkNumSorts(df, 0)
checkNumSorts(resorted, 0)
@@ -140,7 +140,7 @@ abstract class RemoveRedundantSortsSuiteBase
test("SPARK-33472: shuffled join with different left and right side partition numbers") {
withTempView("t1", "t2") {
- spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1")
+ spark.range(0, 100, 1, 2).select(Symbol("id") as "key").createOrReplaceTempView("t1")
(0 to 100).toDF("key").createOrReplaceTempView("t2")
val queryTemplate = """
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
new file mode 100644
index 0000000000..47679ed786
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+abstract class ReplaceHashWithSortAggSuiteBase
+ extends QueryTest
+ with SharedSparkSession
+ with AdaptiveSparkPlanHelper {
+
+ private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+ assert(collectWithSubqueries(plan) {
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
+ }.length == hashAggCount)
+ assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
+ }
+
+ private def checkAggs(
+ query: String,
+ enabledHashAggCount: Int,
+ enabledSortAggCount: Int,
+ disabledHashAggCount: Int,
+ disabledSortAggCount: Int): Unit = {
+ withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "true") {
+ val df = sql(query)
+ checkNumAggs(df, enabledHashAggCount, enabledSortAggCount)
+ val result = df.collect()
+ withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "false") {
+ val df = sql(query)
+ checkNumAggs(df, disabledHashAggCount, disabledSortAggCount)
+ checkAnswer(df, result)
+ }
+ }
+ }
+
+ test("replace partial hash aggregate with sort aggregate") {
+ withTempView("t") {
+ spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t")
+ Seq("FIRST", "COLLECT_LIST").foreach { aggExpr =>
+ val query =
+ s"""
+ |SELECT key, $aggExpr(key)
+ |FROM
+ |(
+ | SELECT key
+ | FROM t
+ | WHERE key > 10
+ | SORT BY key
+ |)
+ |GROUP BY key
+ """.stripMargin
+ checkAggs(query, 1, 1, 2, 0)
+ }
+ }
+ }
+
+ test("replace partial and final hash aggregate together with sort aggregate") {
+ withTempView("t1", "t2") {
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
+ spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
+ val query =
+ s"""
+ |SELECT key, $aggExpr(key)
+ |FROM
+ |(
+ | SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
+ | FROM t1
+ | JOIN t2
+ | ON t1.key = t2.key
+ |)
+ |GROUP BY key
+ """.stripMargin
+ checkAggs(query, 0, 1, 2, 0)
+ }
+ }
+ }
+
+ test("do not replace hash aggregate if child does not have sort order") {
+ withTempView("t1", "t2") {
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
+ spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
+ val query =
+ s"""
+ |SELECT key, $aggExpr(key)
+ |FROM
+ |(
+ | SELECT /*+ BROADCAST(t1) */ t1.key AS key
+ | FROM t1
+ | JOIN t2
+ | ON t1.key = t2.key
+ |)
+ |GROUP BY key
+ """.stripMargin
+ checkAggs(query, 2, 0, 2, 0)
+ }
+ }
+ }
+
+ test("do not replace hash aggregate if there is no group-by column") {
+ withTempView("t1") {
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
+ val query =
+ s"""
+ |SELECT $aggExpr(key)
+ |FROM t1
+ """.stripMargin
+ checkAggs(query, 2, 0, 2, 0)
+ }
+ }
+ }
+}
+
+class ReplaceHashWithSortAggSuite extends ReplaceHashWithSortAggSuiteBase
+ with DisableAdaptiveExecutionSuite
+
+class ReplaceHashWithSortAggSuiteAE extends ReplaceHashWithSortAggSuiteBase
+ with EnableAdaptiveExecutionSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
index 81e692076b..740c10f17b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
@@ -17,17 +17,21 @@
package org.apache.spark.sql.execution
+import java.util.Locale
import java.util.concurrent.Executors
+import java.util.concurrent.atomic.AtomicInteger
import scala.collection.parallel.immutable.ParRange
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
import org.apache.spark.sql.types._
import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT
class SQLExecutionSuite extends SparkFunSuite {
@@ -157,6 +161,45 @@ class SQLExecutionSuite extends SparkFunSuite {
}
}
}
+
+ test("SPARK-34735: Add modified configs for SQL execution in UI") {
+ val spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("test")
+ .config("k1", "v1")
+ .getOrCreate()
+
+ try {
+ val index = new AtomicInteger(0)
+ spark.sparkContext.addSparkListener(new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+ case start: SparkListenerSQLExecutionStart =>
+ if (index.get() == 0 && hasProject(start)) {
+ assert(!start.modifiedConfigs.contains("k1"))
+ index.incrementAndGet()
+ } else if (index.get() == 1 && hasProject(start)) {
+ assert(start.modifiedConfigs.contains("k2"))
+ assert(start.modifiedConfigs("k2") == "v2")
+ assert(start.modifiedConfigs.contains("redaction.password"))
+ assert(start.modifiedConfigs("redaction.password") == REDACTION_REPLACEMENT_TEXT)
+ index.incrementAndGet()
+ }
+ case _ =>
+ }
+
+ private def hasProject(start: SparkListenerSQLExecutionStart): Boolean =
+ start.physicalPlanDescription.toLowerCase(Locale.ROOT).contains("project")
+ })
+ spark.sql("SELECT 1").collect()
+ spark.sql("SET k2 = v2")
+ spark.sql("SET redaction.password = 123")
+ spark.sql("SELECT 1").collect()
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(index.get() == 2)
+ } finally {
+ spark.stop()
+ }
+ }
}
object SQLExecutionSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
index 08789e63fa..55f1713422 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler.SparkListenerEvent
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.test.TestSparkSession
@@ -28,28 +29,46 @@ import org.apache.spark.util.JsonProtocol
class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
test("SparkPlanGraph backward compatibility: metadata") {
- val SQLExecutionStartJsonString =
- """
- |{
- | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart",
- | "executionId":0,
- | "description":"test desc",
- | "details":"test detail",
- | "physicalPlanDescription":"test plan",
- | "sparkPlanInfo": {
- | "nodeName":"TestNode",
- | "simpleString":"test string",
- | "children":[],
- | "metadata":{},
- | "metrics":[]
- | },
- | "time":0
- |}
+ Seq(true, false).foreach { newExecutionStartEvent =>
+ val event = if (newExecutionStartEvent) {
+ "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart"
+ } else {
+ "org.apache.spark.sql.execution.OldVersionSQLExecutionStart"
+ }
+ val SQLExecutionStartJsonString =
+ s"""
+ |{
+ | "Event":"$event",
+ | "executionId":0,
+ | "description":"test desc",
+ | "details":"test detail",
+ | "physicalPlanDescription":"test plan",
+ | "sparkPlanInfo": {
+ | "nodeName":"TestNode",
+ | "simpleString":"test string",
+ | "children":[],
+ | "metadata":{},
+ | "metrics":[]
+ | },
+ | "time":0,
+ | "modifiedConfigs": {
+ | "k1":"v1"
+ | }
+ |}
""".stripMargin
- val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString))
- val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan",
- new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
- assert(reconstructedEvent == expectedEvent)
+
+ val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString))
+ if (newExecutionStartEvent) {
+ val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail",
+ "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0,
+ Map("k1" -> "v1"))
+ assert(reconstructedEvent == expectedEvent)
+ } else {
+ val expectedOldEvent = OldVersionSQLExecutionStart(0, "test desc", "test detail",
+ "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
+ assert(reconstructedEvent == expectedOldEvent)
+ }
+ }
}
test("SparkListenerSQLExecutionEnd backward compatibility") {
@@ -77,3 +96,12 @@ class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
assert(readBack == event)
}
}
+
+private case class OldVersionSQLExecutionStart(
+ executionId: Long,
+ description: String,
+ details: String,
+ physicalPlanDescription: String,
+ sparkPlanInfo: SparkPlanInfo,
+ time: Long)
+ extends SparkListenerEvent
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index cf699d3234..52aa1066f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Divide}
import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
@@ -220,12 +222,6 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}
- private def assertNoSuchTable(query: String): Unit = {
- intercept[NoSuchTableException] {
- sql(query)
- }
- }
-
private def assertAnalysisError(query: String, message: String): Unit = {
val e = intercept[AnalysisException](sql(query))
assert(e.message.contains(message))
@@ -600,7 +596,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
spark.range(10).write.saveAsTable("add_col")
withView("v") {
sql("CREATE VIEW v AS SELECT * FROM add_col")
- spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col")
+ spark.range(10).select(Symbol("id"), 'id as Symbol("a"))
+ .write.mode("overwrite").saveAsTable("add_col")
checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF())
}
}
@@ -772,7 +769,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
withTable("t") {
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t")
withTempView("v1") {
- sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0")
+ withSQLConf(ANSI_ENABLED.key -> "false") {
+ sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0")
+ }
withSQLConf(
USE_CURRENT_SQL_CONFIGS_FOR_VIEW.key -> "true",
ANSI_ENABLED.key -> "true") {
@@ -845,7 +844,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
sql("CREATE VIEW v2 (c1) AS SELECT c1 FROM t ORDER BY 1 ASC, c1 DESC")
sql("CREATE VIEW v3 (c1, count) AS SELECT c1, count(c1) AS cnt FROM t GROUP BY 1")
sql("CREATE VIEW v4 (a, count) AS SELECT c1 as a, count(c1) AS cnt FROM t GROUP BY a")
- sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid")
+ withSQLConf(ANSI_ENABLED.key -> "false") {
+ sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid")
+ }
withSQLConf(CASE_SENSITIVE.key -> "true") {
checkAnswer(sql("SELECT * FROM v1"), Seq(Row(2), Row(3), Row(1)))
@@ -869,9 +870,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
withSQLConf(CASE_SENSITIVE.key -> "true") {
val e = intercept[AnalysisException] {
sql("SELECT * FROM v1")
- }.getMessage
- assert(e.contains("cannot resolve 'C1' given input columns: " +
- "[spark_catalog.default.t.c1]"))
+ }
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("C1", "spark_catalog.default.t.c1")))
}
withSQLConf(ORDER_BY_ORDINAL.key -> "false") {
checkAnswer(sql("SELECT * FROM v2"), Seq(Row(3), Row(2), Row(1)))
@@ -888,15 +889,15 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
withSQLConf(GROUP_BY_ALIASES.key -> "false") {
val e = intercept[AnalysisException] {
sql("SELECT * FROM v4")
- }.getMessage
- assert(e.contains("cannot resolve 'a' given input columns: " +
- "[spark_catalog.default.t.c1]"))
+ }
+ assert(e.getErrorClass == "MISSING_COLUMN")
+ assert(e.messageParameters.sameElements(Array("a", "spark_catalog.default.t.c1")))
}
withSQLConf(ANSI_ENABLED.key -> "true") {
val e = intercept[ArithmeticException] {
sql("SELECT * FROM v5").collect()
}.getMessage
- assert(e.contains("divide by zero"))
+ assert(e.contains("Division by zero"))
}
}
@@ -906,7 +907,59 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
val e = intercept[ArithmeticException] {
sql("SELECT * FROM v1").collect()
}.getMessage
- assert(e.contains("divide by zero"))
+ assert(e.contains("Division by zero"))
+ }
+ }
+ }
+
+ test("CurrentOrigin is correctly set in and out of the View") {
+ withTable("t") {
+ Seq((1, 1), (2, 2)).toDF("a", "b").write.format("parquet").saveAsTable("t")
+ Seq("VIEW", "TEMPORARY VIEW").foreach { viewType =>
+ val viewId = "v"
+ withView(viewId) {
+ val viewText = "SELECT a + b c FROM t"
+ sql(
+ s"""
+ |CREATE $viewType $viewId AS
+ |-- the body of the view
+ |$viewText
+ |""".stripMargin)
+ val plan = sql("select c / 2.0D d from v").logicalPlan
+ val add = plan.collectFirst {
+ case Project(Seq(Alias(a: Add, _)), _) => a
+ }
+ assert(add.isDefined)
+ val qualifiedName = if (viewType == "VIEW") {
+ s"default.$viewId"
+ } else {
+ viewId
+ }
+ val expectedAddOrigin = Origin(
+ line = Some(1),
+ startPosition = Some(7),
+ startIndex = Some(7),
+ stopIndex = Some(11),
+ sqlText = Some("SELECT a + b c FROM t"),
+ objectType = Some("VIEW"),
+ objectName = Some(qualifiedName)
+ )
+ assert(add.get.origin == expectedAddOrigin)
+
+ val divide = plan.collectFirst {
+ case Project(Seq(Alias(d: Divide, _)), _) => d
+ }
+ assert(divide.isDefined)
+ val expectedDivideOrigin = Origin(
+ line = Some(1),
+ startPosition = Some(7),
+ startIndex = Some(7),
+ stopIndex = Some(14),
+ sqlText = Some("select c / 2.0D d from v"),
+ objectType = None,
+ objectName = None)
+ assert(divide.get.origin == expectedDivideOrigin)
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala
index 4ba2b1703c..1874aa15f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.CatalogFunction
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.Repartition
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.withDefaultTimeZone
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
@@ -45,10 +46,12 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils {
viewName: String,
sqlText: String,
columnNames: Seq[String] = Seq.empty,
+ others: Seq[String] = Seq.empty,
replace: Boolean = false): String = {
val replaceString = if (replace) "OR REPLACE" else ""
val columnString = if (columnNames.nonEmpty) columnNames.mkString("(", ",", ")") else ""
- sql(s"CREATE $replaceString $viewTypeString $viewName $columnString AS $sqlText")
+ val othersString = if (others.nonEmpty) others.mkString(" ") else ""
+ sql(s"CREATE $replaceString $viewTypeString $viewName $columnString $othersString AS $sqlText")
formattedViewName(viewName)
}
@@ -117,11 +120,13 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils {
test("change SQLConf should not change view behavior - ansiEnabled") {
withTable("t") {
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t")
- val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1"))
- withView(viewName) {
- Seq("true", "false").foreach { flag =>
- withSQLConf(ANSI_ENABLED.key -> flag) {
- checkViewOutput(viewName, Seq(Row(null)))
+ withSQLConf(ANSI_ENABLED.key -> "false") {
+ val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1"))
+ withView(viewName) {
+ Seq("true", "false").foreach { flag =>
+ withSQLConf(ANSI_ENABLED.key -> flag) {
+ checkViewOutput(viewName, Seq(Row(null)))
+ }
}
}
}
@@ -378,6 +383,31 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils {
}
}
}
+
+ test("SPARK-37219: time travel is unsupported") {
+ val viewName = createView("testView", "SELECT 1 col")
+ withView(viewName) {
+ val e1 = intercept[AnalysisException](
+ sql(s"SELECT * FROM $viewName VERSION AS OF 1").collect()
+ )
+ assert(e1.message.contains("Cannot time travel views"))
+
+ val e2 = intercept[AnalysisException](
+ sql(s"SELECT * FROM $viewName TIMESTAMP AS OF '2000-10-10'").collect()
+ )
+ assert(e2.message.contains("Cannot time travel views"))
+ }
+ }
+
+ test("SPARK-37569: view should report correct nullability information for nested fields") {
+ val sql = "SELECT id, named_struct('a', id) AS nested FROM RANGE(10)"
+ val viewName = createView("testView", sql)
+ withView(viewName) {
+ val df = spark.sql(sql)
+ val dfFromView = spark.table(viewName)
+ assert(df.schema == dfFromView.schema)
+ }
+ }
}
abstract class TempViewTestSuite extends SQLViewTestSuite {
@@ -400,6 +430,18 @@ abstract class TempViewTestSuite extends SQLViewTestSuite {
}
}
+ test("show create table does not support temp view") {
+ val viewName = "spark_28383"
+ withView(viewName) {
+ createView(viewName, "SELECT 1 AS a")
+ val ex = intercept[AnalysisException] {
+ sql(s"SHOW CREATE TABLE ${formattedViewName(viewName)}")
+ }
+ assert(ex.getMessage.contains(
+ s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view."))
+ }
+ }
+
test("back compatibility: skip cyclic reference check if view is stored as logical plan") {
val viewName = formattedViewName("v")
withSQLConf(STORE_ANALYZED_PLAN_FOR_VIEW.key -> "false") {
@@ -581,4 +623,84 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession {
spark.sessionState.conf.clear()
}
}
+
+ test("SPARK-37266: View text can only be SELECT queries") {
+ withView("v") {
+ sql("CREATE VIEW v AS SELECT 1")
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("v"))
+ val dropView = "DROP VIEW v"
+ // Simulate the behavior of hackers
+ val tamperedTable = table.copy(viewText = Some(dropView))
+ spark.sessionState.catalog.alterTable(tamperedTable)
+ val message = intercept[AnalysisException] {
+ sql("SELECT * FROM v")
+ }.getMessage
+ assert(message.contains(s"Invalid view text: $dropView." +
+ s" The view ${table.qualifiedName} may have been tampered with"))
+ }
+ }
+
+ test("show create table for persisted simple view") {
+ val viewName = "v1"
+ Seq(true, false).foreach { serde =>
+ withView(viewName) {
+ createView(viewName, "SELECT 1 AS a")
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( a) AS SELECT 1 AS a"
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected)
+ }
+ }
+ }
+
+ test("show create table for persisted view with output columns") {
+ val viewName = "v1"
+ Seq(true, false).foreach { serde =>
+ withView(viewName) {
+ createView(viewName, "SELECT 1 AS a, 2 AS b", Seq("a", "b COMMENT 'b column'"))
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)}" +
+ s" ( a, b COMMENT 'b column') AS SELECT 1 AS a, 2 AS b"
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected)
+ }
+ }
+ }
+
+ test("show create table for persisted simple view with table comment and properties") {
+ val viewName = "v1"
+ Seq(true, false).foreach { serde =>
+ withView(viewName) {
+ createView(viewName, "SELECT 1 AS c1, '2' AS c2", Seq("c1 COMMENT 'bla'", "c2"),
+ Seq("COMMENT 'table comment'", "TBLPROPERTIES ( 'prop2' = 'value2', 'prop1' = 'value1')"))
+
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( c1 COMMENT 'bla', c2)" +
+ " COMMENT 'table comment'" +
+ " TBLPROPERTIES ( 'prop1' = 'value1', 'prop2' = 'value2')" +
+ " AS SELECT 1 AS c1, '2' AS c2"
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected)
+ }
+ }
+ }
+
+ test("capture the session time zone config while creating a view") {
+ val viewName = "v1_capture_test"
+ withView(viewName) {
+ assert(get.sessionLocalTimeZone === "America/Los_Angeles")
+ createView(viewName,
+ """select hour(ts) as H from (
+ | select cast('2022-01-01T00:00:00.000 America/Los_Angeles' as timestamp) as ts
+ |)""".stripMargin, Seq("H"))
+ withDefaultTimeZone(java.time.ZoneId.of("UTC-09:00")) {
+ withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "UTC-10:00") {
+ checkAnswer(sql(s"select H from $viewName"), Row(0))
+ }
+ }
+ }
+ }
+
+ def getShowCreateDDL(view: String, serde: Boolean = false): String = {
+ val result = if (serde) {
+ sql(s"SHOW CREATE TABLE $view AS SERDE")
+ } else {
+ sql(s"SHOW CREATE TABLE $view")
+ }
+ result.head().getString(0).split("\n").map(_.trim).mkString(" ")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
index 9f70c8aeca..da05373125 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala
@@ -703,27 +703,55 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext {
test("splitSizeListByTargetSize") {
val targetSize = 100
+ val smallPartitionFactor1 = ShufflePartitionsUtil.SMALL_PARTITION_FACTOR
// merge the small partitions at the beginning/end
- val sizeList1 = Seq[Long](15, 90, 15, 15, 15, 90, 15)
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList1, targetSize).toSeq ==
+ val sizeList1 = Array[Long](15, 90, 15, 15, 15, 90, 15)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList1, targetSize, smallPartitionFactor1).toSeq ==
Seq(0, 2, 5))
// merge the small partitions in the middle
- val sizeList2 = Seq[Long](30, 15, 90, 10, 90, 15, 30)
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList2, targetSize).toSeq ==
+ val sizeList2 = Array[Long](30, 15, 90, 10, 90, 15, 30)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList2, targetSize, smallPartitionFactor1).toSeq ==
Seq(0, 2, 4, 5))
// merge small partitions if the partition itself is smaller than
// targetSize * SMALL_PARTITION_FACTOR
- val sizeList3 = Seq[Long](15, 1000, 15, 1000)
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList3, targetSize).toSeq ==
+ val sizeList3 = Array[Long](15, 1000, 15, 1000)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList3, targetSize, smallPartitionFactor1).toSeq ==
Seq(0, 3))
// merge small partitions if the combined size is smaller than
// targetSize * MERGED_PARTITION_FACTOR
- val sizeList4 = Seq[Long](35, 75, 90, 20, 35, 25, 35)
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList4, targetSize).toSeq ==
+ val sizeList4 = Array[Long](35, 75, 90, 20, 35, 25, 35)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList4, targetSize, smallPartitionFactor1).toSeq ==
Seq(0, 2, 3))
+
+ val smallPartitionFactor2 = 0.5
+ // merge last two partition if their size is not bigger than smallPartitionFactor * target
+ val sizeList5 = Array[Long](50, 50, 40, 5)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList5, targetSize, smallPartitionFactor2).toSeq ==
+ Seq(0))
+
+ val sizeList6 = Array[Long](40, 5, 50, 45)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList6, targetSize, smallPartitionFactor2).toSeq ==
+ Seq(0))
+
+ // do not merge
+ val sizeList7 = Array[Long](50, 50, 10, 40, 5)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList7, targetSize, smallPartitionFactor2).toSeq ==
+ Seq(0, 2))
+
+ val sizeList8 = Array[Long](10, 40, 5, 50, 50)
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize(
+ sizeList8, targetSize, smallPartitionFactor2).toSeq ==
+ Seq(0, 3))
}
test("SPARK-35923: Coalesce empty partition with mixed CoalescedPartitionSpec and" +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 6a4f3f6264..7c74423af6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -22,6 +22,7 @@ import scala.util.Random
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -43,13 +44,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
checkAnswer(
input.toDF("a", "b", "c"),
- (child: SparkPlan) => SortExec('a.asc :: 'b.asc :: Nil, global = true, child = child),
+ (child: SparkPlan) => SortExec(Symbol("a").asc :: Symbol("b").asc :: Nil,
+ global = true, child = child),
input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
sortAnswers = false)
checkAnswer(
input.toDF("a", "b", "c"),
- (child: SparkPlan) => SortExec('b.asc :: 'a.asc :: Nil, global = true, child = child),
+ (child: SparkPlan) => SortExec(Symbol("b").asc :: Symbol("a").asc :: Nil,
+ global = true, child = child),
input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
sortAnswers = false)
}
@@ -58,9 +61,9 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"),
(child: SparkPlan) =>
- GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)),
+ GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)),
(child: SparkPlan) =>
- GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)),
sortAnswers = false
)
}
@@ -69,15 +72,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
(child: SparkPlan) =>
- GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)),
+ GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)),
(child: SparkPlan) =>
- GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)),
sortAnswers = false
)
}
test("sorting does not crash for large inputs") {
- val sortOrder = 'a.asc :: Nil
+ val sortOrder = Symbol("a").asc :: Nil
val stringLength = 1024 * 1024 * 2
checkThatPlansAgree(
Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
@@ -91,8 +94,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child),
- (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child),
+ (child: SparkPlan) => SortExec(Symbol("a").asc :: Nil, global = true, child = child),
+ (child: SparkPlan) => ReferenceSort(Symbol("a").asc :: Nil, global = true, child),
sortAnswers = false)
}
}
@@ -105,11 +108,31 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
)
checkAnswer(
input.toDF("a", "b", "c"),
- (child: SparkPlan) => SortExec(Stream('a.asc, 'b.asc, 'c.asc), global = true, child = child),
+ (child: SparkPlan) => SortExec(Stream(Symbol("a").asc, 'b.asc, 'c.asc),
+ global = true, child = child),
input.sortBy(t => (t._1, t._2, t._3)).map(Row.fromTuple),
sortAnswers = false)
}
+ test("SPARK-40089: decimal values sort correctly") {
+ val input = Seq(
+ BigDecimal("999999999999999999.50"),
+ BigDecimal("1.11"),
+ BigDecimal("999999999999999999.49")
+ )
+ // The range partitioner does the right thing. If there are too many
+ // shuffle partitions the error might not always show up.
+ withSQLConf("spark.sql.shuffle.partitions" -> "1") {
+ val inputDf = spark.createDataFrame(sparkContext.parallelize(input.map(v => Row(v)), 1),
+ StructType(StructField("a", DecimalType(20, 2)) :: Nil))
+ checkAnswer(
+ inputDf,
+ (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child),
+ input.sorted.map(Row(_)),
+ sortAnswers = false)
+ }
+ }
+
// Test sorting on different data types
for (
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
@@ -124,12 +147,16 @@ class SortSuite extends SparkPlanTest with SharedSparkSession {
sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
- checkThatPlansAgree(
- inputDf,
- p => SortExec(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23),
- ReferenceSort(sortOrder, global = true, _: SparkPlan),
- sortAnswers = false
- )
+ Seq(true, false).foreach { enableRadix =>
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> enableRadix.toString) {
+ checkThatPlansAgree(
+ inputDf,
+ p => SortExec(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23),
+ ReferenceSort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index 04b589de7c..12d311d683 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -18,11 +18,16 @@
package org.apache.spark.sql.execution
import org.apache.spark.SparkEnv
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.vectorized.ColumnarBatch
class SparkPlanSuite extends QueryTest with SharedSparkSession {
@@ -110,6 +115,16 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
"should have been replaced by aggregate in the optimizer"))
}
+ test("SPARK-37221: The collect-like API in SparkPlan should support columnar output") {
+ val emptyResults = ColumnarOp(LocalTableScanExec(Nil, Nil)).toRowBased.executeCollect()
+ assert(emptyResults.isEmpty)
+
+ val relation = LocalTableScanExec(
+ Seq(AttributeReference("val", IntegerType)()), Seq(InternalRow(1)))
+ val nonEmpty = ColumnarOp(relation).toRowBased.executeCollect()
+ assert(nonEmpty === relation.executeCollect())
+ }
+
test("SPARK-37779: ColumnarToRowExec should be canonicalizable after being (de)serialized") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") {
withTempPath { path =>
@@ -129,3 +144,13 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
}
}
}
+
+case class ColumnarOp(child: SparkPlan) extends UnaryExecNode {
+ override val supportsColumnar: Boolean = true
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] =
+ RowToColumnarExec(child).executeColumnar()
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarOp =
+ copy(child = newChild)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index ba6dd170d8..49f65ab51c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -47,7 +47,7 @@ class SparkSqlParserSuite extends AnalysisTest {
}
private def intercept(sqlCommand: String, messages: String*): Unit =
- interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)
+ interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)()
test("Checks if SET/RESET can parse all the configurations") {
// Force to build static SQL configurations
@@ -168,11 +168,11 @@ class SparkSqlParserSuite extends AnalysisTest {
intercept("SET a=1;2;;", expectedErrMsg)
intercept("SET a b=`1;;`",
- "'a b' is an invalid property key, please use quotes, e.g. SET `a b`=`1;;`")
+ "\"a b\" is an invalid property key, please use quotes, e.g. SET \"a b\"=\"1;;\"")
intercept("SET `a`=1;2;;",
- "'1;2;;' is an invalid property value, please use quotes, e.g." +
- " SET `a`=`1;2;;`")
+ "\"1;2;;\" is an invalid property value, please use quotes, e.g." +
+ " SET \"a\"=\"1;2;;\"")
}
test("refresh resource") {
@@ -312,7 +312,7 @@ class SparkSqlParserSuite extends AnalysisTest {
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", StringType)()),
- Project(Seq('a, 'b, 'c),
+ Project(Seq(Symbol("a"), Symbol("b"), Symbol("c")),
UnresolvedRelation(TableIdentifier("testData"))),
ioSchema))
@@ -336,9 +336,9 @@ class SparkSqlParserSuite extends AnalysisTest {
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false),
Literal(10)),
Aggregate(
- Seq('a),
+ Seq(Symbol("a")),
Seq(
- 'a,
+ Symbol("a"),
UnresolvedAlias(
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None),
UnresolvedAlias(
@@ -363,12 +363,12 @@ class SparkSqlParserSuite extends AnalysisTest {
AttributeReference("c", StringType)()),
WithWindowDefinition(
Map("w" -> WindowSpecDefinition(
- Seq('a),
- Seq(SortOrder('b, Ascending, NullsFirst, Seq.empty)),
+ Seq(Symbol("a")),
+ Seq(SortOrder(Symbol("b"), Ascending, NullsFirst, Seq.empty)),
UnspecifiedFrame)),
Project(
Seq(
- 'a,
+ Symbol("a"),
UnresolvedAlias(
UnresolvedWindowExpression(
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false),
@@ -403,9 +403,9 @@ class SparkSqlParserSuite extends AnalysisTest {
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false),
Literal(10)),
Aggregate(
- Seq('a, 'myCol, 'myCol2),
+ Seq(Symbol("a"), Symbol("myCol"), Symbol("myCol2")),
Seq(
- 'a,
+ Symbol("a"),
UnresolvedAlias(
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None),
UnresolvedAlias(
@@ -415,7 +415,7 @@ class SparkSqlParserSuite extends AnalysisTest {
UnresolvedGenerator(
FunctionIdentifier("explode"),
Seq(UnresolvedAttribute("myTable.myCol"))),
- Nil, false, Option("mytable2"), Seq('myCol2),
+ Nil, false, Option("mytable2"), Seq(Symbol("myCol2")),
Generate(
UnresolvedGenerator(
FunctionIdentifier("explode"),
@@ -423,7 +423,7 @@ class SparkSqlParserSuite extends AnalysisTest {
Seq(
UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), false)),
false))),
- Nil, false, Option("mytable"), Seq('myCol),
+ Nil, false, Option("mytable"), Seq(Symbol("myCol")),
UnresolvedRelation(TableIdentifier("testData")))))),
ioSchema))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala
index c025670fb8..3718b3a3c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala
@@ -49,7 +49,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark {
val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols)
val cols = (0 until numCols).map { idx =>
- from_json('value, schema).getField(s"col$idx")
+ from_json(Symbol("value"), schema).getField(s"col$idx")
}
Seq(
@@ -88,7 +88,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark {
val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols)
val predicate = (0 until numCols).map { idx =>
- (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr
+ (from_json(Symbol("value"), schema).getField(s"col$idx") >= Literal(100000)).expr
}.asInstanceOf[Seq[Expression]].reduce(Or)
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index 6ec5c6287e..ce48945e52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -58,7 +58,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession {
private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)
val limit = 250
- val sortOrder = 'a.desc :: 'b.desc :: Nil
+ val sortOrder = Symbol("a").desc :: Symbol("b").desc :: Nil
test("TakeOrderedAndProject.doExecute without project") {
withClue(s"seed = $seed") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala
index 2f626f7769..73c4e4c3e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala
@@ -40,15 +40,26 @@ class WholeStageCodegenSparkSubmitSuite extends SparkSubmitTestUtils
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
// HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
- // settings of UseCompressedOops JVM option.
+ // settings of UseCompressedClassPointers JVM option.
val argsForSparkSubmit = Seq(
"--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
"--master", "local-cluster[1,1,1024]",
"--driver-memory", "1g",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
- "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
- "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
+ // SPARK-37008: The results of `Platform.BYTE_ARRAY_OFFSET` using different Java versions
+ // and different args as follows table:
+ // +------------------------------+--------+---------+
+ // | |Java 8 |Java 17 |
+ // +------------------------------+--------+---------+
+ // |-XX:-UseCompressedOops | 24 | 16 |
+ // |-XX:+UseCompressedOops | 16 | 16 |
+ // |-XX:-UseCompressedClassPointers| 24 | 24 |
+ // |-XX:+UseCompressedClassPointers| 16 | 16 |
+ // +-------------------------------+-------+---------+
+ // So SPARK-37008 replace `UseCompressedOops` with `UseCompressedClassPointers`.
+ "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedClassPointers",
+ "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedClassPointers",
"--conf", "spark.sql.adaptive.enabled=false",
unusedJar.toString)
runSparkSubmit(argsForSparkSubmit, timeout = 3.minutes)
@@ -60,7 +71,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {
var spark: SparkSession = _
def main(args: Array[String]): Unit = {
- TestUtils.configTestLog4j("INFO")
+ TestUtils.configTestLog4j2("INFO")
spark = SparkSession.builder().getOrCreate()
@@ -73,7 +84,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {
val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
.groupBy(array(col("v"))).agg(count(col("*")))
val plan = df.queryExecution.executedPlan
- assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
+ assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]))
val expectedAnswer =
Row(Array(0), 7178) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c2b22125cb..2be915f000 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
@@ -37,19 +37,30 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
test("range/filter should be combined") {
val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
val plan = df.queryExecution.executedPlan
- assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
+ assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]))
assert(df.collect() === Array(Row(2)))
}
- test("Aggregate should be included in WholeStageCodegen") {
+ test("HashAggregate should be included in WholeStageCodegen") {
val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
val plan = df.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
}
+ test("SortAggregate should be included in WholeStageCodegen") {
+ val df = spark.range(10).agg(max(col("id")), avg(col("id")))
+ withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
+ val plan = df.queryExecution.executedPlan
+ assert(plan.exists(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
+ assert(df.collect() === Array(Row(9, 4.5)))
+ }
+ }
+
testWithWholeStageCodegenOnAndOff("GenerateExec should be" +
" included in WholeStageCodegen") { codegenEnabled =>
import testImplicits._
@@ -59,22 +70,22 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// Array - explode
var expDF = df.select($"name", explode($"knownLanguages"), $"properties")
var plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF, Array(Row("James", "Java", Map("hair" -> "black", "eye" -> "brown")),
Row("James", "Scala", Map("hair" -> "black", "eye" -> "brown"))))
// Map - explode
expDF = df.select($"name", $"knownLanguages", explode($"properties"))
plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF,
Array(Row("James", List("Java", "Scala"), "hair", "black"),
Row("James", List("Java", "Scala"), "eye", "brown")))
@@ -82,33 +93,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// Array - posexplode
expDF = df.select($"name", posexplode($"knownLanguages"))
plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF,
Array(Row("James", 0, "Java"), Row("James", 1, "Scala")))
// Map - posexplode
expDF = df.select($"name", posexplode($"properties"))
plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF,
Array(Row("James", 0, "hair", "black"), Row("James", 1, "eye", "brown")))
// Array - explode , selecting all columns
expDF = df.select($"*", explode($"knownLanguages"))
plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF,
Array(Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Java"),
Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Scala")))
@@ -116,11 +127,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// Map - explode, selecting all columns
expDF = df.select($"*", explode($"properties"))
plan = expDF.queryExecution.executedPlan
- assert(plan.find {
+ assert(plan.exists {
case stage: WholeStageCodegenExec =>
- stage.find(_.isInstanceOf[GenerateExec]).isDefined
+ stage.exists(_.isInstanceOf[GenerateExec])
case _ => !codegenEnabled.toBoolean
- }.isDefined)
+ })
checkAnswer(expDF,
Array(
Row("James", List("Java", "Scala"),
@@ -129,12 +140,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Map("hair" -> "black", "eye" -> "brown"), "eye", "brown")))
}
- test("Aggregate with grouping keys should be included in WholeStageCodegen") {
+ test("HashAggregate with grouping keys should be included in WholeStageCodegen") {
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2)
val plan = df.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]))
assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1)))
}
@@ -143,13 +154,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val schema = new StructType().add("k", IntegerType).add("v", StringType)
val smallDF = spark.createDataFrame(rdd, schema)
val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
- assert(df.queryExecution.executedPlan.find(p =>
+ assert(df.queryExecution.executedPlan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]))
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
- test("ShuffledHashJoin should be included in WholeStageCodegen") {
+ test("Inner ShuffledHashJoin should be included in WholeStageCodegen") {
val df1 = spark.range(5).select($"id".as("k1"))
val df2 = spark.range(15).select($"id".as("k2"))
val df3 = spark.range(6).select($"id".as("k3"))
@@ -171,6 +182,56 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
}
+ test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") {
+ val df1 = spark.range(5).select($"id".as("k1"))
+ val df2 = spark.range(10).select($"id".as("k2"))
+ val df3 = spark.range(3).select($"id".as("k3"))
+
+ Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint =>
+ // test one join with unique key from build side
+ val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+ assert(joinUniqueDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
+ Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
+ assert(joinUniqueDF.count() === 10)
+
+ // test one join with non-unique key from build side
+ val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer")
+ assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
+ Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
+
+ // test one join with non-equi condition
+ val joinWithNonEquiDF = df1.join(df2.hint(hint),
+ $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
+ assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
+ Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
+ Row(null, 5)))
+
+ // test two joins
+ val twoJoinsDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
+ .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer")
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF,
+ Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
+ Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null),
+ Row(null, 9, null), Row(null, null, 1)))
+ }
+ }
+
test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") {
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
@@ -373,9 +434,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
test("Sort should be included in WholeStageCodegen") {
val df = spark.range(3, 0, -1).toDF().sort(col("id"))
val plan = df.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
@@ -384,27 +445,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val ds = spark.range(10).map(_.toString)
val plan = ds.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]))
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
}
test("typed filter should be included in WholeStageCodegen") {
val ds = spark.range(10).filter(_ % 2 == 0)
val plan = ds.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]))
assert(ds.collect() === Array(0, 2, 4, 6, 8))
}
test("back-to-back typed filter should be included in WholeStageCodegen") {
val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
val plan = ds.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]))
assert(ds.collect() === Array(0, 6))
}
@@ -417,7 +478,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val planInt = dsIntFilter.queryExecution.executedPlan
assert(planInt.collect {
case WholeStageCodegenExec(FilterExec(_,
- ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => ()
+ InputAdapter(_: InMemoryTableScanExec))) => ()
}.length == 1)
assert(dsIntFilter.collect() === Array(1, 2))
@@ -456,10 +517,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.select("int")
val plan = df.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(!plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.children(0)
- .isInstanceOf[SortMergeJoinExec]).isEmpty)
+ .isInstanceOf[SortMergeJoinExec]))
assert(df.collect() === Array(Row(1), Row(2)))
}
}
@@ -512,7 +573,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
import testImplicits._
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*)
+ val df = spark.range(10).select(Seq.tabulate(201) {i => (Symbol("id") + i).as(s"c$i")} : _*)
df.write.mode(SaveMode.Overwrite).parquet(path)
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202",
@@ -529,7 +590,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
test("Control splitting consume function by operators with config") {
import testImplicits._
- val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)
+ val df = spark.range(10).select(Seq.tabulate(2) {i => (Symbol("id") + i).as(s"c$i")} : _*)
Seq(true, false).foreach { config =>
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") {
@@ -578,9 +639,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val df = spark.range(100)
val join = df.join(df, "id")
val plan = join.queryExecution.executedPlan
- assert(plan.find(p =>
+ assert(!plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty,
+ p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0),
"codegen stage IDs should be preserved through ReuseExchange")
checkAnswer(join, df.toDF)
}
@@ -592,9 +653,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
// the same query run twice should produce identical code, which would imply a hit in
// the generated code cache.
- val ds1 = spark.range(3).select('id + 2)
+ val ds1 = spark.range(3).select(Symbol("id") + 2)
val code1 = genCode(ds1)
- val ds2 = spark.range(3).select('id + 2)
+ val ds2 = spark.range(3).select(Symbol("id") + 2)
val code2 = genCode(ds2) // same query shape as above, deliberately
assert(code1 == code2, "Should produce same code")
}
@@ -639,10 +700,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
- assert(distinctWithId.queryExecution.executedPlan.collectFirst {
+ assert(distinctWithId.queryExecution.executedPlan.exists {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true
- }.isDefined)
+ case _ => false
+ })
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))
// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
@@ -650,10 +712,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
- assert(groupByWithId.queryExecution.executedPlan.collectFirst {
+ assert(groupByWithId.queryExecution.executedPlan.exists {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true
- }.isDefined)
+ case _ => false
+ })
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0)))
}
}
@@ -679,11 +742,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// HashAggregateExec supports WholeStageCodegen and it's the parent of
// LocalTableScanExec so LocalTableScanExec should be within a WholeStageCodegen domain.
assert(
- executedPlan.find {
+ executedPlan.exists {
case WholeStageCodegenExec(
HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true
case _ => false
- }.isDefined,
+ },
"LocalTableScanExec should be within a WholeStageCodegen domain.")
}
@@ -699,10 +762,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
"SELECT AVG(v) FROM VALUES(1) t(v)",
// Tet case with keys
"SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query =>
- val errMsg = intercept[IllegalStateException] {
+ val e = intercept[IllegalStateException] {
sql(query).collect
- }.getMessage
- assert(errMsg.contains(expectedErrMsg))
+ }
+ assert(e.getMessage.contains(expectedErrMsg))
}
}
}
@@ -721,10 +784,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// Tet case with keys
"SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " +
"GROUP BY k").foreach { query =>
- val e = intercept[Exception] {
+ val e = intercept[IllegalStateException] {
sql(query).collect
}
- assert(e.isInstanceOf[IllegalStateException])
assert(e.getMessage.contains(expectedErrMsg))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 7ae162ca8a..0055b94fa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -20,19 +20,22 @@ package org.apache.spark.sql.execution.adaptive
import java.io.File
import java.net.URI
-import org.apache.log4j.Level
+import org.apache.logging.log4j.Level
import org.scalatest.PrivateMethodTester
+import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.{CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec}
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
-import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
@@ -102,6 +105,12 @@ class AdaptiveQueryExecSuite
}
}
+ def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
+ collect(plan) {
+ case j: BroadcastNestedLoopJoinExec => j
+ }
+ }
+
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
@@ -126,6 +135,18 @@ class AdaptiveQueryExecSuite
}
}
+ private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
+ collect(plan) {
+ case agg: BaseAggregateExec => agg
+ }
+ }
+
+ private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
+ collect(plan) {
+ case l: CollectLimitExec => l
+ }
+ }
+
private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectWithSubqueries(plan) {
case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
@@ -267,11 +288,12 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
- val df1 = spark.range(10).withColumn("a", 'id)
- val df2 = spark.range(10).withColumn("b", 'id)
+ val df1 = spark.range(10).withColumn("a", Symbol("id"))
+ val df2 = spark.range(10).withColumn("b", Symbol("id"))
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
- .groupBy('a).count()
+ val testDf = df1.where(Symbol("a") > 10)
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer")
+ .groupBy(Symbol("a")).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
@@ -283,8 +305,9 @@ class AdaptiveQueryExecSuite
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
- .groupBy('a).count()
+ val testDf = df1.where(Symbol("a") > 10)
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer")
+ .groupBy(Symbol("a")).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
@@ -668,6 +691,47 @@ class AdaptiveQueryExecSuite
}
}
}
+ test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" +
+ " partitions on broadcast side") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") {
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config.
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT * FROM (select * from testData where value = '1') td" +
+ " right outer join testData2 ON key = a")
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+ }
+ }
+ }
+
+ test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" +
+ " partitions on outer/left side") {
+ // if the right side is completed first and the left side is still being executed,
+ // the right side does not know whether there are many empty partitions on the left side,
+ // so there is no demote, and then the right side is broadcast in the planning stage.
+ // so retry several times here to avoid unit test failure.
+ eventually(timeout(15.seconds), interval(500.milliseconds)) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") {
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config.
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT * FROM (select * from testData where value = '1') td" +
+ " left outer join testData2 ON key = a")
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.isEmpty)
+ }
+ }
+ }
+ }
test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") {
var numStages = 0
@@ -738,17 +802,17 @@ class AdaptiveQueryExecSuite
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .when('id >= 750, 1000)
- .otherwise('id).as("key1"),
- 'id as "value1")
+ when(Symbol("id") < 250, 249)
+ .when(Symbol("id") >= 750, 1000)
+ .otherwise(Symbol("id")).as("key1"),
+ Symbol("id") as "value1")
.createOrReplaceTempView("skewData1")
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .otherwise('id).as("key2"),
- 'id as "value2")
+ when(Symbol("id") < 250, 249)
+ .otherwise(Symbol("id")).as("key2"),
+ Symbol("id") as "value2")
.createOrReplaceTempView("skewData2")
def checkSkewJoin(
@@ -806,11 +870,11 @@ class AdaptiveQueryExecSuite
df1.write.parquet(tableDir.getAbsolutePath)
val aggregated = spark.table("bucketed_table").groupBy("i").count()
- val error = intercept[Exception] {
+ val error = intercept[SparkException] {
aggregated.count()
}
- assert(error.toString contains "Invalid bucket file")
- assert(error.getSuppressed.size === 0)
+ assert(error.getErrorClass === "INVALID_BUCKET_FILE")
+ assert(error.getMessage contains "Invalid bucket file")
}
}
}
@@ -842,7 +906,7 @@ class AdaptiveQueryExecSuite
}
}
assert(!testAppender.loggingEvents
- .exists(msg => msg.getRenderedMessage.contains(
+ .exists(msg => msg.getMessage.getFormattedMessage.contains(
s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" +
s" enabled but is not supported for")))
}
@@ -850,6 +914,7 @@ class AdaptiveQueryExecSuite
test("test log level") {
def verifyLog(expectedLevel: Level): Unit = {
val logAppender = new LogAppender("adaptive execution")
+ logAppender.setThreshold(expectedLevel)
withLogAppender(
logAppender,
loggerNames = Seq(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)),
@@ -863,7 +928,7 @@ class AdaptiveQueryExecSuite
Seq("Plan changed", "Final plan").foreach { msg =>
assert(
logAppender.loggingEvents.exists { event =>
- event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel
+ event.getMessage.getFormattedMessage.contains(msg) && event.getLevel == expectedLevel
})
}
}
@@ -982,17 +1047,17 @@ class AdaptiveQueryExecSuite
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .when('id >= 750, 1000)
- .otherwise('id).as("key1"),
- 'id as "value1")
+ when(Symbol("id") < 250, 249)
+ .when(Symbol("id") >= 750, 1000)
+ .otherwise(Symbol("id")).as("key1"),
+ Symbol("id") as "value1")
.createOrReplaceTempView("skewData1")
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .otherwise('id).as("key2"),
- 'id as "value2")
+ when(Symbol("id") < 250, 249)
+ .otherwise(Symbol("id")).as("key2"),
+ Symbol("id") as "value2")
.createOrReplaceTempView("skewData2")
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
@@ -1070,7 +1135,7 @@ class AdaptiveQueryExecSuite
test("AQE should set active session during execution") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val df = spark.range(10).select(sum('id))
+ val df = spark.range(10).select(sum(Symbol("id")))
assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
SparkSession.setActiveSession(null)
checkAnswer(df, Seq(Row(45)))
@@ -1097,7 +1162,7 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
try {
spark.experimental.extraStrategies = TestStrategy :: Nil
- val df = spark.range(10).groupBy('id).count()
+ val df = spark.range(10).groupBy(Symbol("id")).count()
df.collect()
} finally {
spark.experimental.extraStrategies = Nil
@@ -1422,6 +1487,56 @@ class AdaptiveQueryExecSuite
}
}
+ test("SPARK-35442: Support propagate empty relation through aggregate") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key")
+ assert(!plan1.isInstanceOf[LocalTableScanExec])
+ assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1")
+ assert(!plan2.isInstanceOf[LocalTableScanExec])
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+
+ val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult(
+ "SELECT count(*) FROM testData WHERE value = 'no_match'")
+ assert(!plan3.isInstanceOf[LocalTableScanExec])
+ assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec])
+ }
+ }
+
+ test("SPARK-35442: Support propagate empty relation through union") {
+ def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = {
+ assert(
+ collect(plan) {
+ case u: UnionExec => u
+ }.size == numUnion)
+ }
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+ |UNION ALL
+ |SELECT key, 1 FROM testData
+ |""".stripMargin)
+ checkNumUnion(plan1, 1)
+ checkNumUnion(adaptivePlan1, 0)
+ assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+ |UNION ALL
+ |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match'
+ |""".stripMargin)
+ checkNumUnion(plan2, 1)
+ checkNumUnion(adaptivePlan2, 0)
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+ }
+ }
+
test("SPARK-32753: Only copy tags to node with no tags") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
withTempView("v1") {
@@ -1450,7 +1565,8 @@ class AdaptiveQueryExecSuite
"=== Result of Batch AQE Post Stage Creation ===",
"=== Result of Batch AQE Replanning ===",
"=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg =>
- assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg)))
+ assert(testAppender.loggingEvents.exists(
+ _.getMessage.getFormattedMessage.contains(expectedMsg)))
}
}
}
@@ -1502,7 +1618,7 @@ class AdaptiveQueryExecSuite
test("SPARK-33494: Do not use local shuffle read for repartition") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val df = spark.table("testData").repartition('key)
+ val df = spark.table("testData").repartition(Symbol("key"))
df.collect()
// local shuffle read breaks partitioning and shouldn't be used for repartition operation
// which is specified by users.
@@ -1581,28 +1697,28 @@ class AdaptiveQueryExecSuite
| SELECT * FROM testData WHERE key = 1
|)
|RIGHT OUTER JOIN testData2
- |ON value = b
+ |ON CAST(value AS INT) = b
""".stripMargin)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
// Repartition with no partition num specified.
- checkBHJ(df.repartition('b),
+ checkBHJ(df.repartition(Symbol("b")),
// The top shuffle from repartition is optimized out.
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true)
// Repartition with default partition num (5 in test env) specified.
- checkBHJ(df.repartition(5, 'b),
+ checkBHJ(df.repartition(5, Symbol("b")),
// The top shuffle from repartition is optimized out
// The final plan must have 5 partitions, no optimization can be made to the probe side.
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false)
// Repartition with non-default partition num specified.
- checkBHJ(df.repartition(4, 'b),
+ checkBHJ(df.repartition(4, Symbol("b")),
// The top shuffle from repartition is not optimized out
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
// Repartition by col and project away the partition cols
- checkBHJ(df.repartition('b).select('key),
+ checkBHJ(df.repartition(Symbol("b")).select(Symbol("key")),
// The top shuffle from repartition is not optimized out
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
}
@@ -1614,23 +1730,23 @@ class AdaptiveQueryExecSuite
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
// Repartition with no partition num specified.
- checkSMJ(df.repartition('b),
+ checkSMJ(df.repartition(Symbol("b")),
// The top shuffle from repartition is optimized out.
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true)
// Repartition with default partition num (5 in test env) specified.
- checkSMJ(df.repartition(5, 'b),
+ checkSMJ(df.repartition(5, Symbol("b")),
// The top shuffle from repartition is optimized out.
// The final plan must have 5 partitions, can't do coalesced read.
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false)
// Repartition with non-default partition num specified.
- checkSMJ(df.repartition(4, 'b),
+ checkSMJ(df.repartition(4, Symbol("b")),
// The top shuffle from repartition is not optimized out.
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
// Repartition by col and project away the partition cols
- checkSMJ(df.repartition('b).select('key),
+ checkSMJ(df.repartition(Symbol("b")).select(Symbol("key")),
// The top shuffle from repartition is not optimized out.
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
}
@@ -1667,6 +1783,7 @@ class AdaptiveQueryExecSuite
test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") {
val testAppender = new LogAppender("aqe query stage materialization order test")
+ testAppender.setThreshold(Level.DEBUG)
val df = spark.range(1000).select($"id" % 26, $"id" % 10)
.toDF("index", "pv")
val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString))
@@ -1683,7 +1800,7 @@ class AdaptiveQueryExecSuite
}
}
val materializeLogs = testAppender.loggingEvents
- .map(_.getRenderedMessage)
+ .map(_.getMessage.getFormattedMessage)
.filter(_.startsWith("Materialize query stage"))
.toArray
assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec"))
@@ -1722,10 +1839,94 @@ class AdaptiveQueryExecSuite
}
}
+ test("SPARK-34980: Support coalesce partition through union") {
+ def checkResultPartition(
+ df: Dataset[Row],
+ numUnion: Int,
+ numShuffleReader: Int,
+ numPartition: Int): Unit = {
+ df.collect()
+ assert(collect(df.queryExecution.executedPlan) {
+ case u: UnionExec => u
+ }.size == numUnion)
+ assert(collect(df.queryExecution.executedPlan) {
+ case r: AQEShuffleReadExec => r
+ }.size === numShuffleReader)
+ assert(df.rdd.partitions.length === numPartition)
+ }
+
+ Seq(true, false).foreach { combineUnionEnabled =>
+ val combineUnionConfig = if (combineUnionEnabled) {
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ""
+ } else {
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+ "org.apache.spark.sql.catalyst.optimizer.CombineUnions"
+ }
+ // advisory partition size 1048576 has no special meaning, just a big enough value
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+ combineUnionConfig) {
+ withTempView("t1", "t2") {
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 2)
+ .toDF().createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 4)
+ .toDF().createOrReplaceTempView("t2")
+
+ // positive test that could be coalesced
+ checkResultPartition(
+ sql("""
+ |SELECT key, count(*) FROM t1 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t2
+ """.stripMargin),
+ numUnion = 1,
+ numShuffleReader = 1,
+ numPartition = 1 + 4)
+
+ checkResultPartition(
+ sql("""
+ |SELECT key, count(*) FROM t1 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t2
+ |UNION ALL
+ |SELECT * FROM t1
+ """.stripMargin),
+ numUnion = if (combineUnionEnabled) 1 else 2,
+ numShuffleReader = 1,
+ numPartition = 1 + 4 + 2)
+
+ checkResultPartition(
+ sql("""
+ |SELECT /*+ merge(t2) */ t1.key, t2.key FROM t1 JOIN t2 ON t1.key = t2.key
+ |UNION ALL
+ |SELECT key, count(*) FROM t2 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t1
+ """.stripMargin),
+ numUnion = if (combineUnionEnabled) 1 else 2,
+ numShuffleReader = 3,
+ numPartition = 1 + 1 + 2)
+
+ // negative test
+ checkResultPartition(
+ sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2"),
+ numUnion = if (combineUnionEnabled) 1 else 1,
+ numShuffleReader = 0,
+ numPartition = 2 + 4
+ )
+ }
+ }
+ }
+ }
+
test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") {
withTable("t") {
withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
- SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
spark.sql("CREATE TABLE t (c1 int) USING PARQUET")
val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1")
assert(
@@ -1889,8 +2090,8 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") {
// partition size [0,258,72,72,72]
checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4)
- // partition size [72,216,216,144,72]
- checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 4, 7)
+ // partition size [144,72,144,72,72,144,72]
+ checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 6, 7)
}
// no skewed partition should be optimized
@@ -1925,6 +2126,74 @@ class AdaptiveQueryExecSuite
}
}
+ test("SPARK-33832: Support optimize skew join even if introduce extra shuffle") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "false",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+ SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+
+ // check if optimized skewed join does not satisfy the required distribution
+ Seq(true, false).foreach { hasRequiredDistribution =>
+ Seq(true, false).foreach { hasPartitionNumber =>
+ val repartition = if (hasRequiredDistribution) {
+ s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) */"
+ } else {
+ ""
+ }
+
+ // check required distribution and extra shuffle
+ val (_, adaptive1) =
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
+ s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
+ val shuffles1 = collect(adaptive1) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles1.size == 3)
+ // shuffles1.head is the top-level shuffle under the Aggregate operator
+ assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS)
+ val smj1 = findTopLevelSortMergeJoin(adaptive1)
+ assert(smj1.size == 1 && smj1.head.isSkewJoin)
+
+ // only check required distribution
+ val (_, adaptive2) =
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
+ s"JOIN skewData2 ON key1 = key2")
+ val shuffles2 = collect(adaptive2) {
+ case s: ShuffleExchangeExec => s
+ }
+ if (hasRequiredDistribution) {
+ assert(shuffles2.size == 3)
+ val finalShuffle = shuffles2.head
+ if (hasPartitionNumber) {
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM)
+ } else {
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL)
+ }
+ } else {
+ assert(shuffles2.size == 2)
+ }
+ val smj2 = findTopLevelSortMergeJoin(adaptive2)
+ assert(smj2.size == 1 && smj2.head.isSkewJoin)
+ }
+ }
+ }
+ }
+ }
+
test("SPARK-35968: AQE coalescing should not produce too small partitions by default") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptive) =
@@ -2033,10 +2302,108 @@ class AdaptiveQueryExecSuite
}
}
+ test("SPARK-36424: Support eliminate limits in AQE Optimizer") {
+ withTempView("v") {
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, if (i > 2) "2" else i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 LIMIT 5
+ """.stripMargin)
+ assert(findTopLevelLimit(origin1).size == 1)
+ assert(findTopLevelLimit(adaptive1).isEmpty)
+
+ // eliminate limit through filter
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 HAVING sum(c1) > 1 LIMIT 5
+ """.stripMargin)
+ assert(findTopLevelLimit(origin2).size == 1)
+ assert(findTopLevelLimit(adaptive2).isEmpty)
+
+ // The strategy of Eliminate Limits batch should be fixedPoint
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20
+ """.stripMargin
+ )
+ assert(findTopLevelLimit(origin3).size == 1)
+ assert(findTopLevelLimit(adaptive3).isEmpty)
+ }
+ }
+ }
+
+ test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(if (i > 2) 2 else i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ def checkRebalance(query: String, numShufflePartitions: Int): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(adaptive.collect {
+ case sort: SortExec => sort
+ }.size == 1)
+ val read = collect(adaptive) {
+ case read: AQEShuffleReadExec => read
+ }
+ assert(read.size == 1)
+ assert(read.head.partitionSpecs.forall(_.isInstanceOf[PartialReducerPartitionSpec]))
+ assert(read.head.partitionSpecs.size == numShufflePartitions)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "50") {
+ checkRebalance("SELECT /*+ REBALANCE(c1) */ * FROM v SORT BY c1", 2)
+ checkRebalance("SELECT /*+ REBALANCE */ * FROM v SORT BY c1", 2)
+ }
+ }
+ }
+ }
+
+ test("SPARK-37357: Add small partition factor for rebalance partitions") {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ spark.sparkContext.parallelize(
+ (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(
+ collect(adaptive) {
+ case read: AQEShuffleReadExec => read
+ }.nonEmpty == exists)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") {
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") {
+ // block size: [88, 97, 97]
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false)
+ }
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") {
+ // block size: [88, 97, 97]
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true)
+ }
+ }
+ }
+ }
+ }
+
test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584") {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
// Spark estimates a string column as 20 bytes so with 60k rows, these relations should be
// estimated at ~120m bytes which is greater than the broadcast join threshold.
val joinKeyOne = "00112233445566778899"
@@ -2085,6 +2452,203 @@ class AdaptiveQueryExecSuite
assert(bhj.length == 1)
}
}
+
+ test("SPARK-37328: skew join with 3 tables") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ withTempView("skewData1", "skewData2", "skewData3") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id % 3 as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key3", "id as value3")
+ .createOrReplaceTempView("skewData3")
+
+ // skewedJoin doesn't happen in last stage
+ val (_, adaptive1) =
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value2 = value3")
+ val shuffles1 = collect(adaptive1) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles1.size == 4)
+ val smj1 = findTopLevelSortMergeJoin(adaptive1)
+ assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin)
+
+ // Query has two skewJoin in two continuous stages.
+ val (_, adaptive2) =
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value1 = value3")
+ val shuffles2 = collect(adaptive2) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles2.size == 4)
+ val smj2 = findTopLevelSortMergeJoin(adaptive2)
+ assert(smj2.size == 2 && smj2.forall(_.isSkewJoin))
+ }
+ }
+ }
+
+ test("SPARK-37652: optimize skewed join through union") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+
+ def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = {
+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val joins = findTopLevelSortMergeJoin(innerAdaptivePlan)
+ val optimizeSkewJoins = joins.filter(_.isSkewJoin)
+ assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums)
+ }
+
+ // skewJoin union skewJoin
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2)
+
+ // skewJoin union aggregate
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1)
+
+ // skewJoin1 union (skewJoin2 join aggregate)
+ // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " +
+ "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " +
+ "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0)
+ }
+ }
+ }
+
+ test("SPARK-38162: Optimize one row plan in AQE Optimizer") {
+ withTempView("v") {
+ spark.sparkContext.parallelize(
+ (1 to 4).map(i => TestData(i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ // remove sort
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM v where c1 = 1 order by c1, c2
+ |""".stripMargin)
+ assert(findTopLevelSort(origin1).size == 1)
+ assert(findTopLevelSort(adaptive1).isEmpty)
+
+ // convert group only aggregate to project
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin2).size == 2)
+ assert(findTopLevelAggregate(adaptive2).isEmpty)
+
+ // remove distinct in aggregate
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin3).size == 4)
+ assert(findTopLevelAggregate(adaptive3).size == 2)
+
+ // do not optimize if the aggregate is inside query stage
+ val (origin4, adaptive4) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT distinct c1 FROM v where c1 = 1
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin4).size == 2)
+ assert(findTopLevelAggregate(adaptive4).size == 2)
+
+ val (origin5, adaptive5) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT sum(distinct c1) FROM v where c1 = 1
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin5).size == 4)
+ assert(findTopLevelAggregate(adaptive5).size == 4)
+ }
+ }
+
+ test("SPARK-39551: Invalid plan check - invalid broadcast query stage") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
+ |INNER JOIN testData2 t2
+ |ON t1.b = t2.b AND t1.a = 0
+ |RIGHT OUTER JOIN testData2 t3
+ |ON t1.a > t3.a
+ |GROUP BY t3.b
+ """.stripMargin
+ )
+ assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
+ }
+ }
+
+ test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") {
+ // partitioning: HashPartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5)
+ // shuffleOrigin: REPARTITION_BY_COL
+ // The minimum partition number after AQE coalesce is 1
+ assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1)
+ // through project
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2")
+ .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5)
+
+ // partitioning: RangePartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ // The minimum partition number of RangePartitioner is 1
+ assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1)
+ // shuffleOrigin: REPARTITION_BY_COL
+ assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1)
+
+ // partitioning: RoundRobinPartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5)
+ // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE
+ assert(spark.range(0).repartition().rdd.getNumPartitions == 0)
+ // through project
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2")
+ .repartition(5).select($"c2").rdd.getNumPartitions == 5)
+
+ // partitioning: SinglePartition
+ assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1)
+ }
+ }
+
+ test("SPARK-39915: Ensure the output partitioning is user-specified") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df1 = spark.range(1).selectExpr("id as c1")
+ val df2 = spark.range(1).selectExpr("id as c2")
+ val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1"))
+ assert(df.rdd.getNumPartitions == 3)
+ }
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index a5ac2d5aa7..e876e9d6ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1377,7 +1377,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx)
val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx)
var count = 0
@@ -1398,7 +1398,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
val ctx = TaskContext.empty()
- val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx)
// Write batches to Arrow stream format as a byte array
val out = new ByteArrayOutputStream()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 146d9fcdb9..a88f423ae0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.arrow
-import org.apache.arrow.vector.IntervalDayVector
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
@@ -30,12 +28,12 @@ class ArrowWriterSuite extends SparkFunSuite {
test("simple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
- val avroDatatype = dt match {
+ val datatype = dt match {
case _: DayTimeIntervalType => DayTimeIntervalType()
case _: YearMonthIntervalType => YearMonthIntervalType()
case tpe => tpe
}
- val schema = new StructType().add("value", avroDatatype, nullable = true)
+ val schema = new StructType().add("value", datatype, nullable = true)
val writer = ArrowWriter.create(schema, timeZoneId)
assert(writer.schema === schema)
@@ -61,6 +59,7 @@ class ArrowWriterSuite extends SparkFunSuite {
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
case TimestampType => reader.getLong(rowId)
+ case TimestampNTZType => reader.getLong(rowId)
case _: YearMonthIntervalType => reader.getInt(rowId)
case _: DayTimeIntervalType => reader.getLong(rowId)
}
@@ -81,6 +80,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()))
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
+ check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong))
check(NullType, Seq(null, null, null))
DataTypeTestUtils.yearMonthIntervalTypes
.foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)))
@@ -88,30 +88,6 @@ class ArrowWriterSuite extends SparkFunSuite {
Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), (Long.MinValue + 808L))))
}
- test("long overflow for DayTimeIntervalType")
- {
- val schema = new StructType().add("value", DayTimeIntervalType(), nullable = true)
- val writer = ArrowWriter.create(schema, null)
- val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
- val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector]
-
- valueVector.set(0, 106751992, 0)
- valueVector.set(1, 106751991, Int.MaxValue)
-
- // first long overflow for test Math.multiplyExact()
- val msg = intercept[java.lang.ArithmeticException] {
- reader.getLong(0)
- }.getMessage
- assert(msg.equals("long overflow"))
-
- // second long overflow for test Math.addExact()
- val msg1 = intercept[java.lang.ArithmeticException] {
- reader.getLong(1)
- }.getMessage
- assert(msg1.equals("long overflow"))
- writer.root.close()
- }
-
test("get multiple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
val avroDatatype = dt match {
@@ -139,6 +115,7 @@ class ArrowWriterSuite extends SparkFunSuite {
case DoubleType => reader.getDoubles(0, data.size)
case DateType => reader.getInts(0, data.size)
case TimestampType => reader.getLongs(0, data.size)
+ case TimestampNTZType => reader.getLongs(0, data.size)
case _: YearMonthIntervalType => reader.getInts(0, data.size)
case _: DayTimeIntervalType => reader.getLongs(0, data.size)
}
@@ -155,6 +132,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
+ check(TimestampNTZType, (0 until 10).map(_ * 4.32e10.toLong))
DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14)))
DataTypeTestUtils.dayTimeIntervalTypes.foreach(check(_, (-10 until 10).map(_ * 1000.toLong)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala
new file mode 100644
index 0000000000..0537527b85
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Benchmark to measure performance for interval sort.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
+ * 2. build/sbt "sql/test:runMain <this class>"
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ * Results will be written to "benchmarks/IntervalBenchmark-results.txt".
+ * }}}
+ */
+object AnsiIntervalSortBenchmark extends SqlBasedBenchmark {
+ private val numRows = 100 * 1000 * 1000
+
+ private def radixBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
+ val benchmark = new Benchmark(name, cardinality, output = output)
+ benchmark.addCase(s"$name enable radix", 3) { _ =>
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> "true") {
+ f
+ }
+ }
+
+ benchmark.addCase(s"$name disable radix", 3) { _ =>
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> "false") {
+ f
+ }
+ }
+ benchmark.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val dt = spark.range(numRows).selectExpr("make_dt_interval(id % 24) as c1", "id as c2")
+ radixBenchmark("year month interval one column", numRows) {
+ dt.sortWithinPartitions("c1").select("c2").noop()
+ }
+
+ radixBenchmark("year month interval two columns", numRows) {
+ dt.sortWithinPartitions("c1", "c2").select("c2").noop()
+ }
+
+ val ym = spark.range(numRows).selectExpr("make_ym_interval(id % 2000) as c1", "id as c2")
+ radixBenchmark("day time interval one columns", numRows) {
+ ym.sortWithinPartitions("c1").select("c2").noop()
+ }
+
+ radixBenchmark("day time interval two columns", numRows) {
+ ym.sortWithinPartitions("c1", "c2").select("c2").noop()
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala
new file mode 100644
index 0000000000..eb0b896574
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.benchmark.Benchmark
+
+/**
+ * Benchmark for measuring perf of different Base64 implementations
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
+ * 2. build/sbt "sql/test:runMain <this class>"
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ * Results will be written to "benchmarks/Base64Benchmark-results.txt".
+ * }}}
+ */
+object Base64Benchmark extends SqlBasedBenchmark {
+ import spark.implicits._
+ private val N = 20L * 1000 * 1000
+
+ private def doEncode(len: Int, f: Array[Byte] => Array[Byte]): Unit = {
+ spark.range(N).map(_ => "Spark" * len).foreach { s =>
+ f(s.getBytes)
+ ()
+ }
+ }
+
+ private def doDecode(len: Int, f: Array[Byte] => Array[Byte]): Unit = {
+ spark.range(N).map(_ => "Spark" * len).map { s =>
+ // using the same encode func
+ java.util.Base64.getMimeEncoder.encode(s.getBytes)
+ }.foreach { s =>
+ f(s)
+ ()
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ Seq(1, 3, 5, 7).map { len =>
+ val benchmark = new Benchmark(s"encode for $len", N, output = output)
+ benchmark.addCase("java", 3) { _ =>
+ doEncode(len, x => java.util.Base64.getMimeEncoder().encode(x))
+ }
+ benchmark.addCase(s"apache", 3) { _ =>
+ doEncode(len, org.apache.commons.codec.binary.Base64.encodeBase64)
+ }
+ benchmark
+ }.foreach(_.run())
+
+ Seq(1, 3, 5, 7).map { len =>
+ val benchmark = new Benchmark(s"decode for $len", N, output = output)
+ benchmark.addCase("java", 3) { _ =>
+ doDecode(len, x => java.util.Base64.getMimeDecoder.decode(x))
+ }
+ benchmark.addCase(s"apache", 3) { _ =>
+ doDecode(len, org.apache.commons.codec.binary.Base64.decodeBase64)
+ }
+ benchmark
+ }.foreach(_.run())
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala
index f78ccf9569..ccb65c7d3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.execution.benchmark
import scala.util.Random
+import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
+
import org.apache.spark.benchmark.Benchmark
/**
* Benchmark to measure read performance with Bloom filters.
*
- * Currently, only ORC supports bloom filters, we will add Parquet BM as soon as it becomes
- * available.
- *
* To run this benchmark:
* {{{
* 1. without sbt: bin/spark-submit --class <this class>
@@ -43,7 +42,7 @@ object BloomFilterBenchmark extends SqlBasedBenchmark {
private val N = scaleFactor * 1000 * 1000
private val df = spark.range(N).map(_ => Random.nextInt)
- private def writeBenchmark(): Unit = {
+ private def writeORCBenchmark(): Unit = {
withTempPath { dir =>
val path = dir.getCanonicalPath
@@ -61,7 +60,7 @@ object BloomFilterBenchmark extends SqlBasedBenchmark {
}
}
- private def readBenchmark(): Unit = {
+ private def readORCBenchmark(): Unit = {
withTempPath { dir =>
val path = dir.getCanonicalPath
@@ -81,8 +80,55 @@ object BloomFilterBenchmark extends SqlBasedBenchmark {
}
}
+ private def writeParquetBenchmark(): Unit = {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ runBenchmark("Parquet Write") {
+ val benchmark = new Benchmark(s"Write ${scaleFactor}M rows", N, output = output)
+ benchmark.addCase("Without bloom filter") { _ =>
+ df.write.mode("overwrite").parquet(path + "/withoutBF")
+ }
+ benchmark.addCase("With bloom filter") { _ =>
+ df.write.mode("overwrite")
+ .option(ParquetOutputFormat.BLOOM_FILTER_ENABLED + "#value", true)
+ .parquet(path + "/withBF")
+ }
+ benchmark.run()
+ }
+ }
+ }
+
+ private def readParquetBenchmark(): Unit = {
+ val blockSizes = Seq(2 * 1024 * 1024, 4 * 1024 * 1024, 6 * 1024 * 1024, 8 * 1024 * 1024,
+ 12 * 1024 * 1024, 16 * 1024 * 1024, 32 * 1024 * 1024)
+ for (blocksize <- blockSizes) {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ df.write.option("parquet.block.size", blocksize).parquet(path + "/withoutBF")
+ df.write.option(ParquetOutputFormat.BLOOM_FILTER_ENABLED + "#value", true)
+ .option("parquet.block.size", blocksize)
+ .parquet(path + "/withBF")
+
+ runBenchmark("Parquet Read") {
+ val benchmark = new Benchmark(s"Read a row from ${scaleFactor}M rows", N, output = output)
+ benchmark.addCase("Without bloom filter, blocksize: " + blocksize) { _ =>
+ spark.read.parquet(path + "/withoutBF").where("value = 0").noop()
+ }
+ benchmark.addCase("With bloom filter, blocksize: " + blocksize) { _ =>
+ spark.read.option(ParquetInputFormat.BLOOM_FILTERING_ENABLED, true)
+ .parquet(path + "/withBF").where("value = 0").noop()
+ }
+ benchmark.run()
+ }
+ }
+ }
+ }
+
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
- writeBenchmark()
- readBenchmark()
+ writeORCBenchmark()
+ readORCBenchmark()
+ writeParquetBenchmark()
+ readParquetBenchmark()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala
index 361deb0d3e..45d50b5e11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala
@@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.execution.benchmark
+import org.apache.parquet.column.ParquetProperties
+import org.apache.parquet.hadoop.ParquetOutputFormat
+
import org.apache.spark.sql.internal.SQLConf
/**
@@ -53,7 +56,16 @@ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark {
formats.foreach { format =>
runBenchmark(s"$format writer benchmark") {
- runDataSourceBenchmark(format)
+ if (format.equals("Parquet")) {
+ ParquetProperties.WriterVersion.values().foreach {
+ writeVersion =>
+ withSQLConf(ParquetOutputFormat.WRITER_VERSION -> writeVersion.toString) {
+ runDataSourceBenchmark("Parquet", Some(writeVersion.toString))
+ }
+ }
+ } else {
+ runDataSourceBenchmark(format)
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala
new file mode 100644
index 0000000000..99016842d8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.util.Random
+
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+
+/**
+ * Benchmark to measure performance for byte array operators.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt:
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
+ * 2. build/sbt "sql/test:runMain <this class>"
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ * Results will be written to "benchmarks/<this class>-results.txt".
+ * }}}
+ */
+object ByteArrayBenchmark extends BenchmarkBase {
+ private val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ private val randomChar = new Random(0)
+
+ def randomBytes(min: Int, max: Int): Array[Byte] = {
+ val len = randomChar.nextInt(max - min) + min
+ val bytes = new Array[Byte](len)
+ var i = 0
+ while (i < len) {
+ bytes(i) = chars.charAt(randomChar.nextInt(chars.length())).toByte
+ i += 1
+ }
+ bytes
+ }
+
+ def byteArrayComparisons(iters: Long): Unit = {
+ val count = 16 * 1000
+ val dataTiny = Seq.fill(count)(randomBytes(2, 7)).toArray
+ val dataSmall = Seq.fill(count)(randomBytes(8, 16)).toArray
+ val dataMedium = Seq.fill(count)(randomBytes(16, 32)).toArray
+ val dataLarge = Seq.fill(count)(randomBytes(512, 1024)).toArray
+ val dataLargeSlow = Seq.fill(count)(
+ Array.tabulate(512) {i => if (i < 511) 0.toByte else 1.toByte}).toArray
+
+ def compareBinary(data: Array[Array[Byte]]) = { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iters) {
+ var i = 0
+ while (i < count) {
+ sum += ByteArray.compareBinary(data(i), data((i + 1) % count))
+ i += 1
+ }
+ }
+ }
+
+ val benchmark = new Benchmark("Byte Array compareTo", count * iters, 25, output = output)
+ benchmark.addCase("2-7 byte")(compareBinary(dataTiny))
+ benchmark.addCase("8-16 byte")(compareBinary(dataSmall))
+ benchmark.addCase("16-32 byte")(compareBinary(dataMedium))
+ benchmark.addCase("512-1024 byte")(compareBinary(dataLarge))
+ benchmark.addCase("512 byte slow")(compareBinary(dataLargeSlow))
+ benchmark.addCase("2-7 byte")(compareBinary(dataTiny))
+ benchmark.run()
+ }
+
+ def byteArrayEquals(iters: Long): Unit = {
+ def binaryEquals(inputs: Array[BinaryEqualInfo]) = { _: Int =>
+ var res = false
+ for (_ <- 0L until iters) {
+ inputs.foreach { input =>
+ res = ByteArrayMethods.arrayEquals(
+ input.s1.getBaseObject, input.s1.getBaseOffset,
+ input.s2.getBaseObject, input.s2.getBaseOffset + input.deltaOffset,
+ input.len)
+ }
+ }
+ }
+ val count = 16 * 1000
+ val rand = new Random(0)
+ val inputs = (0 until count).map { _ =>
+ val s1 = UTF8String.fromBytes(randomBytes(1, 16))
+ val s2 = UTF8String.fromBytes(randomBytes(1, 16))
+ val len = s1.numBytes().min(s2.numBytes())
+ val deltaOffset = rand.nextInt(len)
+ BinaryEqualInfo(s1, s2, deltaOffset, len)
+ }.toArray
+
+ val benchmark = new Benchmark("Byte Array equals", count * iters, 25, output = output)
+ benchmark.addCase("Byte Array equals")(binaryEquals(inputs))
+ benchmark.run()
+ }
+
+ case class BinaryEqualInfo(
+ s1: UTF8String,
+ s2: UTF8String,
+ deltaOffset: Int,
+ len: Int)
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("byte array comparisons") {
+ byteArrayComparisons(1024 * 4)
+ }
+
+ runBenchmark("byte array equals") {
+ byteArrayEquals(1000 * 10)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
index 0fc43c7052..b35aa73e14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
@@ -21,11 +21,14 @@ import java.io.File
import scala.collection.JavaConverters._
import scala.util.Random
-import org.apache.spark.SparkConf
+import org.apache.parquet.column.ParquetProperties
+import org.apache.parquet.hadoop.ParquetOutputFormat
+
+import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader}
+import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnVector
@@ -66,16 +69,23 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
- private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
+ private def prepareTable(
+ dir: File,
+ df: DataFrame,
+ partition: Option[String] = None,
+ onlyParquetOrc: Boolean = false): Unit = {
val testDf = if (partition.isDefined) {
df.write.partitionBy(partition.get)
} else {
df.write
}
- saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv")
- saveAsJsonTable(testDf, dir.getCanonicalPath + "/json")
- saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet")
+ if (!onlyParquetOrc) {
+ saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv")
+ saveAsJsonTable(testDf, dir.getCanonicalPath + "/json")
+ }
+ saveAsParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1")
+ saveAsParquetV2Table(testDf, dir.getCanonicalPath + "/parquetV2")
saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc")
}
@@ -89,9 +99,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.read.json(dir).createOrReplaceTempView("jsonTable")
}
- private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = {
+ private def saveAsParquetV1Table(df: DataFrameWriter[Row], dir: String): Unit = {
df.mode("overwrite").option("compression", "snappy").parquet(dir)
- spark.read.parquet(dir).createOrReplaceTempView("parquetTable")
+ spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table")
+ }
+
+ private def saveAsParquetV2Table(df: DataFrameWriter[Row], dir: String): Unit = {
+ withSQLConf(ParquetOutputFormat.WRITER_VERSION ->
+ ParquetProperties.WriterVersion.PARQUET_2_0.toString) {
+ df.mode("overwrite").option("compression", "snappy").parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("parquetV2Table")
+ }
}
private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = {
@@ -99,6 +117,8 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.read.orc(dir).createOrReplaceTempView("orcTable")
}
+ private def withParquetVersions(f: String => Unit): Unit = Seq("V1", "V2").foreach(f)
+
def numericScanBenchmark(values: Int, dataType: DataType): Unit = {
// Benchmarks running through spark sql.
val sqlBenchmark = new Benchmark(
@@ -113,112 +133,267 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
import spark.implicits._
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1"))
+ val query = dataType match {
+ case BooleanType => "sum(cast(id as bigint))"
+ case _ => "sum(id)"
+ }
+
sqlBenchmark.addCase("SQL CSV") { _ =>
- spark.sql("select sum(id) from csvTable").noop()
+ spark.sql(s"select $query from csvTable").noop()
}
sqlBenchmark.addCase("SQL Json") { _ =>
- spark.sql("select sum(id) from jsonTable").noop()
+ spark.sql(s"select $query from jsonTable").noop()
}
- sqlBenchmark.addCase("SQL Parquet Vectorized") { _ =>
- spark.sql("select sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select $query from parquet${version}Table").noop()
+ }
}
- sqlBenchmark.addCase("SQL Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ sqlBenchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select $query from parquet${version}Table").noop()
+ }
}
}
sqlBenchmark.addCase("SQL ORC Vectorized") { _ =>
- spark.sql("SELECT sum(id) FROM orcTable").noop()
+ spark.sql(s"SELECT $query FROM orcTable").noop()
}
sqlBenchmark.addCase("SQL ORC MR") { _ =>
withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("SELECT sum(id) FROM orcTable").noop()
+ spark.sql(s"SELECT $query FROM orcTable").noop()
}
}
sqlBenchmark.run()
- // Driving the parquet reader in batch mode directly.
- val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray
val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled
val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize
- parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ =>
- var longSum = 0L
- var doubleSum = 0.0
- val aggregateValue: (ColumnVector, Int) => Unit = dataType match {
- case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i)
- case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i)
- case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i)
- case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i)
- case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i)
- case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i)
- }
-
- files.map(_.asInstanceOf[String]).foreach { p =>
- val reader = new VectorizedParquetRecordReader(
- enableOffHeapColumnVector, vectorizedReaderBatchSize)
- try {
- reader.initialize(p, ("id" :: Nil).asJava)
- val batch = reader.resultBatch()
- val col = batch.column(0)
- while (reader.nextBatch()) {
- val numRows = batch.numRows()
- var i = 0
- while (i < numRows) {
- if (!col.isNullAt(i)) aggregateValue(col, i)
- i += 1
+ withParquetVersions { version =>
+ // Driving the parquet reader in batch mode directly.
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version"))
+ parquetReaderBenchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ =>
+ var longSum = 0L
+ var doubleSum = 0.0
+ val aggregateValue: (ColumnVector, Int) => Unit = dataType match {
+ case BooleanType =>
+ (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L
+ case ByteType =>
+ (col: ColumnVector, i: Int) => longSum += col.getByte(i)
+ case ShortType =>
+ (col: ColumnVector, i: Int) => longSum += col.getShort(i)
+ case IntegerType =>
+ (col: ColumnVector, i: Int) => longSum += col.getInt(i)
+ case LongType =>
+ (col: ColumnVector, i: Int) => longSum += col.getLong(i)
+ case FloatType =>
+ (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i)
+ case DoubleType =>
+ (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i)
+ }
+
+ files.foreach { p =>
+ val reader = new VectorizedParquetRecordReader(
+ enableOffHeapColumnVector, vectorizedReaderBatchSize)
+ try {
+ reader.initialize(p, ("id" :: Nil).asJava)
+ val batch = reader.resultBatch()
+ val col = batch.column(0)
+ while (reader.nextBatch()) {
+ val numRows = batch.numRows()
+ var i = 0
+ while (i < numRows) {
+ if (!col.isNullAt(i)) aggregateValue(col, i)
+ i += 1
+ }
}
+ } finally {
+ reader.close()
}
- } finally {
- reader.close()
}
}
}
- // Decoding in vectorized but having the reader return rows.
- parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num =>
- var longSum = 0L
- var doubleSum = 0.0
- val aggregateValue: (InternalRow) => Unit = dataType match {
- case ByteType => (col: InternalRow) => longSum += col.getByte(0)
- case ShortType => (col: InternalRow) => longSum += col.getShort(0)
- case IntegerType => (col: InternalRow) => longSum += col.getInt(0)
- case LongType => (col: InternalRow) => longSum += col.getLong(0)
- case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0)
- case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0)
- }
-
- files.map(_.asInstanceOf[String]).foreach { p =>
- val reader = new VectorizedParquetRecordReader(
- enableOffHeapColumnVector, vectorizedReaderBatchSize)
- try {
- reader.initialize(p, ("id" :: Nil).asJava)
- val batch = reader.resultBatch()
- while (reader.nextBatch()) {
- val it = batch.rowIterator()
- while (it.hasNext) {
- val record = it.next()
- if (!record.isNullAt(0)) aggregateValue(record)
+ withParquetVersions { version =>
+ // Driving the parquet reader in batch mode directly.
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version"))
+ // Decoding in vectorized but having the reader return rows.
+ parquetReaderBenchmark
+ .addCase(s"ParquetReader Vectorized -> Row: DataPage$version") { _ =>
+ var longSum = 0L
+ var doubleSum = 0.0
+ val aggregateValue: (InternalRow) => Unit = dataType match {
+ case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L
+ case ByteType => (col: InternalRow) => longSum += col.getByte(0)
+ case ShortType => (col: InternalRow) => longSum += col.getShort(0)
+ case IntegerType => (col: InternalRow) => longSum += col.getInt(0)
+ case LongType => (col: InternalRow) => longSum += col.getLong(0)
+ case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0)
+ case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0)
+ }
+
+ files.foreach { p =>
+ val reader = new VectorizedParquetRecordReader(
+ enableOffHeapColumnVector, vectorizedReaderBatchSize)
+ try {
+ reader.initialize(p, ("id" :: Nil).asJava)
+ val batch = reader.resultBatch()
+ while (reader.nextBatch()) {
+ val it = batch.rowIterator()
+ while (it.hasNext) {
+ val record = it.next()
+ if (!record.isNullAt(0)) aggregateValue(record)
+ }
+ }
+ } finally {
+ reader.close()
}
}
- } finally {
- reader.close()
+ }
+ }
+ }
+
+ parquetReaderBenchmark.run()
+ }
+ }
+
+ /**
+ * Similar to [[numericScanBenchmark]] but accessed column is a struct field.
+ */
+ def nestedNumericScanBenchmark(values: Int, dataType: DataType): Unit = {
+ val sqlBenchmark = new Benchmark(
+ s"SQL Single ${dataType.sql} Column Scan in Struct",
+ values,
+ output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "parquetV1Table", "parquetV2Table", "orcTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
+
+ prepareTable(dir,
+ spark.sql(s"SELECT named_struct('f', CAST(value as ${dataType.sql})) as col FROM t1"),
+ onlyParquetOrc = true)
+
+ sqlBenchmark.addCase(s"SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(col.f) from orcTable").noop()
+ }
+ }
+
+ sqlBenchmark.addCase(s"SQL ORC Vectorized (Nested Column Disabled)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql(s"select sum(col.f) from orcTable").noop()
+ }
+ }
+
+ sqlBenchmark.addCase(s"SQL ORC Vectorized (Nested Column Enabled)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql(s"select sum(col.f) from orcTable").noop()
+ }
+ }
+
+ withParquetVersions { version =>
+ sqlBenchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop()
+ }
+ }
+
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " +
+ "(Nested Column Disabled)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop()
+ }
+ }
+
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " +
+ "(Nested Column Enabled)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop()
+ }
+ }
+ }
+
+ sqlBenchmark.run()
+ }
+ }
+ }
+
+ def nestedColumnScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark(s"SQL Nested Column Scan", values, minNumIters = 10,
+ output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "parquetV1Table", "parquetV2Table", "orcTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).map { x =>
+ val arrayOfStructColumn = (0 until 5).map(i => (x + i, s"$x" * 5))
+ val mapOfStructColumn = Map(
+ s"$x" -> (x * 0.1, (x, s"$x" * 100)),
+ (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)),
+ (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300)))
+ (arrayOfStructColumn, mapOfStructColumn)
+ }.toDF("col1", "col2").createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql(s"SELECT * FROM t1"), onlyParquetOrc = true)
+
+ benchmark.addCase("SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL ORC Vectorized (Nested Column Disabled)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL ORC Vectorized (Nested Column Enabled)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table")
+ .noop()
+ }
+ }
+
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " +
+ s"(Nested Column Disabled)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table")
+ .noop()
+ }
+ }
+
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " +
+ s"(Nested Column Enabled)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table")
+ .noop()
}
}
}
- parquetReaderBenchmark.run()
+ benchmark.run()
}
}
}
@@ -227,7 +402,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
val benchmark = new Benchmark("Int and String Scan", values, output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
import spark.implicits._
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
@@ -243,13 +418,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql("select sum(c1), sum(length(c2)) from jsonTable").noop()
}
- benchmark.addCase("SQL Parquet Vectorized") { _ =>
- spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop()
+ }
}
- benchmark.addCase("SQL Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop()
+ }
}
}
@@ -272,7 +451,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
val benchmark = new Benchmark("Repeated String", values, output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
import spark.implicits._
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
@@ -288,13 +467,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql("select sum(length(c1)) from jsonTable").noop()
}
- benchmark.addCase("SQL Parquet Vectorized") { _ =>
- spark.sql("select sum(length(c1)) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop()
+ }
}
- benchmark.addCase("SQL Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(length(c1)) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop()
+ }
}
}
@@ -317,7 +500,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
val benchmark = new Benchmark("Partitioned Table", values, output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
import spark.implicits._
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
@@ -331,13 +514,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql("select sum(id) from jsonTable").noop()
}
- benchmark.addCase("Data column - Parquet Vectorized") { _ =>
- spark.sql("select sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Data column - Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(id) from parquet${version}Table").noop()
+ }
}
- benchmark.addCase("Data column - Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Data column - Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(id) from parquet${version}Table").noop()
+ }
}
}
@@ -359,13 +546,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql("select sum(p) from jsonTable").noop()
}
- benchmark.addCase("Partition column - Parquet Vectorized") { _ =>
- spark.sql("select sum(p) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Partition column - Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(p) from parquet${version}Table").noop()
+ }
}
- benchmark.addCase("Partition column - Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(p) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Partition column - Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(p) from parquet${version}Table").noop()
+ }
}
}
@@ -387,13 +578,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql("select sum(p), sum(id) from jsonTable").noop()
}
- benchmark.addCase("Both columns - Parquet Vectorized") { _ =>
- spark.sql("select sum(p), sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Both columns - Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop()
+ }
}
- benchmark.addCase("Both columns - Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(p), sum(id) from parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"Both columns - Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop()
+ }
}
}
@@ -418,7 +613,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
spark.range(values).createOrReplaceTempView("t1")
prepareTable(
@@ -437,39 +632,45 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
"not NULL and c2 is not NULL").noop()
}
- benchmark.addCase("SQL Parquet Vectorized") { _ =>
- spark.sql("select sum(length(c2)) from parquetTable where c1 is " +
- "not NULL and c2 is not NULL").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " +
+ "not NULL and c2 is not NULL").noop()
+ }
}
- benchmark.addCase("SQL Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql("select sum(length(c2)) from parquetTable where c1 is " +
- "not NULL and c2 is not NULL").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " +
+ "not NULL and c2 is not NULL").noop()
+ }
}
}
- val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray
- val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled
- val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize
- benchmark.addCase("ParquetReader Vectorized") { num =>
- var sum = 0
- files.map(_.asInstanceOf[String]).foreach { p =>
- val reader = new VectorizedParquetRecordReader(
- enableOffHeapColumnVector, vectorizedReaderBatchSize)
- try {
- reader.initialize(p, ("c1" :: "c2" :: Nil).asJava)
- val batch = reader.resultBatch()
- while (reader.nextBatch()) {
- val rowIterator = batch.rowIterator()
- while (rowIterator.hasNext) {
- val row = rowIterator.next()
- val value = row.getUTF8String(0)
- if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes()
+ withParquetVersions { version =>
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version"))
+ val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled
+ val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize
+ benchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ =>
+ var sum = 0
+ files.foreach { p =>
+ val reader = new VectorizedParquetRecordReader(
+ enableOffHeapColumnVector, vectorizedReaderBatchSize)
+ try {
+ reader.initialize(p, ("c1" :: "c2" :: Nil).asJava)
+ val batch = reader.resultBatch()
+ while (reader.nextBatch()) {
+ val rowIterator = batch.rowIterator()
+ while (rowIterator.hasNext) {
+ val row = rowIterator.next()
+ val value = row.getUTF8String(0)
+ if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes()
+ }
}
+ } finally {
+ reader.close()
}
- } finally {
- reader.close()
}
}
}
@@ -498,7 +699,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
output = output)
withTempPath { dir =>
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") {
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") {
import spark.implicits._
val middle = width / 2
val selectExpr = (1 to width).map(i => s"value as c$i")
@@ -515,13 +716,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
spark.sql(s"SELECT sum(c$middle) FROM jsonTable").noop()
}
- benchmark.addCase("SQL Parquet Vectorized") { _ =>
- spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ =>
+ spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop()
+ }
}
- benchmark.addCase("SQL Parquet MR") { _ =>
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop()
+ withParquetVersions { version =>
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop()
+ }
}
}
@@ -542,10 +747,18 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
runBenchmark("SQL Single Numeric Column Scan") {
- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach {
+ Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach {
dataType => numericScanBenchmark(1024 * 1024 * 15, dataType)
}
}
+ runBenchmark("SQL Single Numeric Column Scan in Struct") {
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach {
+ dataType => nestedNumericScanBenchmark(1024 * 1024 * 15, dataType)
+ }
+ }
+ runBenchmark("SQL Nested Column Scan") {
+ nestedColumnScanBenchmark(1024 * 1024)
+ }
runBenchmark("Int and String Scan") {
intStringScanBenchmark(1024 * 1024 * 10)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala
index 405d60794e..77e26048e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala
@@ -66,7 +66,7 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark {
}
}
- def runDataSourceBenchmark(format: String): Unit = {
+ def runDataSourceBenchmark(format: String, extraInfo: Option[String] = None): Unit = {
val tableInt = "tableInt"
val tableDouble = "tableDouble"
val tableIntString = "tableIntString"
@@ -75,7 +75,12 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark {
withTempTable(tempTable) {
spark.range(numRows).createOrReplaceTempView(tempTable)
withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) {
- val benchmark = new Benchmark(s"$format writer benchmark", numRows, output = output)
+ val writerName = extraInfo match {
+ case Some(extra) => s"$format($extra)"
+ case _ => format
+ }
+ val benchmark =
+ new Benchmark(s"$writerName writer benchmark", numRows, output = output)
writeNumeric(tableInt, format, benchmark, "Int")
writeNumeric(tableDouble, format, benchmark, "Double")
writeIntString(tableIntString, format, benchmark)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala
index 4e42330088..918f665238 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala
@@ -61,7 +61,8 @@ object DateTimeBenchmark extends SqlBasedBenchmark {
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
withDefaultTimeZone(LA) {
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) {
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId,
+ SQLConf.LEGACY_INTERVAL_ENABLED.key -> true.toString) {
val N = 10000000
runBenchmark("datetime +/- interval") {
val benchmark = new Benchmark("datetime +/- interval", N, output = output)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
index 849c413072..787fdc7b59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
@@ -44,7 +44,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v"))
codegenBenchmark("Join w long", N) {
val df = spark.range(N).join(dim, (col("id") % M) === col("k"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -55,7 +55,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val dim = broadcast(spark.range(M).selectExpr("cast(id/10 as long) as k"))
codegenBenchmark("Join w long duplicated", N) {
val df = spark.range(N).join(dim, (col("id") % M) === col("k"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -70,7 +70,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val df = spark.range(N).join(dim2,
(col("id") % M).cast(IntegerType) === col("k1")
&& (col("id") % M).cast(IntegerType) === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -84,7 +84,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
codegenBenchmark("Join w 2 longs", N) {
val df = spark.range(N).join(dim3,
(col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -98,7 +98,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
codegenBenchmark("Join w 2 longs duplicated", N) {
val df = spark.range(N).join(dim4,
(col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -109,7 +109,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v"))
codegenBenchmark("outer join w long", N) {
val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "left")
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -120,7 +120,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v"))
codegenBenchmark("semi join w long", N) {
val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi")
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec]))
df.noop()
}
}
@@ -131,7 +131,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val df1 = spark.range(N).selectExpr(s"id * 2 as k1")
val df2 = spark.range(N).selectExpr(s"id * 3 as k2")
val df = df1.join(df2, col("k1") === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec]))
df.noop()
}
}
@@ -144,7 +144,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val df2 = spark.range(N)
.selectExpr(s"(id * 15485867) % ${N*10} as k2")
val df = df1.join(df2, col("k1") === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec]))
df.noop()
}
}
@@ -159,7 +159,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val df1 = spark.range(N).selectExpr(s"id as k1")
val df2 = spark.range(N / 3).selectExpr(s"id * 3 as k2")
val df = df1.join(df2, col("k1") === col("k2"))
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[ShuffledHashJoinExec]))
df.noop()
}
}
@@ -172,8 +172,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v"))
codegenBenchmark("broadcast nested loop join", N) {
val df = spark.range(N).join(dim)
- assert(df.queryExecution.sparkPlan.find(
- _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined)
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastNestedLoopJoinExec]))
df.noop()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala
index e9bdff5853..31d5fd9ffd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala
@@ -49,7 +49,7 @@ object RangeBenchmark extends SqlBasedBenchmark {
}
benchmark.addCase("filter after range", numIters = 4) { _ =>
- spark.range(N).filter('id % 100 === 0).noop()
+ spark.range(N).filter(Symbol("id") % 100 === 0).noop()
}
benchmark.addCase("count after range", numIters = 4) { _ =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala
index f84172278b..78d6b01580 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.benchmark
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.SaveMode.Overwrite
@@ -41,6 +42,7 @@ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper {
.config(SQLConf.SHUFFLE_PARTITIONS.key, 1)
.config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1)
.config(UI_ENABLED.key, false)
+ .config(MAX_RESULT_SIZE.key, "3g")
.getOrCreate()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala
new file mode 100644
index 0000000000..a98c8d8a23
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType}
+import org.apache.spark.util.Utils
+
+/**
+ * Synthetic benchmark for State Store basic operations.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class>
+ * --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar>
+ * 2. build/sbt "sql/test:runMain <this class>"
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ * Results will be written to "benchmarks/StateStoreBasicOperationsBenchmark-results.txt".
+ * }}}
+ */
+object StateStoreBasicOperationsBenchmark extends SqlBasedBenchmark {
+
+ private val keySchema = StructType(
+ Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true)))
+ private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ private val keyProjection = UnsafeProjection.create(keySchema)
+ private val valueProjection = UnsafeProjection.create(valueSchema)
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runPutBenchmark()
+ runDeleteBenchmark()
+ runEvictBenchmark()
+ }
+
+ final def skip(benchmarkName: String)(func: => Any): Unit = {
+ output.foreach(_.write(s"$benchmarkName is skipped".getBytes))
+ }
+
+ private def runPutBenchmark(): Unit = {
+ def registerPutBenchmarkCase(
+ benchmark: Benchmark,
+ testName: String,
+ provider: StateStoreProvider,
+ version: Long,
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+ benchmark.addTimerCase(testName) { timer =>
+ val store = provider.getStore(version)
+
+ timer.startTiming()
+ updateRows(store, rows)
+ timer.stopTiming()
+
+ store.abort()
+ }
+ }
+
+ runBenchmark("put rows") {
+ val numOfRows = Seq(10000)
+ val overwriteRates = Seq(100, 75, 50, 25, 10, 5, 0)
+
+ numOfRows.foreach { numOfRow =>
+ val testData = constructRandomizedTestData(numOfRow,
+ (1 to numOfRow).map(_ * 1000L).toList, 0)
+
+ val inMemoryProvider = newHDFSBackedStateStoreProvider()
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false)
+
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData)
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData)
+ val committedRocksDBWithNoTrackVersion = loadInitialData(
+ rocksDBWithNoTrackProvider, testData)
+
+ overwriteRates.foreach { overwriteRate =>
+ val numOfRowsToOverwrite = numOfRow * overwriteRate / 100
+
+ val numOfNewRows = numOfRow - numOfRowsToOverwrite
+ val newRows = if (numOfNewRows > 0) {
+ constructRandomizedTestData(numOfNewRows,
+ (1 to numOfNewRows).map(_ * 1000L).toList, 0)
+ } else {
+ Seq.empty[(UnsafeRow, UnsafeRow)]
+ }
+ val existingRows = if (numOfRowsToOverwrite > 0) {
+ Random.shuffle(testData).take(numOfRowsToOverwrite)
+ } else {
+ Seq.empty[(UnsafeRow, UnsafeRow)]
+ }
+ val rowsToPut = Random.shuffle(newRows ++ existingRows)
+
+ val benchmark = new Benchmark(s"putting $numOfRow rows " +
+ s"($numOfRowsToOverwrite rows to overwrite - rate $overwriteRate)",
+ numOfRow, minNumIters = 10000, output = output)
+
+ registerPutBenchmarkCase(benchmark, "In-memory", inMemoryProvider,
+ committedInMemoryVersion, rowsToPut)
+ registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)",
+ rocksDBProvider, committedRocksDBVersion, rowsToPut)
+ registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)",
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, rowsToPut)
+
+ benchmark.run()
+ }
+
+ inMemoryProvider.close()
+ rocksDBProvider.close()
+ rocksDBWithNoTrackProvider.close()
+ }
+ }
+ }
+
+ private def runDeleteBenchmark(): Unit = {
+ def registerDeleteBenchmarkCase(
+ benchmark: Benchmark,
+ testName: String,
+ provider: StateStoreProvider,
+ version: Long,
+ keys: Seq[UnsafeRow]): Unit = {
+ benchmark.addTimerCase(testName) { timer =>
+ val store = provider.getStore(version)
+
+ timer.startTiming()
+ deleteRows(store, keys)
+ timer.stopTiming()
+
+ store.abort()
+ }
+ }
+
+ runBenchmark("delete rows") {
+ val numOfRows = Seq(10000)
+ val nonExistRates = Seq(100, 75, 50, 25, 10, 5, 0)
+ numOfRows.foreach { numOfRow =>
+ val testData = constructRandomizedTestData(numOfRow,
+ (1 to numOfRow).map(_ * 1000L).toList, 0)
+
+ val inMemoryProvider = newHDFSBackedStateStoreProvider()
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false)
+
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData)
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData)
+ val committedRocksDBWithNoTrackVersion = loadInitialData(
+ rocksDBWithNoTrackProvider, testData)
+
+ nonExistRates.foreach { nonExistRate =>
+ val numOfRowsNonExist = numOfRow * nonExistRate / 100
+
+ val numOfExistingRows = numOfRow - numOfRowsNonExist
+ val nonExistingRows = if (numOfRowsNonExist > 0) {
+ constructRandomizedTestData(numOfRowsNonExist,
+ (numOfRow + 1 to numOfRow + numOfRowsNonExist).map(_ * 1000L).toList, 0)
+ } else {
+ Seq.empty[(UnsafeRow, UnsafeRow)]
+ }
+ val existingRows = if (numOfExistingRows > 0) {
+ Random.shuffle(testData).take(numOfExistingRows)
+ } else {
+ Seq.empty[(UnsafeRow, UnsafeRow)]
+ }
+ val keysToDelete = Random.shuffle(nonExistingRows ++ existingRows).map(_._1)
+
+ val benchmark = new Benchmark(s"trying to delete $numOfRow rows " +
+ s"from $numOfRow rows" +
+ s"($numOfRowsNonExist rows are non-existing - rate $nonExistRate)",
+ numOfRow, minNumIters = 10000, output = output)
+
+ registerDeleteBenchmarkCase(benchmark, "In-memory", inMemoryProvider,
+ committedInMemoryVersion, keysToDelete)
+ registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)",
+ rocksDBProvider, committedRocksDBVersion, keysToDelete)
+ registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)",
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, keysToDelete)
+
+ benchmark.run()
+ }
+
+ inMemoryProvider.close()
+ rocksDBProvider.close()
+ rocksDBWithNoTrackProvider.close()
+ }
+ }
+ }
+
+ private def runEvictBenchmark(): Unit = {
+ def registerEvictBenchmarkCase(
+ benchmark: Benchmark,
+ testName: String,
+ provider: StateStoreProvider,
+ version: Long,
+ maxTimestampToEvictInMillis: Long,
+ expectedNumOfRows: Long): Unit = {
+ benchmark.addTimerCase(testName) { timer =>
+ val store = provider.getStore(version)
+
+ timer.startTiming()
+ evictAsFullScanAndRemove(store, maxTimestampToEvictInMillis,
+ expectedNumOfRows)
+ timer.stopTiming()
+
+ store.abort()
+ }
+ }
+
+ runBenchmark("evict rows") {
+ val numOfRows = Seq(10000)
+ val numOfEvictionRates = Seq(100, 75, 50, 25, 10, 5, 0)
+
+ numOfRows.foreach { numOfRow =>
+ val timestampsInMicros = (0L until numOfRow).map(ts => ts * 1000L).toList
+
+ val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0)
+
+ val inMemoryProvider = newHDFSBackedStateStoreProvider()
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false)
+
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData)
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData)
+ val committedRocksDBWithNoTrackVersion = loadInitialData(
+ rocksDBWithNoTrackProvider, testData)
+
+ numOfEvictionRates.foreach { numOfEvictionRate =>
+ val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100
+ val maxTimestampToEvictInMillis = timestampsInMicros
+ .take(numOfRow * numOfEvictionRate / 100)
+ .lastOption.map(_ / 1000).getOrElse(-1L)
+
+ val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " +
+ s"(maxTimestampToEvictInMillis: $maxTimestampToEvictInMillis) " +
+ s"from $numOfRow rows",
+ numOfRow, minNumIters = 10000, output = output)
+
+ registerEvictBenchmarkCase(benchmark, "In-memory", inMemoryProvider,
+ committedInMemoryVersion, maxTimestampToEvictInMillis, numOfRowsToEvict)
+
+ registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)",
+ rocksDBProvider, committedRocksDBVersion, maxTimestampToEvictInMillis,
+ numOfRowsToEvict)
+
+ registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)",
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion,
+ maxTimestampToEvictInMillis, numOfRowsToEvict)
+
+ benchmark.run()
+ }
+
+ inMemoryProvider.close()
+ rocksDBProvider.close()
+ rocksDBWithNoTrackProvider.close()
+ }
+ }
+ }
+
+ private def getRows(store: StateStore, keys: Seq[UnsafeRow]): Seq[UnsafeRow] = {
+ keys.map(store.get)
+ }
+
+ private def loadInitialData(
+ provider: StateStoreProvider,
+ data: Seq[(UnsafeRow, UnsafeRow)]): Long = {
+ val store = provider.getStore(0)
+ updateRows(store, data)
+ store.commit()
+ }
+
+ private def updateRows(
+ store: StateStore,
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+ rows.foreach { case (key, value) =>
+ store.put(key, value)
+ }
+ }
+
+ private def deleteRows(
+ store: StateStore,
+ rows: Seq[UnsafeRow]): Unit = {
+ rows.foreach { key =>
+ store.remove(key)
+ }
+ }
+
+ private def evictAsFullScanAndRemove(
+ store: StateStore,
+ maxTimestampToEvictMillis: Long,
+ expectedNumOfRows: Long): Unit = {
+ var removedRows: Long = 0
+ store.iterator().foreach { r =>
+ if (r.key.getLong(1) <= maxTimestampToEvictMillis * 1000L) {
+ store.remove(r.key)
+ removedRows += 1
+ }
+ }
+ assert(removedRows == expectedNumOfRows,
+ s"expected: $expectedNumOfRows actual: $removedRows")
+ }
+
+ // This prevents created keys to be in order, which may affect the performance on RocksDB.
+ private def constructRandomizedTestData(
+ numRows: Int,
+ timestamps: List[Long],
+ minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+ assert(numRows >= timestamps.length)
+ assert(numRows % timestamps.length == 0)
+
+ (1 to numRows).map { idx =>
+ val keyRow = new GenericInternalRow(2)
+ keyRow.setInt(0, Random.nextInt(Int.MaxValue))
+ keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds
+ val valueRow = new GenericInternalRow(1)
+ valueRow.setInt(0, minIdx + idx)
+
+ val keyUnsafeRow = keyProjection(keyRow).copy()
+ val valueUnsafeRow = valueProjection(valueRow).copy()
+
+ (keyUnsafeRow, valueUnsafeRow)
+ }
+ }
+
+ private def newHDFSBackedStateStoreProvider(): StateStoreProvider = {
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+ val provider = new HDFSBackedStateStoreProvider()
+ val storeConf = new StateStoreConf(new SQLConf())
+ provider.init(
+ storeId, keySchema, valueSchema, 0,
+ storeConf, new Configuration)
+ provider
+ }
+
+ private def newRocksDBStateProvider(
+ trackTotalNumberOfRows: Boolean = true): StateStoreProvider = {
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+ val provider = new RocksDBStateStoreProvider()
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString("spark.sql.streaming.stateStore.rocksdb.trackTotalNumberOfRows",
+ trackTotalNumberOfRows.toString)
+ val storeConf = new StateStoreConf(sqlConf)
+
+ provider.init(
+ storeId, keySchema, valueSchema, 0,
+ storeConf, new Configuration)
+ provider
+ }
+
+ private def newDir(): String = Utils.createTempDir().toString
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala
index 099a1aa996..645dc870d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.execution.columnar
+import scala.collection.JavaConverters._
+
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection}
import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer}
import org.apache.spark.sql.execution.columnar.InMemoryRelation.clearSerializer
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -101,7 +103,13 @@ class TestSingleIntColumnarCachedBatchSerializer extends CachedBatchSerializer {
cacheAttributes: Seq[Attribute],
selectedAttributes: Seq[Attribute],
conf: SQLConf): RDD[InternalRow] = {
- throw new IllegalStateException("This does not work. This is only for testing")
+ convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf)
+ .mapPartitionsInternal { batches =>
+ val toUnsafe = UnsafeProjection.create(selectedAttributes, selectedAttributes)
+ batches.flatMap { batch =>
+ batch.rowIterator().asScala.map(toUnsafe)
+ }
+ }
}
override def buildFilter(
@@ -138,8 +146,9 @@ class CachedBatchSerializerSuite extends QueryTest with SharedSparkSession {
input.write.parquet(workDirPath)
val data = spark.read.parquet(workDirPath)
data.cache()
- assert(data.count() == 3)
- checkAnswer(data, Row(100) :: Row(200) :: Row(300) :: Nil)
+ val df = data.union(data)
+ assert(df.count() == 6)
+ checkAnswer(df, Row(100) :: Row(200) :: Row(300) :: Row(100) :: Row(200) :: Row(300) :: Nil)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala
new file mode 100644
index 0000000000..55d9fb2731
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.columnar
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.execution.ColumnarToRowExec
+import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
+
+/**
+ * Benchmark to low level memory access using different ways to manage buffers.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class>
+ * --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar> <rowsNum>
+ * 2. build/sbt "sql/Test/runMain <this class> <rowsNum>"
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class>
+ * <rowsNum>"
+ * Results will be written to "benchmarks/InMemoryColumnarBenchmark-results.txt".
+ * }}}
+ */
+object InMemoryColumnarBenchmark extends SqlBasedBenchmark {
+ def intCache(rowsNum: Long, numIters: Int): Unit = {
+ val data = spark.range(0, rowsNum, 1, 1).toDF("i").cache()
+
+ val inMemoryScan = data.queryExecution.executedPlan.collect {
+ case m: InMemoryTableScanExec => m
+ }
+
+ val columnarScan = ColumnarToRowExec(inMemoryScan(0))
+ val rowScan = inMemoryScan(0)
+
+ assert(inMemoryScan.size == 1)
+
+ val benchmark = new Benchmark("Int In-Memory scan", rowsNum, output = output)
+
+ benchmark.addCase("columnar deserialization + columnar-to-row", numIters) { _ =>
+ columnarScan.executeCollect()
+ }
+
+ benchmark.addCase("row-based deserialization", numIters) { _ =>
+ rowScan.executeCollect()
+ }
+
+ benchmark.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ val rowsNum = if (mainArgs.length > 0) mainArgs(0).toLong else 1000000
+ runBenchmark(s"Int In-memory with $rowsNum rows") {
+ intCache(rowsNum = rowsNum, numIters = 3)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index b8f73f4563..779aa49a34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
+import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.columnar.CachedBatch
-import org.apache.spark.sql.execution.{ColumnarToRowExec, FilterExec, InputAdapter, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{FilterExec, InputAdapter, WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -152,7 +153,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
}
test("projection") {
- val logicalPlan = testData.select('value, 'key).logicalPlan
+ val logicalPlan = testData.select(Symbol("value"), Symbol("key")).logicalPlan
val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan
val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5),
MEMORY_ONLY, plan, None, logicalPlan)
@@ -504,8 +505,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
val df2 = df1.where("y = 3")
val planBeforeFilter = df2.queryExecution.executedPlan.collect {
- case FilterExec(_, c: ColumnarToRowExec) => c.child
- case WholeStageCodegenExec(FilterExec(_, ColumnarToRowExec(i: InputAdapter))) => i.child
+ case f: FilterExec => f.child
+ case WholeStageCodegenExec(FilterExec(_, i: InputAdapter)) => i.child
}
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
@@ -563,4 +564,56 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
}
}
}
+
+ test("SPARK-39104: InMemoryRelation#isCachedColumnBuffersLoaded should be thread-safe") {
+ val plan = spark.range(1).queryExecution.executedPlan
+ val serializer = new TestCachedBatchSerializer(true, 1)
+ val cachedRDDBuilder = CachedRDDBuilder(serializer, MEMORY_ONLY, plan, None)
+
+ @volatile var isCachedColumnBuffersLoaded = false
+ @volatile var stopped = false
+
+ val th1 = new Thread {
+ override def run(): Unit = {
+ while (!isCachedColumnBuffersLoaded && !stopped) {
+ cachedRDDBuilder.cachedColumnBuffers
+ cachedRDDBuilder.clearCache()
+ }
+ }
+ }
+
+ val th2 = new Thread {
+ override def run(): Unit = {
+ while (!isCachedColumnBuffersLoaded && !stopped) {
+ isCachedColumnBuffersLoaded = cachedRDDBuilder.isCachedColumnBuffersLoaded
+ }
+ }
+ }
+
+ val th3 = new Thread {
+ override def run(): Unit = {
+ Thread.sleep(3000L)
+ stopped = true
+ }
+ }
+
+ val exceptionCnt = new AtomicInteger
+ val exceptionHandler: Thread.UncaughtExceptionHandler = (_: Thread, cause: Throwable) => {
+ exceptionCnt.incrementAndGet
+ fail(cause)
+ }
+
+ th1.setUncaughtExceptionHandler(exceptionHandler)
+ th2.setUncaughtExceptionHandler(exceptionHandler)
+ th1.start()
+ th2.start()
+ th3.start()
+ th1.join()
+ th2.join()
+ th3.join()
+
+ cachedRDDBuilder.clearCache()
+
+ assert(exceptionCnt.get == 0)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala
new file mode 100644
index 0000000000..bc1ffb93fe
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan
+import org.apache.spark.sql.catalyst.plans.logical.SetNamespaceLocation
+
+class AlterNamespaceSetLocationParserSuite extends AnalysisTest {
+ test("set namespace location") {
+ comparePlans(
+ parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"),
+ SetNamespaceLocation(
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db"))
+
+ comparePlans(
+ parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"),
+ SetNamespaceLocation(
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db"))
+
+ comparePlans(
+ parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"),
+ SetNamespaceLocation(
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala
new file mode 100644
index 0000000000..25bae01821
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.connector.catalog.SupportsNamespaces
+
+/**
+ * This base suite contains unified tests for the `ALTER NAMESPACE ... SET LOCATION` command that
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are located
+ * in more specific test suites:
+ *
+ * - V2 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v2.AlterNamespaceSetLocationSuite`
+ * - V1 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuiteBase`
+ * - V1 In-Memory catalog:
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuite`
+ * - V1 Hive External catalog:
+ * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetLocationSuite`
+ */
+trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUtils {
+ override val command = "ALTER NAMESPACE ... SET LOCATION"
+
+ protected def namespace: String
+
+ protected def notFoundMsgPrefix: String
+
+ test("Empty location string") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns")
+ val message = intercept[IllegalArgumentException] {
+ sql(s"ALTER NAMESPACE $ns SET LOCATION ''")
+ }.getMessage
+ assert(message.contains("Can not create a Path from an empty string"))
+ }
+ }
+
+ test("Namespace does not exist") {
+ val ns = "not_exist"
+ val message = intercept[AnalysisException] {
+ sql(s"ALTER DATABASE $catalog.$ns SET LOCATION 'loc'")
+ }.getMessage
+ assert(message.contains(s"$notFoundMsgPrefix '$ns' not found"))
+ }
+
+ // Hive catalog does not support "ALTER NAMESPACE ... SET LOCATION", thus
+ // this is called from non-Hive v1 and v2 tests.
+ protected def runBasicTest(): Unit = {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE IF NOT EXISTS $ns COMMENT " +
+ "'test namespace' LOCATION '/tmp/loc_test_1'")
+ sql(s"ALTER NAMESPACE $ns SET LOCATION '/tmp/loc_test_2'")
+ assert(getLocation(ns).contains("file:/tmp/loc_test_2"))
+ }
+ }
+
+ protected def getLocation(namespace: String): String = {
+ val locationRow = sql(s"DESCRIBE NAMESPACE EXTENDED $namespace")
+ .toDF("key", "value")
+ .where(s"key like '${SupportsNamespaces.PROP_LOCATION.capitalize}%'")
+ .collect()
+ assert(locationRow.length == 1)
+ locationRow(0).getString(1)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala
new file mode 100644
index 0000000000..868dc275b8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.logical.SetNamespaceProperties
+
+class AlterNamespaceSetPropertiesParserSuite extends AnalysisTest {
+ test("set namespace properties") {
+ Seq("DATABASE", "SCHEMA", "NAMESPACE").foreach { nsToken =>
+ Seq("PROPERTIES", "DBPROPERTIES").foreach { propToken =>
+ comparePlans(
+ parsePlan(s"ALTER $nsToken a.b.c SET $propToken ('a'='a', 'b'='b', 'c'='c')"),
+ SetNamespaceProperties(
+ UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a", "b" -> "b", "c" -> "c")))
+
+ comparePlans(
+ parsePlan(s"ALTER $nsToken a.b.c SET $propToken ('a'='a')"),
+ SetNamespaceProperties(
+ UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a")))
+ }
+ }
+ }
+
+ test("property values must be set") {
+ val e = intercept[ParseException] {
+ parsePlan("ALTER NAMESPACE my_db SET PROPERTIES('key_without_value', 'key_with_value'='x')")
+ }
+ assert(e.getMessage.contains(
+ "Operation not allowed: Values must be specified for key(s): [key_without_value]"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala
new file mode 100644
index 0000000000..1351d09e03
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * This base suite contains unified tests for the `ALTER NAMESPACE ... SET PROPERTIES` command that
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are located
+ * in more specific test suites:
+ *
+ * - V2 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v2.AlterNamespaceSetPropertiesSuite`
+ * - V1 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetPropertiesSuiteBase`
+ * - V1 In-Memory catalog:
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetPropertiesSuite`
+ * - V1 Hive External catalog:
+ * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetPropertiesSuite`
+ */
+trait AlterNamespaceSetPropertiesSuiteBase extends QueryTest with DDLCommandTestUtils {
+ override val command = "ALTER NAMESPACE ... SET PROPERTIES"
+
+ protected def namespace: String
+
+ protected def notFoundMsgPrefix: String
+
+ test("Namespace does not exist") {
+ val ns = "not_exist"
+ val message = intercept[AnalysisException] {
+ sql(s"ALTER DATABASE $catalog.$ns SET PROPERTIES ('d'='d')")
+ }.getMessage
+ assert(message.contains(s"$notFoundMsgPrefix '$ns' not found"))
+ }
+
+ test("basic test") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns")
+ assert(getProperties(ns) === "")
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c))")
+ sql(s"ALTER DATABASE $ns SET PROPERTIES ('d'='d')")
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c), (d,d))")
+ }
+ }
+
+ test("test with properties set while creating namespace") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns WITH PROPERTIES ('a'='a','b'='b','c'='c')")
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c))")
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('a'='b', 'b'='a')")
+ assert(getProperties(ns) === "((a,b), (b,a), (c,c))")
+ }
+ }
+
+ test("test reserved properties") {
+ import SupportsNamespaces._
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+ val ns = s"$catalog.$namespace"
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) {
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns")
+ val exception = intercept[ParseException] {
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('$key'='dummyVal')")
+ }
+ assert(exception.getMessage.contains(s"$key is a reserved namespace property"))
+ }
+ }
+ }
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) {
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
+ withNamespace(ns) {
+ // Set the location explicitly because v2 catalog may not set the default location.
+ // Without this, `meta.get(key)` below may return null.
+ sql(s"CREATE NAMESPACE $ns LOCATION 'tmp/prop_test'")
+ assert(getProperties(ns) === "")
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('$key'='foo')")
+ assert(getProperties(ns) === "", s"$key is a reserved namespace property and ignored")
+ val meta = spark.sessionState.catalogManager.catalog(catalog)
+ .asNamespaceCatalog.loadNamespaceMetadata(namespace.split('.'))
+ assert(!meta.get(key).contains("foo"),
+ "reserved properties should not have side effects")
+ }
+ }
+ }
+ }
+
+ protected def getProperties(namespace: String): String = {
+ val propsRow = sql(s"DESCRIBE NAMESPACE EXTENDED $namespace")
+ .toDF("key", "value")
+ .where("key like 'Properties%'")
+ .collect()
+ assert(propsRow.length == 1)
+ propsRow(0).getString(1)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala
new file mode 100644
index 0000000000..9a9b3378e8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.{QueryTest, Row}
+
+/**
+ * This base suite contains unified tests for the `ALTER TABLE .. ADD COLUMNS` command that
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are
+ * located in more specific test suites:
+ *
+ * - V2 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v2.AlterTableAddColumnsSuite`
+ * - V1 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v1.AlterTableAddColumnsSuiteBase`
+ * - V1 In-Memory catalog:
+ * `org.apache.spark.sql.execution.command.v1.AlterTableAddColumnsSuite`
+ * - V1 Hive External catalog:
+ * `org.apache.spark.sql.hive.execution.command.AlterTableAddColumnsSuite`
+ */
+trait AlterTableAddColumnsSuiteBase extends QueryTest with DDLCommandTestUtils {
+ override val command = "ALTER TABLE .. ADD COLUMNS"
+
+ test("add an ANSI interval columns") {
+ assume(!catalogVersion.contains("Hive")) // Hive catalog doesn't support the interval types
+
+ withNamespaceAndTable("ns", "tbl") { t =>
+ sql(s"CREATE TABLE $t (id bigint) $defaultUsing")
+ sql(s"ALTER TABLE $t ADD COLUMNS (ym INTERVAL YEAR, dt INTERVAL HOUR)")
+ sql(s"INSERT INTO $t SELECT 0, INTERVAL '100' YEAR, INTERVAL '10' HOUR")
+ checkAnswer(
+ sql(s"SELECT id, ym, dt data FROM $t"),
+ Seq(Row(0, Period.ofYears(100), Duration.ofHours(10))))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
index e2e15917d7..dee14953c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{AnalysisException, QueryTest}
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException
import org.apache.spark.sql.internal.SQLConf
@@ -189,4 +191,40 @@ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils
checkPartitions(t, Map("part" ->"2020-01-01"))
}
}
+
+ test("SPARK-37261: Add ANSI intervals as partition values") {
+ assume(!catalogVersion.contains("Hive")) // Hive catalog doesn't support the interval types
+
+ withNamespaceAndTable("ns", "tbl") { t =>
+ sql(
+ s"""CREATE TABLE $t (
+ | ym INTERVAL YEAR,
+ | dt INTERVAL DAY,
+ | data STRING) $defaultUsing
+ |PARTITIONED BY (ym, dt)""".stripMargin)
+ sql(
+ s"""ALTER TABLE $t ADD PARTITION (
+ | ym = INTERVAL '100' YEAR,
+ | dt = INTERVAL '10' DAY
+ |) LOCATION 'loc'""".stripMargin)
+
+ checkPartitions(t, Map("ym" -> "INTERVAL '100' YEAR", "dt" -> "INTERVAL '10' DAY"))
+ checkLocation(t, Map("ym" -> "INTERVAL '100' YEAR", "dt" -> "INTERVAL '10' DAY"), "loc")
+
+ sql(
+ s"""INSERT INTO $t PARTITION (
+ | ym = INTERVAL '100' YEAR,
+ | dt = INTERVAL '10' DAY) SELECT 'aaa'""".stripMargin)
+ sql(
+ s"""INSERT INTO $t PARTITION (
+ | ym = INTERVAL '1' YEAR,
+ | dt = INTERVAL '-1' DAY) SELECT 'bbb'""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT ym, dt, data FROM $t"),
+ Seq(
+ Row(Period.ofYears(100), Duration.ofDays(10), "aaa"),
+ Row(Period.ofYears(1), Duration.ofDays(-1), "bbb")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala
index ebc1bd3468..394392299b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala
@@ -29,7 +29,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa
val errMsg = intercept[ParseException] {
parsePlan("ALTER TABLE RECOVER PARTITIONS")
}.getMessage
- assert(errMsg.contains("no viable alternative at input 'ALTER TABLE RECOVER PARTITIONS'"))
+ assert(errMsg.contains("Syntax error at or near 'PARTITIONS'"))
}
test("recover partitions of a table") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala
index 1803ec0469..2942d61f7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala
@@ -136,4 +136,12 @@ trait AlterTableRenameSuiteBase extends QueryTest with DDLCommandTestUtils {
checkAnswer(spark.table(dst), Row(1, 2))
}
}
+
+ test("SPARK-38587: use formatted names") {
+ withNamespaceAndTable("CaseUpperCaseLower", "CaseUpperCaseLower") { t =>
+ sql(s"CREATE TABLE ${t}_Old (i int) $defaultUsing")
+ sql(s"ALTER TABLE ${t}_Old RENAME TO CaseUpperCaseLower.CaseUpperCaseLower")
+ assert(spark.table(t).isEmpty)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala
new file mode 100644
index 0000000000..fed40767c2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import java.time.{Duration, Period}
+
+import org.apache.spark.sql.{QueryTest, Row}
+
+/**
+ * This base suite contains unified tests for the `ALTER TABLE .. REPLACE COLUMNS` command that
+ * check the V2 table catalog. The tests that cannot run for all supported catalogs are
+ * located in more specific test suites:
+ *
+ * - V2 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v2.AlterTableReplaceColumnsSuite`
+ */
+trait AlterTableReplaceColumnsSuiteBase extends QueryTest with DDLCommandTestUtils {
+ override val command = "ALTER TABLE .. REPLACE COLUMNS"
+
+ test("SPARK-37304: Replace columns by ANSI intervals") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ sql(s"CREATE TABLE $t (ym INTERVAL MONTH, dt INTERVAL HOUR, data STRING) $defaultUsing")
+ // TODO(SPARK-37303): Uncomment the command below after REPLACE COLUMNS is fixed
+ // sql(s"INSERT INTO $t SELECT INTERVAL '1' MONTH, INTERVAL '2' HOUR, 'abc'")
+ sql(
+ s"""
+ |ALTER TABLE $t REPLACE COLUMNS (
+ | new_ym INTERVAL YEAR,
+ | new_dt INTERVAL MINUTE,
+ | new_data INT)""".stripMargin)
+ sql(s"INSERT INTO $t SELECT INTERVAL '3' YEAR, INTERVAL '4' MINUTE, 5")
+
+ checkAnswer(
+ sql(s"SELECT new_ym, new_dt, new_data FROM $t"),
+ Seq(
+ Row(Period.ofYears(3), Duration.ofMinutes(4), 5)))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
index ba683c049a..f77b6336b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala
@@ -98,23 +98,75 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils {
}
}
- def checkTableSchemaTypeStr(expected: Seq[Row]): Unit = {
- checkAnswer(sql("desc t").selectExpr("data_type").where("data_type like '%char%'"), expected)
+ def checkTableSchemaTypeStr(table: String, expected: Seq[Row]): Unit = {
+ checkAnswer(
+ sql(s"desc $table").selectExpr("data_type").where("data_type like '%char%'"),
+ expected)
}
test("SPARK-33901: alter table add columns should not change original table's schema") {
withTable("t") {
sql(s"CREATE TABLE t(i CHAR(5), c VARCHAR(4)) USING $format")
sql("ALTER TABLE t ADD COLUMNS (d VARCHAR(5))")
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)"), Row("varchar(5)")))
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)"), Row("varchar(5)")))
}
}
test("SPARK-33901: ctas should should not change table's schema") {
- withTable("t", "tt") {
- sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format")
- sql(s"CREATE TABLE t USING $format AS SELECT * FROM tt")
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)")))
+ withTable("t1", "t2") {
+ sql(s"CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING $format")
+ sql(s"CREATE TABLE t2 USING $format AS SELECT * FROM t1")
+ checkTableSchemaTypeStr("t2", Seq(Row("char(5)"), Row("varchar(4)")))
+ }
+ }
+
+ test("SPARK-37160: CREATE TABLE with CHAR_AS_VARCHAR") {
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)")))
+ }
+ }
+ }
+
+ test("SPARK-37160: CREATE TABLE AS SELECT with CHAR_AS_VARCHAR") {
+ withTable("t1", "t2") {
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)")))
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ sql(s"CREATE TABLE t2 USING $format AS SELECT * FROM t1")
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)")))
+ }
+ }
+ }
+
+ test("SPARK-37160: ALTER TABLE ADD COLUMN with CHAR_AS_VARCHAR") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)")))
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ sql("ALTER TABLE t ADD COLUMN c2 CHAR(10)")
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(10)")))
+ }
+ }
+ }
+
+ test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format")
+ checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"),
+ Row("varchar(3)"))
+ checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"),
+ Row("char(5)"))
+ }
+ }
+
+ test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format")
+ val rest = sql("SHOW CREATE TABLE t").head().getString(0)
+ assert(rest.contains("VARCHAR(3)"))
+ assert(rest.contains("CHAR(5)"))
}
}
}
@@ -130,7 +182,7 @@ class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with Shar
withTable("t", "tt") {
sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format")
sql("CREATE TABLE t LIKE tt")
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)")))
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)")))
}
}
@@ -140,7 +192,19 @@ class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with Shar
sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format")
withView("t") {
sql("CREATE VIEW t AS SELECT * FROM tt")
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)")))
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)")))
+ }
+ }
+ }
+
+ // TODO(SPARK-33902): MOVE TO SUPER CLASS AFTER THE TARGET TICKET RESOLVED
+ test("SPARK-37160: CREATE TABLE LIKE with CHAR_AS_VARCHAR") {
+ withTable("t1", "t2") {
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)")))
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ sql(s"CREATE TABLE t2 LIKE t1")
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)")))
}
}
}
@@ -196,4 +260,40 @@ class DSV2CharVarcharDDLTestSuite extends CharVarcharDDLTestBase
assert(e.getMessage contains "char(4) cannot be cast to varchar(3)")
}
}
+
+ test("SPARK-37160: REPLACE TABLE with CHAR_AS_VARCHAR") {
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(col INT) USING $format")
+ sql(s"REPLACE TABLE t(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)")))
+ }
+ }
+ }
+
+ test("SPARK-37160: REPLACE TABLE AS SELECT with CHAR_AS_VARCHAR") {
+ withTable("t1", "t2") {
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format")
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)")))
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ sql(s"CREATE TABLE t2(col INT) USING $format")
+ sql(s"REPLACE TABLE t2 AS SELECT * FROM t1")
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)")))
+ }
+ }
+ }
+
+ test("SPARK-37160: ALTER TABLE ALTER/REPLACE COLUMN with CHAR_AS_VARCHAR") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(col CHAR(5), c2 VARCHAR(10)) USING $format")
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(10)")))
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
+ sql("ALTER TABLE t ALTER c2 TYPE CHAR(20)")
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(20)")))
+
+ sql("ALTER TABLE t REPLACE COLUMNS (col CHAR(5))")
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)")))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala
new file mode 100644
index 0000000000..6c59512148
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedDBObjectName}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan
+import org.apache.spark.sql.catalyst.plans.logical.CreateNamespace
+
+class CreateNamespaceParserSuite extends AnalysisTest {
+ test("create namespace -- backward compatibility with DATABASE/DBPROPERTIES") {
+ val expected = CreateNamespace(
+ UnresolvedDBObjectName(Seq("a", "b", "c"), true),
+ ifNotExists = true,
+ Map(
+ "a" -> "a",
+ "b" -> "b",
+ "c" -> "c",
+ "comment" -> "namespace_comment",
+ "location" -> "/home/user/db"))
+
+ comparePlans(
+ parsePlan(
+ """
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db'
+ """.stripMargin),
+ expected)
+
+ comparePlans(
+ parsePlan(
+ """
+ |CREATE DATABASE IF NOT EXISTS a.b.c
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db'
+ """.stripMargin),
+ expected)
+ }
+
+ test("create namespace -- check duplicates") {
+ def createNamespace(duplicateClause: String): String = {
+ s"""
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |$duplicateClause
+ |$duplicateClause
+ """.stripMargin
+ }
+ val sql1 = createNamespace("COMMENT 'namespace_comment'")
+ val sql2 = createNamespace("LOCATION '/home/user/db'")
+ val sql3 = createNamespace("WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+ val sql4 = createNamespace("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+
+ intercept(sql1, "Found duplicate clauses: COMMENT")
+ intercept(sql2, "Found duplicate clauses: LOCATION")
+ intercept(sql3, "Found duplicate clauses: WITH PROPERTIES")
+ intercept(sql4, "Found duplicate clauses: WITH DBPROPERTIES")
+ }
+
+ test("create namespace - property values must be set") {
+ intercept(
+ "CREATE NAMESPACE a.b.c WITH PROPERTIES('key_without_value', 'key_with_value'='x')",
+ "Operation not allowed: Values must be specified for key(s): [key_without_value]")
+ }
+
+ test("create namespace -- either PROPERTIES or DBPROPERTIES is allowed") {
+ val sql =
+ s"""
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ """.stripMargin
+ intercept(sql, "The feature is not supported: " +
+ "set PROPERTIES and DBPROPERTIES at the same time.")
+ }
+
+ test("create namespace - support for other types in PROPERTIES") {
+ val sql =
+ """
+ |CREATE NAMESPACE a.b.c
+ |LOCATION '/home/user/db'
+ |WITH PROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE)
+ """.stripMargin
+ comparePlans(
+ parsePlan(sql),
+ CreateNamespace(
+ UnresolvedDBObjectName(Seq("a", "b", "c"), true),
+ ifNotExists = false,
+ Map(
+ "a" -> "1",
+ "b" -> "0.1",
+ "c" -> "true",
+ "location" -> "/home/user/db")))
+ }
+
+ private def intercept(sqlCommand: String, messages: String*): Unit =
+ interceptParseException(parsePlan)(sqlCommand, messages: _*)()
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala
new file mode 100644
index 0000000000..7db8fba8ac
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, SupportsNamespaces}
+import org.apache.spark.sql.execution.command.DDLCommandTestUtils.V1_COMMAND_VERSION
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * This base suite contains unified tests for the `CREATE NAMESPACE` command that check V1 and V2
+ * table catalogs. The tests that cannot run for all supported catalogs are located in more
+ * specific test suites:
+ *
+ * - V2 table catalog tests: `org.apache.spark.sql.execution.command.v2.CreateNamespaceSuite`
+ * - V1 table catalog tests:
+ * `org.apache.spark.sql.execution.command.v1.CreateNamespaceSuiteBase`
+ * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.CreateNamespaceSuite`
+ * - V1 Hive External catalog:
+* `org.apache.spark.sql.hive.execution.command.CreateNamespaceSuite`
+ */
+trait CreateNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override val command = "CREATE NAMESPACE"
+
+ protected def namespace: String
+
+ protected def namespaceArray: Array[String] = namespace.split('.')
+
+ protected def notFoundMsgPrefix: String =
+ if (commandVersion == V1_COMMAND_VERSION) "Database" else "Namespace"
+
+ test("basic") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns")
+ assert(getCatalog(catalog).asNamespaceCatalog.namespaceExists(namespaceArray))
+ }
+ }
+
+ test("namespace with location") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ withTempDir { tmpDir =>
+ // The generated temp path is not qualified.
+ val path = tmpDir.getCanonicalPath
+ assert(!path.startsWith("file:/"))
+
+ val e = intercept[IllegalArgumentException] {
+ sql(s"CREATE NAMESPACE $ns LOCATION ''")
+ }
+ assert(e.getMessage.contains("Can not create a Path from an empty string"))
+
+ val uri = new Path(path).toUri
+ sql(s"CREATE NAMESPACE $ns LOCATION '$uri'")
+
+ // Make sure the location is qualified.
+ val expected = makeQualifiedPath(tmpDir.toString)
+ assert("file" === expected.getScheme)
+ assert(new Path(getNamespaceLocation(catalog, namespaceArray)).toUri === expected)
+ }
+ }
+ }
+
+ test("Namespace already exists") {
+ val ns = s"$catalog.$namespace"
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns")
+
+ val e = intercept[NamespaceAlreadyExistsException] {
+ sql(s"CREATE NAMESPACE $ns")
+ }
+ assert(e.getMessage.contains(s"$notFoundMsgPrefix '$namespace' already exists"))
+
+ // The following will be no-op since the namespace already exists.
+ sql(s"CREATE NAMESPACE IF NOT EXISTS $ns")
+ }
+ }
+
+ test("CreateNameSpace: reserved properties") {
+ import SupportsNamespaces._
+ val ns = s"$catalog.$namespace"
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) {
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
+ val exception = intercept[ParseException] {
+ sql(s"CREATE NAMESPACE $ns WITH DBPROPERTIES('$key'='dummyVal')")
+ }
+ assert(exception.getMessage.contains(s"$key is a reserved namespace property"))
+ }
+ }
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) {
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key =>
+ withNamespace(ns) {
+ sql(s"CREATE NAMESPACE $ns WITH DBPROPERTIES('$key'='foo')")
+ assert(sql(s"DESC NAMESPACE EXTENDED $ns")
+ .toDF("k", "v")
+ .where("k='Properties'")
+ .where("v=''")
+ .count == 1, s"$key is a reserved namespace property and ignored")
+ val meta =
+ getCatalog(catalog).asNamespaceCatalog.loadNamespaceMetadata(namespaceArray)
+ assert(meta.get(key) == null || !meta.get(key).contains("foo"),
+ "reserved properties should not have side effects")
+ }
+ }
+ }
+ }
+
+ protected def getNamespaceLocation(catalog: String, namespace: Array[String]): String = {
+ val metadata = getCatalog(catalog).asNamespaceCatalog
+ .loadNamespaceMetadata(namespace).asScala
+ metadata(SupportsNamespaces.PROP_LOCATION)
+ }
+
+ private def getCatalog(name: String): CatalogPlugin = {
+ spark.sessionState.catalogManager.catalog(name)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
index f9e26f8277..39f2abd35c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
@@ -39,7 +39,9 @@ import org.apache.spark.sql.test.SQLTestUtils
*/
trait DDLCommandTestUtils extends SQLTestUtils {
// The version of the catalog under testing such as "V1", "V2", "Hive V1".
- protected def version: String
+ protected def catalogVersion: String
+ // The version of the SQL command under testing such as "V1", "V2".
+ protected def commandVersion: String
// Name of the command as SQL statement, for instance "SHOW PARTITIONS"
protected def command: String
// The catalog name which can be used in SQL statements under testing
@@ -51,7 +53,8 @@ trait DDLCommandTestUtils extends SQLTestUtils {
// the failed test in logs belongs to.
override def test(testName: String, testTags: Tag*)(testFun: => Any)
(implicit pos: Position): Unit = {
- super.test(s"$command $version: " + testName, testTags: _*)(testFun)
+ val testNamePrefix = s"$command using $catalogVersion catalog $commandVersion command"
+ super.test(s"$testNamePrefix: $testName", testTags: _*)(testFun)
}
protected def withNamespaceAndTable(ns: String, tableName: String, cat: String = catalog)
@@ -170,3 +173,8 @@ trait DDLCommandTestUtils extends SQLTestUtils {
part1Loc
}
}
+
+object DDLCommandTestUtils {
+ val V1_COMMAND_VERSION = "V1"
+ val V2_COMMAND_VERSION = "V2"
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index 6c337e3e82..7bf2b8ff04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.command
import java.util.Locale
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, LocalTempView, UnresolvedAttribute, UnresolvedDBObjectName, UnresolvedFunc}
+import org.apache.spark.sql.catalyst.catalog.{ArchiveResource, FileResource, FunctionResource, JarResource}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
@@ -31,6 +32,7 @@ import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.test.SharedSparkSession
class DDLParserSuite extends AnalysisTest with SharedSparkSession {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
private lazy val parser = new SparkSqlParser()
private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = {
@@ -43,15 +45,18 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession {
}
}
+ private def intercept(sqlCommand: String, messages: String*): Unit =
+ interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)()
+
private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = {
val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null)
comparePlans(plan, expected, checkAnalysis = false)
}
- test("alter database - property values must be set") {
- assertUnsupported(
- sql = "ALTER DATABASE my_db SET DBPROPERTIES('key_without_value', 'key_with_value'='x')",
- containsThesePhrases = Seq("key_without_value"))
+ test("show current namespace") {
+ comparePlans(
+ parser.parsePlan("SHOW CURRENT NAMESPACE"),
+ ShowCurrentNamespaceCommand())
}
test("insert overwrite directory") {
@@ -195,21 +200,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession {
assert(parsed.isInstanceOf[Project])
}
- test("duplicate keys in table properties") {
- val e = intercept[ParseException] {
- parser.parsePlan("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('key1' = '1', 'key1' = '2')")
- }.getMessage
- assert(e.contains("Found duplicate keys 'key1'"))
- }
-
- test("duplicate columns in partition specs") {
- val e = intercept[ParseException] {
- parser.parsePlan(
- "ALTER TABLE dbx.tab1 PARTITION (a='1', a='2') RENAME TO PARTITION (a='100', a='200')")
- }.getMessage
- assert(e.contains("Found duplicate keys 'a'"))
- }
-
test("unsupported operations") {
intercept[ParseException] {
parser.parsePlan(
@@ -283,12 +273,12 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession {
val s = ScriptTransformation("func", Seq.empty, p, null)
compareTransformQuery("select transform(a, b) using 'func' from e where f < 10",
- s.copy(child = p.copy(child = p.child.where('f < 10)),
- output = Seq('key.string, 'value.string)))
+ s.copy(child = p.copy(child = p.child.where(Symbol("f") < 10)),
+ output = Seq(Symbol("key").string, Symbol("value").string)))
compareTransformQuery("map a, b using 'func' as c, d from e",
- s.copy(output = Seq('c.string, 'd.string)))
+ s.copy(output = Seq(Symbol("c").string, Symbol("d").string)))
compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e",
- s.copy(output = Seq('c.int, 'd.decimal(10, 0))))
+ s.copy(output = Seq(Symbol("c").int, Symbol("d").decimal(10, 0))))
}
test("use backticks in output of Script Transform") {
@@ -319,6 +309,182 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession {
""".stripMargin)
}
+ test("create view -- basic") {
+ val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1"
+ val parsed1 = parser.parsePlan(v1)
+
+ val expected1 = CreateView(
+ UnresolvedDBObjectName(Seq("view1"), false),
+ Seq.empty[(String,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment