Created
April 13, 2023 10:09
-
-
Save zhouyuan/8c41cb1b579b3ca5bb5879ff7260c139 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala | |
index 171e93c1bf..53662c6560 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala | |
@@ -17,6 +17,8 @@ | |
package org.apache.spark.sql | |
+import java.time.{Duration, Period} | |
+ | |
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} | |
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals | |
import org.apache.spark.sql.catalyst.plans.logical.Aggregate | |
@@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa | |
} | |
} | |
} | |
+ | |
+ test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") { | |
+ val table = "approx_count_distinct_for_ansi_intervals_tbl" | |
+ withTable(table) { | |
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), | |
+ (Period.ofMonths(200), Duration.ofSeconds(200L)), | |
+ (Period.ofMonths(300), Duration.ofSeconds(300L))) | |
+ .toDF("col1", "col2").createOrReplaceTempView(table) | |
+ val endpoints = (0 to 5).map(_ / 10) | |
+ | |
+ val relation = spark.table(table).logicalPlan | |
+ val ymAttr = relation.output.find(_.name == "col1").get | |
+ val ymAggFunc = | |
+ ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_)))) | |
+ val ymAggExpr = ymAggFunc.toAggregateExpression() | |
+ val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)() | |
+ | |
+ val dtAttr = relation.output.find(_.name == "col2").get | |
+ val dtAggFunc = | |
+ ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_)))) | |
+ val dtAggExpr = dtAggFunc.toAggregateExpression() | |
+ val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)() | |
+ val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation)) | |
+ checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1))) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala | |
index 5ff15c9710..9237c9e948 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala | |
@@ -18,7 +18,7 @@ | |
package org.apache.spark.sql | |
import java.sql.{Date, Timestamp} | |
-import java.time.LocalDateTime | |
+import java.time.{Duration, LocalDateTime, Period} | |
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | |
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY | |
@@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession | |
class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession { | |
import testImplicits._ | |
- private val table = "percentile_test" | |
+ private val table = "percentile_approx" | |
test("percentile_approx, single percentile value") { | |
withTempView(table) { | |
@@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession | |
Row(18, 17, 17, 17)) | |
} | |
} | |
+ | |
+ test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") { | |
+ withTempView(table) { | |
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), | |
+ (Period.ofMonths(200), Duration.ofSeconds(200L)), | |
+ (Period.ofMonths(300), Duration.ofSeconds(300L))) | |
+ .toDF("col1", "col2").createOrReplaceTempView(table) | |
+ checkAnswer( | |
+ spark.sql( | |
+ s"""SELECT | |
+ | percentile_approx(col1, 0.5), | |
+ | SUM(null), | |
+ | percentile_approx(col2, 0.5) | |
+ |FROM $table | |
+ """.stripMargin), | |
+ Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L))) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala | |
new file mode 100644 | |
index 0000000000..05513cddcc | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala | |
@@ -0,0 +1,217 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import org.apache.spark.sql.catalyst.FunctionIdentifier | |
+import org.apache.spark.sql.catalyst.expressions._ | |
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate | |
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec | |
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+/** | |
+ * Query tests for the Bloom filter aggregate and filter function. | |
+ */ | |
+class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { | |
+ import testImplicits._ | |
+ | |
+ val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") | |
+ val funcId_might_contain = new FunctionIdentifier("might_contain") | |
+ | |
+ override def beforeAll(): Unit = { | |
+ super.beforeAll() | |
+ // Register 'bloom_filter_agg' to builtin. | |
+ spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg, | |
+ new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), | |
+ (children: Seq[Expression]) => children.size match { | |
+ case 1 => new BloomFilterAggregate(children.head) | |
+ case 2 => new BloomFilterAggregate(children.head, children(1)) | |
+ case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) | |
+ }) | |
+ | |
+ // Register 'might_contain' to builtin. | |
+ spark.sessionState.functionRegistry.registerFunction(funcId_might_contain, | |
+ new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), | |
+ (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) | |
+ } | |
+ | |
+ override def afterAll(): Unit = { | |
+ spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) | |
+ spark.sessionState.functionRegistry.dropFunction(funcId_might_contain) | |
+ super.afterAll() | |
+ } | |
+ | |
+ test("Test bloom_filter_agg and might_contain") { | |
+ val conf = SQLConf.get | |
+ val table = "bloom_filter_test" | |
+ for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, | |
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) { | |
+ for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, | |
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) { | |
+ val sqlString = s""" | |
+ |SELECT every(might_contain( | |
+ | (SELECT bloom_filter_agg(col, | |
+ | cast($numEstimatedItems as long), | |
+ | cast($numBits as long)) | |
+ | FROM $table), | |
+ | col)) positive_membership_test, | |
+ | every(might_contain( | |
+ | (SELECT bloom_filter_agg(col, | |
+ | cast($numEstimatedItems as long), | |
+ | cast($numBits as long)) | |
+ | FROM values (-1L), (100001L), (20000L) as t(col)), | |
+ | col)) negative_membership_test | |
+ |FROM $table | |
+ """.stripMargin | |
+ withTempView(table) { | |
+ (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L)) | |
+ .toDF("col").createOrReplaceTempView(table) | |
+ // Validate error messages as well as answers when there's no error. | |
+ if (numEstimatedItems <= 0) { | |
+ val exception = intercept[AnalysisException] { | |
+ spark.sql(sqlString) | |
+ } | |
+ assert(exception.getMessage.contains( | |
+ "The estimated number of items must be a positive value")) | |
+ } else if (numBits <= 0) { | |
+ val exception = intercept[AnalysisException] { | |
+ spark.sql(sqlString) | |
+ } | |
+ assert(exception.getMessage.contains("The number of bits must be a positive value")) | |
+ } else { | |
+ checkAnswer(spark.sql(sqlString), Row(true, false)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Test that bloom_filter_agg errors out disallowed input value types") { | |
+ val exception1 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT bloom_filter_agg(a) | |
+ |FROM values (1.2), (2.5) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception1.getMessage.contains( | |
+ "Input to function bloom_filter_agg should have been a bigint value")) | |
+ | |
+ val exception2 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT bloom_filter_agg(a, 2) | |
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception2.getMessage.contains( | |
+ "function bloom_filter_agg should have been a bigint value followed with two bigint")) | |
+ | |
+ val exception3 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT bloom_filter_agg(a, cast(2 as long), 5) | |
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception3.getMessage.contains( | |
+ "function bloom_filter_agg should have been a bigint value followed with two bigint")) | |
+ | |
+ val exception4 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT bloom_filter_agg(a, null, 5) | |
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments")) | |
+ | |
+ val exception5 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT bloom_filter_agg(a, 5, null) | |
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments")) | |
+ } | |
+ | |
+ test("Test that might_contain errors out disallowed input value types") { | |
+ val exception1 = intercept[AnalysisException] { | |
+ spark.sql("""|SELECT might_contain(1.0, 1L)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception1.getMessage.contains( | |
+ "Input to function might_contain should have been binary followed by a value with bigint")) | |
+ | |
+ val exception2 = intercept[AnalysisException] { | |
+ spark.sql("""|SELECT might_contain(NULL, 0.1)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception2.getMessage.contains( | |
+ "Input to function might_contain should have been binary followed by a value with bigint")) | |
+ } | |
+ | |
+ test("Test that might_contain errors out non-constant Bloom filter") { | |
+ val exception1 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT might_contain(cast(a as binary), cast(5 as long)) | |
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception1.getMessage.contains( | |
+ "The Bloom filter binary input to might_contain should be either a constant value or " + | |
+ "a scalar subquery expression")) | |
+ | |
+ val exception2 = intercept[AnalysisException] { | |
+ spark.sql(""" | |
+ |SELECT might_contain((select cast(a as binary)), cast(5 as long)) | |
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""" | |
+ .stripMargin) | |
+ } | |
+ assert(exception2.getMessage.contains( | |
+ "The Bloom filter binary input to might_contain should be either a constant value or " + | |
+ "a scalar subquery expression")) | |
+ } | |
+ | |
+ test("Test that might_contain can take a constant value input") { | |
+ checkAnswer(spark.sql( | |
+ """SELECT might_contain( | |
+ |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', | |
+ |cast(201 as long))""".stripMargin), | |
+ Row(false)) | |
+ } | |
+ | |
+ test("Test that bloom_filter_agg produces a NULL with empty input") { | |
+ checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""), | |
+ Row(null)) | |
+ } | |
+ | |
+ test("Test NULL inputs for might_contain") { | |
+ checkAnswer(spark.sql( | |
+ s""" | |
+ |SELECT might_contain(null, null) both_null, | |
+ | might_contain(null, 1L) null_bf, | |
+ | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), | |
+ | null) null_value | |
+ """.stripMargin), | |
+ Row(null, null, null)) | |
+ } | |
+ | |
+ test("Test that a query with bloom_filter_agg has partial aggregates") { | |
+ assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") | |
+ .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan | |
+ .collect({case agg: BaseAggregateExec => agg}).size == 2) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala | |
index 13039bbbf6..a596ebc6b6 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEHintSuite.scala | |
@@ -17,7 +17,7 @@ | |
package org.apache.spark.sql | |
-import org.apache.log4j.Level | |
+import org.apache.logging.log4j.Level | |
import org.apache.spark.sql.catalyst.plans.logical._ | |
import org.apache.spark.sql.test.SharedSparkSession | |
@@ -65,7 +65,7 @@ class CTEHintSuite extends QueryTest with SharedSparkSession { | |
} | |
val warningMessages = logAppender.loggingEvents | |
.filter(_.getLevel == Level.WARN) | |
- .map(_.getRenderedMessage) | |
+ .map(_.getMessage.getFormattedMessage) | |
.filter(_.contains("hint")) | |
assert(warningMessages.size == warnings.size) | |
warnings.foreach { w => | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala | |
index 7ee533ac26..e758c6f8df 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala | |
@@ -17,7 +17,8 @@ | |
package org.apache.spark.sql | |
-import org.apache.spark.sql.catalyst.plans.logical.WithCTE | |
+import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, LessThan, Literal, Or} | |
+import org.apache.spark.sql.catalyst.plans.logical._ | |
import org.apache.spark.sql.execution.adaptive._ | |
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec | |
import org.apache.spark.sql.internal.SQLConf | |
@@ -42,7 +43,7 @@ abstract class CTEInlineSuiteBase | |
""".stripMargin) | |
checkAnswer(df, Nil) | |
assert( | |
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty, | |
+ df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), | |
"Non-deterministic With-CTE with multiple references should be not inlined.") | |
} | |
} | |
@@ -59,7 +60,7 @@ abstract class CTEInlineSuiteBase | |
""".stripMargin) | |
checkAnswer(df, Nil) | |
assert( | |
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).nonEmpty, | |
+ df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), | |
"Non-deterministic With-CTE with multiple references should be not inlined.") | |
} | |
} | |
@@ -76,10 +77,10 @@ abstract class CTEInlineSuiteBase | |
""".stripMargin) | |
checkAnswer(df, Row(0, 1) :: Row(1, 2) :: Nil) | |
assert( | |
- df.queryExecution.analyzed.find(_.isInstanceOf[WithCTE]).nonEmpty, | |
+ df.queryExecution.analyzed.exists(_.isInstanceOf[WithCTE]), | |
"With-CTE should not be inlined in analyzed plan.") | |
assert( | |
- df.queryExecution.optimizedPlan.find(_.isInstanceOf[WithCTE]).isEmpty, | |
+ !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]), | |
"With-CTE with one reference should be inlined in optimized plan.") | |
} | |
} | |
@@ -107,8 +108,8 @@ abstract class CTEInlineSuiteBase | |
"With-CTE should contain 2 CTE defs after analysis.") | |
assert( | |
df.queryExecution.optimizedPlan.collect { | |
- case WithCTE(_, cteDefs) => cteDefs | |
- }.head.length == 2, | |
+ case r: RepartitionOperation => r | |
+ }.length == 6, | |
"With-CTE should contain 2 CTE def after optimization.") | |
} | |
} | |
@@ -136,8 +137,8 @@ abstract class CTEInlineSuiteBase | |
"With-CTE should contain 2 CTE defs after analysis.") | |
assert( | |
df.queryExecution.optimizedPlan.collect { | |
- case WithCTE(_, cteDefs) => cteDefs | |
- }.head.length == 1, | |
+ case r: RepartitionOperation => r | |
+ }.length == 4, | |
"One CTE def should be inlined after optimization.") | |
} | |
} | |
@@ -163,7 +164,7 @@ abstract class CTEInlineSuiteBase | |
"With-CTE should contain 2 CTE defs after analysis.") | |
assert( | |
df.queryExecution.optimizedPlan.collect { | |
- case WithCTE(_, cteDefs) => cteDefs | |
+ case r: RepartitionOperation => r | |
}.isEmpty, | |
"CTEs with one reference should all be inlined after optimization.") | |
} | |
@@ -248,7 +249,7 @@ abstract class CTEInlineSuiteBase | |
"With-CTE should contain 2 CTE defs after analysis.") | |
assert( | |
df.queryExecution.optimizedPlan.collect { | |
- case WithCTE(_, cteDefs) => cteDefs | |
+ case r: RepartitionOperation => r | |
}.isEmpty, | |
"Deterministic CTEs should all be inlined after optimization.") | |
} | |
@@ -272,6 +273,372 @@ abstract class CTEInlineSuiteBase | |
assert(ex.message.contains("Table or view not found: v1")) | |
} | |
} | |
+ | |
+ test("CTE Predicate push-down and column pruning") { | |
+ withView("t") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") | |
+ val df = sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1, c2, 's' c3, rand() c4 from t | |
+ |), | |
+ |vv as ( | |
+ | select v1.c1, v1.c2, rand() c5 from v v1, v v2 | |
+ | where v1.c1 > 0 and v1.c3 = 's' and v1.c2 = v2.c2 | |
+ |) | |
+ |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2 | |
+ |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df, Row(1, 2, 1, 2) :: Nil) | |
+ assert( | |
+ df.queryExecution.analyzed.collect { | |
+ case WithCTE(_, cteDefs) => cteDefs | |
+ }.head.length == 2, | |
+ "With-CTE should contain 2 CTE defs after analysis.") | |
+ val cteRepartitions = df.queryExecution.optimizedPlan.collect { | |
+ case r: RepartitionOperation => r | |
+ } | |
+ assert(cteRepartitions.length == 6, | |
+ "CTE should not be inlined after optimization.") | |
+ val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct | |
+ // Check column pruning and predicate push-down. | |
+ assert(distinctCteRepartitions.length == 2) | |
+ assert(distinctCteRepartitions(1).collectFirst { | |
+ case p: Project if p.projectList.length == 3 => p | |
+ }.isDefined, "CTE columns should be pruned.") | |
+ assert(distinctCteRepartitions(1).collectFirst { | |
+ case f: Filter if f.condition.semanticEquals(GreaterThan(f.output(1), Literal(0))) => f | |
+ }.isDefined, "Predicate 'c2 > 0' should be pushed down to the CTE def 'v'.") | |
+ assert(distinctCteRepartitions(0).collectFirst { | |
+ case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f | |
+ }.isDefined, "CTE 'vv' definition contains predicate 'c1 > 0'.") | |
+ assert(distinctCteRepartitions(1).collectFirst { | |
+ case f: Filter if f.condition.find(_.semanticEquals(f.output(0))).isDefined => f | |
+ }.isEmpty, "Predicate 'c1 > 0' should be not pushed down to the CTE def 'v'.") | |
+ // Check runtime repartition reuse. | |
+ assert( | |
+ collectWithSubqueries(df.queryExecution.executedPlan) { | |
+ case r: ReusedExchangeExec => r | |
+ }.length == 2, | |
+ "CTE repartition is reused.") | |
+ } | |
+ } | |
+ | |
+ test("CTE Predicate push-down and column pruning - combined predicate") { | |
+ withView("t") { | |
+ Seq((0, 1, 2), (1, 2, 3)).toDF("c1", "c2", "c3").createOrReplaceTempView("t") | |
+ val df = sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1, c2, c3, rand() c4 from t | |
+ |), | |
+ |vv as ( | |
+ | select v1.c1, v1.c2, rand() c5 from v v1, v v2 | |
+ | where v1.c1 > 0 and v2.c3 < 5 and v1.c2 = v2.c2 | |
+ |) | |
+ |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2 | |
+ |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df, Row(1, 2, 1, 2) :: Nil) | |
+ assert( | |
+ df.queryExecution.analyzed.collect { | |
+ case WithCTE(_, cteDefs) => cteDefs | |
+ }.head.length == 2, | |
+ "With-CTE should contain 2 CTE defs after analysis.") | |
+ val cteRepartitions = df.queryExecution.optimizedPlan.collect { | |
+ case r: RepartitionOperation => r | |
+ } | |
+ assert(cteRepartitions.length == 6, | |
+ "CTE should not be inlined after optimization.") | |
+ val distinctCteRepartitions = cteRepartitions.map(_.canonicalized).distinct | |
+ // Check column pruning and predicate push-down. | |
+ assert(distinctCteRepartitions.length == 2) | |
+ assert(distinctCteRepartitions(1).collectFirst { | |
+ case p: Project if p.projectList.length == 3 => p | |
+ }.isDefined, "CTE columns should be pruned.") | |
+ assert( | |
+ distinctCteRepartitions(1).collectFirst { | |
+ case f: Filter | |
+ if f.condition.semanticEquals( | |
+ And( | |
+ GreaterThan(f.output(1), Literal(0)), | |
+ Or( | |
+ GreaterThan(f.output(0), Literal(0)), | |
+ LessThan(f.output(2), Literal(5))))) => | |
+ f | |
+ }.isDefined, | |
+ "Predicate 'c2 > 0 AND (c1 > 0 OR c3 < 5)' should be pushed down to the CTE def 'v'.") | |
+ // Check runtime repartition reuse. | |
+ assert( | |
+ collectWithSubqueries(df.queryExecution.executedPlan) { | |
+ case r: ReusedExchangeExec => r | |
+ }.length == 2, | |
+ "CTE repartition is reused.") | |
+ } | |
+ } | |
+ | |
+ test("Views with CTEs - 1 temp view") { | |
+ withView("t", "t2") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") | |
+ sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1 + c2 c3 from t | |
+ |) | |
+ |select sum(c3) s from v | |
+ """.stripMargin).createOrReplaceTempView("t2") | |
+ val df = sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1 * c2 c3 from t | |
+ |) | |
+ |select sum(c3) from v except select s from t2 | |
+ """.stripMargin) | |
+ checkAnswer(df, Row(2) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("Views with CTEs - 2 temp views") { | |
+ withView("t", "t2", "t3") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") | |
+ sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1 + c2 c3 from t | |
+ |) | |
+ |select sum(c3) s from v | |
+ """.stripMargin).createOrReplaceTempView("t2") | |
+ sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1 * c2 c3 from t | |
+ |) | |
+ |select sum(c3) s from v | |
+ """.stripMargin).createOrReplaceTempView("t3") | |
+ val df = sql("select s from t3 except select s from t2") | |
+ checkAnswer(df, Row(2) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("Views with CTEs - temp view + sql view") { | |
+ withTable("t") { | |
+ withView ("t2", "t3") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").write.saveAsTable("t") | |
+ sql( | |
+ s"""with | |
+ |v as ( | |
+ | select c1 + c2 c3 from t | |
+ |) | |
+ |select sum(c3) s from v | |
+ """.stripMargin).createOrReplaceTempView("t2") | |
+ sql( | |
+ s"""create view t3 as | |
+ |with | |
+ |v as ( | |
+ | select c1 * c2 c3 from t | |
+ |) | |
+ |select sum(c3) s from v | |
+ """.stripMargin) | |
+ val df = sql("select s from t3 except select s from t2") | |
+ checkAnswer(df, Row(2) :: Nil) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Union of Dataframes with CTEs") { | |
+ val a = spark.sql("with t as (select 1 as n) select * from t ") | |
+ val b = spark.sql("with t as (select 2 as n) select * from t ") | |
+ val df = a.union(b) | |
+ checkAnswer(df, Row(1) :: Row(2) :: Nil) | |
+ } | |
+ | |
+ test("CTE definitions out of original order when not inlined") { | |
+ withView("t1", "t2") { | |
+ Seq((1, 2, 10, 100), (2, 3, 20, 200)).toDF("workspace_id", "issue_id", "shard_id", "field_id") | |
+ .createOrReplaceTempView("issue_current") | |
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> | |
+ "org.apache.spark.sql.catalyst.optimizer.InlineCTE") { | |
+ val df = sql( | |
+ """ | |
+ |WITH cte_0 AS ( | |
+ | SELECT workspace_id, issue_id, shard_id, field_id FROM issue_current | |
+ |), | |
+ |cte_1 AS ( | |
+ | WITH filtered_source_table AS ( | |
+ | SELECT * FROM cte_0 WHERE shard_id in ( 10 ) | |
+ | ) | |
+ | SELECT source_table.workspace_id, field_id FROM cte_0 source_table | |
+ | INNER JOIN ( | |
+ | SELECT workspace_id, issue_id FROM filtered_source_table GROUP BY 1, 2 | |
+ | ) target_table | |
+ | ON source_table.issue_id = target_table.issue_id | |
+ | AND source_table.workspace_id = target_table.workspace_id | |
+ | WHERE source_table.shard_id IN ( 10 ) | |
+ |) | |
+ |SELECT * FROM cte_1 | |
+ """.stripMargin) | |
+ checkAnswer(df, Row(1, 100) :: Nil) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Make sure CTESubstitution places WithCTE back in the plan correctly.") { | |
+ withView("t") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") | |
+ | |
+ // CTE on both sides of join - WithCTE placed over first common parent, i.e., the join. | |
+ val df1 = sql( | |
+ s""" | |
+ |select count(v1.c3), count(v2.c3) from ( | |
+ | with | |
+ | v1 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v1 | |
+ |) v1 join ( | |
+ | with | |
+ | v2 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v2 | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df1, Row(2, 2) :: Nil) | |
+ df1.queryExecution.analyzed match { | |
+ case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 2) | |
+ case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got $other") | |
+ } | |
+ | |
+ // CTE on one side of join - WithCTE placed back where it was. | |
+ val df2 = sql( | |
+ s""" | |
+ |select count(v1.c3), count(v2.c3) from ( | |
+ | select c1, c2, rand() c3 from t | |
+ |) v1 join ( | |
+ | with | |
+ | v2 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v2 | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df2, Row(2, 2) :: Nil) | |
+ df2.queryExecution.analyzed match { | |
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _, _, _)) => | |
+ assert(cteDefs.length == 1) | |
+ case other => fail(s"Expect pattern Aggregate(Join(_, WithCTE(_))) but got $other") | |
+ } | |
+ | |
+ // CTE on one side of join and both sides of union - WithCTE placed on first common parent. | |
+ val df3 = sql( | |
+ s""" | |
+ |select count(v1.c3), count(v2.c3) from ( | |
+ | select c1, c2, rand() c3 from t | |
+ |) v1 join ( | |
+ | select * from ( | |
+ | with | |
+ | v1 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v1 | |
+ | ) | |
+ | union all | |
+ | select * from ( | |
+ | with | |
+ | v2 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v2 | |
+ | ) | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df3, Row(4, 4) :: Nil) | |
+ df3.queryExecution.analyzed match { | |
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union, cteDefs)), _, _, _)) => | |
+ assert(cteDefs.length == 2) | |
+ case other => fail( | |
+ s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got $other") | |
+ } | |
+ | |
+ // CTE on one side of join and one side of union - WithCTE placed back where it was. | |
+ val df4 = sql( | |
+ s""" | |
+ |select count(v1.c3), count(v2.c3) from ( | |
+ | select c1, c2, rand() c3 from t | |
+ |) v1 join ( | |
+ | select * from ( | |
+ | with | |
+ | v1 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v1 | |
+ | ) | |
+ | union all | |
+ | select c1, c2, rand() c3 from t | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df4, Row(4, 4) :: Nil) | |
+ df4.queryExecution.analyzed match { | |
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)), _, _, _)) | |
+ if children.head.find(_.isInstanceOf[WithCTE]).isDefined => | |
+ assert( | |
+ children.head.collect { | |
+ case w: WithCTE => w | |
+ }.head.cteDefs.length == 1) | |
+ case other => fail( | |
+ s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got $other") | |
+ } | |
+ | |
+ // CTE on both sides of join and one side of union - WithCTE placed on first common parent. | |
+ val df5 = sql( | |
+ s""" | |
+ |select count(v1.c3), count(v2.c3) from ( | |
+ | with | |
+ | v1 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v1 | |
+ |) v1 join ( | |
+ | select c1, c2, rand() c3 from t | |
+ | union all | |
+ | select * from ( | |
+ | with | |
+ | v2 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v2 | |
+ | ) | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df5, Row(4, 4) :: Nil) | |
+ df5.queryExecution.analyzed match { | |
+ case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 2) | |
+ case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got $other") | |
+ } | |
+ | |
+ // CTE as root node - WithCTE placed back where it was. | |
+ val df6 = sql( | |
+ s""" | |
+ |with | |
+ |v1 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ |) | |
+ |select count(v1.c3), count(v2.c3) from | |
+ |v1 join ( | |
+ | with | |
+ | v2 as ( | |
+ | select c1, c2, rand() c3 from t | |
+ | ) | |
+ | select * from v2 | |
+ |) v2 on v1.c1 = v2.c1 | |
+ """.stripMargin) | |
+ checkAnswer(df6, Row(2, 2) :: Nil) | |
+ df6.queryExecution.analyzed match { | |
+ case WithCTE(_, cteDefs) => assert(cteDefs.length == 2) | |
+ case other => fail(s"Expect pattern WithCTE(_) but got $other") | |
+ } | |
+ } | |
+ } | |
} | |
class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
index bad28152e4..4de409f56d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException | |
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression | |
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH} | |
import org.apache.spark.sql.catalyst.util.DateTimeConstants | |
-import org.apache.spark.sql.execution.{ExecSubqueryExpression, RDDScanExec, SparkPlan} | |
+import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} | |
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
import org.apache.spark.sql.execution.columnar._ | |
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec | |
@@ -604,9 +604,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils | |
uncacheTable("t2") | |
} | |
- // One side of join is not partitioned in the desired way. Since the number of partitions of | |
- // the side that has already partitioned is smaller than the side that is not partitioned, | |
- // we shuffle both side. | |
+ // One side of join is not partitioned in the desired way. We'll only shuffle this side. | |
withTempView("t1", "t2") { | |
testData.repartition(6, $"value").createOrReplaceTempView("t1") | |
testData2.repartition(3, $"a").createOrReplaceTempView("t2") | |
@@ -614,7 +612,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils | |
spark.catalog.cacheTable("t2") | |
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") | |
- verifyNumExchanges(query, 2) | |
+ verifyNumExchanges(query, 1) | |
checkAnswer( | |
query, | |
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) | |
@@ -881,8 +879,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils | |
test("SPARK-23312: vectorized cache reader can be disabled") { | |
Seq(true, false).foreach { vectorized => | |
withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { | |
- val df = spark.range(10).cache() | |
- df.queryExecution.executedPlan.foreach { | |
+ val df1 = spark.range(10).cache() | |
+ val df2 = spark.range(10).cache() | |
+ val union = df1.union(df2) | |
+ union.queryExecution.executedPlan.foreach { | |
case i: InMemoryTableScanExec => | |
assert(i.supportsColumnar == vectorized) | |
case _ => | |
@@ -891,6 +891,19 @@ class CachedTableSuite extends QueryTest with SQLTestUtils | |
} | |
} | |
+ test("SPARK-37369: Avoid redundant ColumnarToRow transition on InMemoryTableScan") { | |
+ Seq(true, false).foreach { vectorized => | |
+ withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { | |
+ val cache = spark.range(10).cache() | |
+ val df = cache.filter($"id" > 0) | |
+ val columnarToRow = df.queryExecution.executedPlan.collect { | |
+ case c: ColumnarToRowExec => c | |
+ } | |
+ assert(columnarToRow.isEmpty) | |
+ } | |
+ } | |
+ } | |
+ | |
private def checkIfNoJobTriggered[T](f: => T): T = { | |
var numJobTriggered = 0 | |
val jobListener = new SparkListener { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala | |
index 7be54d49a9..978e3f8d36 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala | |
@@ -100,6 +100,19 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { | |
} | |
} | |
+ test("char type values should not be padded when charVarcharAsString is true") { | |
+ withSQLConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key -> "true") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(a STRING, b CHAR(5), c CHAR(5)) USING $format partitioned by (c)") | |
+ sql("INSERT INTO t VALUES ('abc', 'abc', 'abc')") | |
+ checkAnswer(sql("SELECT b FROM t WHERE b='abc'"), Row("abc")) | |
+ checkAnswer(sql("SELECT b FROM t WHERE b in ('abc')"), Row("abc")) | |
+ checkAnswer(sql("SELECT c FROM t WHERE c='abc'"), Row("abc")) | |
+ checkAnswer(sql("SELECT c FROM t WHERE c in ('abc')"), Row("abc")) | |
+ } | |
+ } | |
+ } | |
+ | |
test("varchar type values length check and trim: partitioned columns") { | |
(0 to 5).foreach { n => | |
// SPARK-34192: we need to create a a new table for each round of test because of | |
@@ -332,8 +345,8 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { | |
sql(s"CREATE TABLE t(c STRUCT<c: $typeName(5)>) USING $format") | |
sql("INSERT INTO t SELECT struct(null)") | |
checkAnswer(spark.table("t"), Row(Row(null))) | |
- val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) | |
- assert(e.getCause.getMessage.contains(s"Exceeds char/varchar type length limitation: 5")) | |
+ val e = intercept[RuntimeException](sql("INSERT INTO t SELECT struct('123456')")) | |
+ assert(e.getMessage.contains(s"Exceeds char/varchar type length limitation: 5")) | |
} | |
} | |
@@ -843,27 +856,6 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa | |
} | |
} | |
- // TODO(SPARK-33875): Move these tests to super after DESCRIBE COLUMN v2 implemented | |
- test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { | |
- withTable("t") { | |
- sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") | |
- checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), | |
- Row("varchar(3)")) | |
- checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), | |
- Row("char(5)")) | |
- } | |
- } | |
- | |
- // TODO(SPARK-33898): Move these tests to super after SHOW CREATE TABLE for v2 implemented | |
- test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") { | |
- withTable("t") { | |
- sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") | |
- val rest = sql("SHOW CREATE TABLE t").head().getString(0) | |
- assert(rest.contains("VARCHAR(3)")) | |
- assert(rest.contains("CHAR(5)")) | |
- } | |
- } | |
- | |
test("SPARK-34114: should not trim right for read-side length check and char padding") { | |
Seq("char", "varchar").foreach { typ => | |
withTempPath { dir => | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala | |
index e7ca431726..1f8dc6f80d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala | |
@@ -137,6 +137,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") | |
} | |
+ test("SPARK-34805: as propagates metadata from nested column") { | |
+ val metadata = new MetadataBuilder | |
+ metadata.putString("key", "value") | |
+ val df = spark.createDataFrame(sparkContext.emptyRDD[Row], | |
+ StructType(Seq( | |
+ StructField("parent", StructType(Seq( | |
+ StructField("child", StringType, metadata = metadata.build()) | |
+ )))) | |
+ )) | |
+ val newCol = df("parent.child") | |
+ assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") | |
+ } | |
+ | |
test("collect on column produced by a binary operator") { | |
val df = Seq((1, 2, 3)).toDF("a", "b", "c") | |
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) | |
@@ -281,9 +294,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
testData.select(isnan($"a"), isnan($"b")), | |
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) | |
- checkAnswer( | |
- sql("select isnan(15), isnan('invalid')"), | |
- Row(false, false)) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ sql("select isnan(15), isnan('invalid')"), | |
+ Row(false, false)) | |
+ } | |
} | |
test("nanvl") { | |
@@ -932,7 +947,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) | |
} | |
+ test("SPARK-37646: lit") { | |
+ assert(lit($"foo") == $"foo") | |
+ assert(lit(Symbol("foo")) == $"foo") | |
+ assert(lit(1) == Column(Literal(1))) | |
+ assert(lit(null) == Column(Literal(null, NullType))) | |
+ } | |
+ | |
test("typedLit") { | |
+ assert(typedLit($"foo") == $"foo") | |
+ assert(typedLit(Symbol("foo")) == $"foo") | |
+ assert(typedLit(1) == Column(Literal(1))) | |
+ assert(typedLit[String](null) == Column(Literal(null, StringType))) | |
+ | |
val df = Seq(Tuple1(0)).toDF("a") | |
// Only check the types `lit` cannot handle | |
checkAnswer( | |
@@ -1017,17 +1044,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should throw an exception if any intermediate structs don't exist") { | |
intercept[AnalysisException] { | |
- structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) | |
+ structLevel2.withColumn("a", Symbol("a").withField("x.b", lit(2))) | |
}.getMessage should include("No such struct field x in a") | |
intercept[AnalysisException] { | |
- structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) | |
+ structLevel3.withColumn("a", Symbol("a").withField("a.x.b", lit(2))) | |
}.getMessage should include("No such struct field x in a") | |
} | |
test("withField should throw an exception if intermediate field is not a struct") { | |
intercept[AnalysisException] { | |
- structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) | |
+ structLevel1.withColumn("a", Symbol("a").withField("b.a", lit(2))) | |
}.getMessage should include("struct argument should be struct type, got: int") | |
} | |
@@ -1041,7 +1068,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
StructField("a", structType, nullable = false))), | |
nullable = false)))) | |
- structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) | |
+ structLevel2.withColumn("a", Symbol("a").withField("a.b", lit(2))) | |
}.getMessage should include("Ambiguous reference to fields") | |
} | |
@@ -1060,7 +1087,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add field to struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("d", lit(4))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4))), | |
Row(Row(1, null, 3, 4)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1101,7 +1128,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add null field to struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(null).cast(IntegerType))), | |
Row(Row(1, null, 3, null)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1114,7 +1141,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add multiple fields to struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("e", lit(5))), | |
Row(Row(1, null, 3, 4, 5)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1128,7 +1155,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add multiple fields to nullable struct") { | |
checkAnswer( | |
- nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), | |
+ nullableStructLevel1.withColumn("a", Symbol("a") | |
+ .withField("d", lit(4)).withField("e", lit(5))), | |
Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1142,8 +1170,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add field to nested struct") { | |
Seq( | |
- structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), | |
- structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) | |
+ structLevel2.withColumn("a", Symbol("a").withField("a.d", lit(4))), | |
+ structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("d", lit(4)))) | |
).foreach { df => | |
checkAnswer( | |
df, | |
@@ -1204,7 +1232,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add field to deeply nested struct") { | |
checkAnswer( | |
- structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), | |
+ structLevel3.withColumn("a", Symbol("a").withField("a.a.d", lit(4))), | |
Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1221,7 +1249,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace field in struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("b", lit(2))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), | |
Row(Row(1, 2, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1233,7 +1261,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace field in nullable struct") { | |
checkAnswer( | |
- nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), | |
+ nullableStructLevel1.withColumn("a", Symbol("a").withField("b", lit("foo"))), | |
Row(null) :: Row(Row(1, "foo", 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1259,7 +1287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace field with null value in struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("c", lit(null).cast(IntegerType))), | |
Row(Row(1, null, null)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1271,7 +1299,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace multiple fields in struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("a", lit(10)).withField("b", lit(20))), | |
Row(Row(10, 20, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1283,7 +1311,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace multiple fields in nullable struct") { | |
checkAnswer( | |
- nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), | |
+ nullableStructLevel1.withColumn("a", Symbol("a").withField("a", lit(10)) | |
+ .withField("b", lit(20))), | |
Row(null) :: Row(Row(10, 20, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1296,7 +1325,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace field in nested struct") { | |
Seq( | |
structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), | |
- structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) | |
+ structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("b", lit(2)))) | |
).foreach { df => | |
checkAnswer( | |
df, | |
@@ -1377,7 +1406,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("b", lit(100))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(100))), | |
Row(Row(1, 100, 100)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1389,7 +1418,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace fields in struct in given order") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("b", lit(2)).withField("b", lit(20))), | |
Row(Row(1, 20, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1401,7 +1430,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add field and then replace same field in struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), | |
+ structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("d", lit(5))), | |
Row(Row(1, null, 3, 5)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1425,7 +1454,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), | |
+ df.withColumn("a", Symbol("a").withField("`a.b`.`e.f`", lit(2))), | |
Row(Row(Row(1, 2, 3))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1437,7 +1466,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
intercept[AnalysisException] { | |
- df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) | |
+ df.withColumn("a", Symbol("a").withField("a.b.e.f", lit(2))) | |
}.getMessage should include("No such struct field a in a.b") | |
} | |
@@ -1452,7 +1481,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace field in struct even if casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), | |
Row(Row(2, 1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1461,7 +1490,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), | |
Row(Row(1, 2)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1474,7 +1503,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should add field to struct because casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), | |
Row(Row(1, 1, 2)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1484,7 +1513,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), | |
Row(Row(1, 1, 2)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1512,7 +1541,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should replace nested field in struct even if casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { | |
checkAnswer( | |
- mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))), | |
Row(Row(Row(2, 1), Row(1, 1))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1527,7 +1556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))), | |
Row(Row(Row(1, 1), Row(2, 1))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1546,11 +1575,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("withField should throw an exception because casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
intercept[AnalysisException] { | |
- mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))) | |
}.getMessage should include("No such struct field A in a, B") | |
intercept[AnalysisException] { | |
- mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))) | |
}.getMessage should include("No such struct field b in a, B") | |
} | |
} | |
@@ -1757,17 +1786,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should throw an exception if any intermediate structs don't exist") { | |
intercept[AnalysisException] { | |
- structLevel2.withColumn("a", 'a.dropFields("x.b")) | |
+ structLevel2.withColumn("a", Symbol("a").dropFields("x.b")) | |
}.getMessage should include("No such struct field x in a") | |
intercept[AnalysisException] { | |
- structLevel3.withColumn("a", 'a.dropFields("a.x.b")) | |
+ structLevel3.withColumn("a", Symbol("a").dropFields("a.x.b")) | |
}.getMessage should include("No such struct field x in a") | |
} | |
test("dropFields should throw an exception if intermediate field is not a struct") { | |
intercept[AnalysisException] { | |
- structLevel1.withColumn("a", 'a.dropFields("b.a")) | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("b.a")) | |
}.getMessage should include("struct argument should be struct type, got: int") | |
} | |
@@ -1781,13 +1810,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
StructField("a", structType, nullable = false))), | |
nullable = false)))) | |
- structLevel2.withColumn("a", 'a.dropFields("a.b")) | |
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b")) | |
}.getMessage should include("Ambiguous reference to fields") | |
} | |
test("dropFields should drop field in struct") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.dropFields("b")), | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("b")), | |
Row(Row(1, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1810,7 +1839,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop multiple fields in struct") { | |
Seq( | |
structLevel1.withColumn("a", $"a".dropFields("b", "c")), | |
- structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("b").dropFields("c")) | |
).foreach { df => | |
checkAnswer( | |
df, | |
@@ -1824,7 +1853,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should throw an exception if no fields will be left in struct") { | |
intercept[AnalysisException] { | |
- structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("a", "b", "c")) | |
}.getMessage should include("cannot drop all fields in struct") | |
} | |
@@ -1848,7 +1877,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop field in nested struct") { | |
checkAnswer( | |
- structLevel2.withColumn("a", 'a.dropFields("a.b")), | |
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b")), | |
Row(Row(Row(1, 3))) :: Nil, | |
StructType( | |
Seq(StructField("a", StructType(Seq( | |
@@ -1861,7 +1890,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop multiple fields in nested struct") { | |
checkAnswer( | |
- structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), | |
+ structLevel2.withColumn("a", Symbol("a").dropFields("a.b", "a.c")), | |
Row(Row(Row(1))) :: Nil, | |
StructType( | |
Seq(StructField("a", StructType(Seq( | |
@@ -1898,7 +1927,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop field in deeply nested struct") { | |
checkAnswer( | |
- structLevel3.withColumn("a", 'a.dropFields("a.a.b")), | |
+ structLevel3.withColumn("a", Symbol("a").dropFields("a.a.b")), | |
Row(Row(Row(Row(1, 3)))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1922,7 +1951,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.dropFields("b")), | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("b")), | |
Row(Row(1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1933,7 +1962,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop field in struct even if casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), | |
Row(Row(1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1941,7 +1970,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), | |
Row(Row(1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1953,7 +1982,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should not drop field in struct because casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), | |
Row(Row(1, 1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1962,7 +1991,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), | |
+ mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), | |
Row(Row(1, 1)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1975,7 +2004,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should drop nested field in struct even if casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { | |
checkAnswer( | |
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")), | |
Row(Row(Row(1), Row(1, 1))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -1989,7 +2018,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")), | |
Row(Row(Row(1, 1), Row(1))) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -2007,18 +2036,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
test("dropFields should throw an exception because casing is different") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
intercept[AnalysisException] { | |
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")) | |
}.getMessage should include("No such struct field A in a, B") | |
intercept[AnalysisException] { | |
- mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) | |
+ mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")) | |
}.getMessage should include("No such struct field b in a, B") | |
} | |
} | |
test("dropFields should drop only fields that exist") { | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.dropFields("d")), | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("d")), | |
Row(Row(1, null, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -2028,7 +2057,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
nullable = false)))) | |
checkAnswer( | |
- structLevel1.withColumn("a", 'a.dropFields("b", "d")), | |
+ structLevel1.withColumn("a", Symbol("a").dropFields("b", "d")), | |
Row(Row(1, 3)) :: Nil, | |
StructType(Seq( | |
StructField("a", StructType(Seq( | |
@@ -2737,19 +2766,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e.isInstanceOf[ArithmeticException]) | |
- assert(e.getMessage.contains("divide by zero")) | |
+ assert(e.getMessage.contains("Division by zero")) | |
val e2 = intercept[SparkException] { | |
Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e2.isInstanceOf[ArithmeticException]) | |
- assert(e2.getMessage.contains("divide by zero")) | |
+ assert(e2.getMessage.contains("Division by zero")) | |
val e3 = intercept[SparkException] { | |
Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e3.isInstanceOf[ArithmeticException]) | |
- assert(e3.getMessage.contains("divide by zero")) | |
+ assert(e3.getMessage.contains("Division by zero")) | |
} | |
test("SPARK-34875: divide day-time interval by numeric") { | |
@@ -2784,19 +2813,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e.isInstanceOf[ArithmeticException]) | |
- assert(e.getMessage.contains("divide by zero")) | |
+ assert(e.getMessage.contains("Division by zero")) | |
val e2 = intercept[SparkException] { | |
Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e2.isInstanceOf[ArithmeticException]) | |
- assert(e2.getMessage.contains("divide by zero")) | |
+ assert(e2.getMessage.contains("Division by zero")) | |
val e3 = intercept[SparkException] { | |
Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect() | |
}.getCause | |
assert(e3.isInstanceOf[ArithmeticException]) | |
- assert(e3.getMessage.contains("divide by zero")) | |
+ assert(e3.getMessage.contains("Division by zero")) | |
} | |
test("SPARK-34896: return day-time interval from dates subtraction") { | |
@@ -2928,4 +2957,47 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
} | |
+ | |
+ test("SPARK-36778: add ilike API for scala") { | |
+ // scalastyle:off | |
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here. | |
+ // null handling | |
+ val nullDf = Seq("a", null).toDF("src") | |
+ checkAnswer(nullDf.filter($"src".ilike("A")), Row("a")) | |
+ checkAnswer(nullDf.filter($"src".ilike(null)), spark.emptyDataFrame) | |
+ // simple pattern | |
+ val simpleDf = Seq("a", "A", "abdef", "a_%b", "addb", "abC", "a\nb").toDF("src") | |
+ checkAnswer(simpleDf.filter($"src".ilike("a")), Seq("a", "A").toDF()) | |
+ checkAnswer(simpleDf.filter($"src".ilike("A")), Seq("a", "A").toDF()) | |
+ checkAnswer(simpleDf.filter($"src".ilike("b")), spark.emptyDataFrame) | |
+ checkAnswer(simpleDf.filter($"src".ilike("aBdef")), Seq("abdef").toDF()) | |
+ checkAnswer(simpleDf.filter($"src".ilike("a\\__b")), Seq("a_%b").toDF()) | |
+ checkAnswer(simpleDf.filter($"src".ilike("A_%b")), Seq("a_%b", "addb", "a\nb").toDF()) | |
+ checkAnswer(simpleDf.filter($"src".ilike("a%")), simpleDf) | |
+ checkAnswer(simpleDf.filter($"src".ilike("a_b")), Seq("a\nb").toDF()) | |
+ // double-escaping backslash | |
+ val dEscDf = Seq("""\__""", """\\__""").toDF("src") | |
+ checkAnswer(dEscDf.filter($"src".ilike("""\\\__""")), Seq("""\__""").toDF()) | |
+ checkAnswer(dEscDf.filter($"src".ilike("""%\\%\%""")), spark.emptyDataFrame) | |
+ // unicode | |
+ val uncDf = Seq("a\u20ACA", "Aâ¬a", "aâ¬AA", "a\u20ACaz", "ÐÐÐѺÎá»").toDF("src") | |
+ checkAnswer(uncDf.filter($"src".ilike("_\u20AC_")), Seq("a\u20ACA", "Aâ¬a").toDF()) | |
+ checkAnswer(uncDf.filter($"src".ilike("_â¬_")), Seq("a\u20ACA", "Aâ¬a").toDF()) | |
+ checkAnswer(uncDf.filter($"src".ilike("_\u20AC_a")), Seq("aâ¬AA").toDF()) | |
+ checkAnswer(uncDf.filter($"src".ilike("_â¬_Z")), Seq("a\u20ACaz").toDF()) | |
+ checkAnswer(uncDf.filter($"src".ilike("ÑÑÑÑ»Ïá»")), Seq("ÐÐÐѺÎá»").toDF()) | |
+ // scalastyle:on | |
+ } | |
+ | |
+ test("SPARK-39093: divide period by integral expression") { | |
+ val df = Seq(((Period.ofMonths(10)), 2)).toDF("pd", "num") | |
+ checkAnswer(df.select($"pd" / ($"num" + 3)), | |
+ Seq((Period.ofMonths(2))).toDF) | |
+ } | |
+ | |
+ test("SPARK-39093: divide duration by integral expression") { | |
+ val df = Seq(((Duration.ofDays(10)), 2)).toDF("dd", "num") | |
+ checkAnswer(df.select($"dd" / ($"num" + 3)), | |
+ Seq((Duration.ofDays(2))).toDF) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala | |
index 711a4bc3fd..b683f3573b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala | |
@@ -82,16 +82,16 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("schema_of_csv - infers schemas") { | |
checkAnswer( | |
spark.range(1).select(schema_of_csv(lit("0.1,1"))), | |
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) | |
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) | |
checkAnswer( | |
spark.range(1).select(schema_of_csv("0.1,1")), | |
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) | |
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) | |
} | |
test("schema_of_csv - infers schemas using options") { | |
val df = spark.range(1) | |
.select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) | |
- checkAnswer(df, Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) | |
+ checkAnswer(df, Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) | |
} | |
test("to_csv - struct") { | |
@@ -220,7 +220,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { | |
val input = concat_ws(",", lit(0.1), lit(1)) | |
checkAnswer( | |
spark.range(1).select(schema_of_csv(input)), | |
- Seq(Row("STRUCT<`_c0`: DOUBLE, `_c1`: INT>"))) | |
+ Seq(Row("STRUCT<_c0: DOUBLE, _c1: INT>"))) | |
} | |
test("optional datetime parser does not affect csv time formatting") { | |
@@ -353,4 +353,22 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
} | |
+ | |
+ test("SPARK-37326: Handle incorrectly formatted timestamp_ntz values in from_csv") { | |
+ val fromCsvDF = Seq("2021-08-12T15:16:23.000+11:00").toDF("csv") | |
+ .select( | |
+ from_csv( | |
+ $"csv", | |
+ StructType(StructField("a", TimestampNTZType) :: Nil), | |
+ Map.empty[String, String]) as "value") | |
+ .selectExpr("value.a") | |
+ checkAnswer(fromCsvDF, Row(null)) | |
+ } | |
+ | |
+ test("SPARK-38955: disable lineSep option in from_csv and schema_of_csv") { | |
+ val df = Seq[String]("1,2\n2").toDF("csv") | |
+ val actual = df.select(from_csv( | |
+ $"csv", schema_of_csv("1,2\n2"), Map.empty[String, String].asJava)) | |
+ checkAnswer(actual, Row(Row(1, "2\n2"))) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | |
index c074a207af..c460f1e43b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | |
@@ -875,6 +875,11 @@ class DataFrameAggregateSuite extends QueryTest | |
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") | |
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) | |
+ checkAnswer( | |
+ courseSales.groupBy("course").agg(max_by(col("year"), col("earnings"))), | |
+ Row("dotNET", 2013) :: Row("Java", 2013) :: Nil | |
+ ) | |
+ | |
checkAnswer( | |
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), | |
Row("b") :: Nil | |
@@ -931,6 +936,11 @@ class DataFrameAggregateSuite extends QueryTest | |
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") | |
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) | |
+ checkAnswer( | |
+ courseSales.groupBy("course").agg(min_by(col("year"), col("earnings"))), | |
+ Row("dotNET", 2012) :: Row("Java", 2012) :: Nil | |
+ ) | |
+ | |
checkAnswer( | |
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), | |
Row("a") :: Nil | |
@@ -1015,10 +1025,15 @@ class DataFrameAggregateSuite extends QueryTest | |
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), | |
Nil) | |
- val error = intercept[AnalysisException] { | |
- sql("SELECT COUNT_IF(x) FROM tempView") | |
+ // When ANSI mode is on, it will implicit cast the string as boolean and throw a runtime | |
+ // error. Here we simply test with ANSI mode off. | |
+ if (!conf.ansiEnabled) { | |
+ val error = intercept[AnalysisException] { | |
+ sql("SELECT COUNT_IF(x) FROM tempView") | |
+ } | |
+ assert(error.message.contains("cannot resolve 'count_if(tempview.x)' due to data type " + | |
+ "mismatch: argument 1 requires boolean type, however, 'tempview.x' is of string type")) | |
} | |
- assert(error.message.contains("function count_if requires boolean type")) | |
} | |
} | |
@@ -1125,9 +1140,11 @@ class DataFrameAggregateSuite extends QueryTest | |
val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col") | |
checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) | |
- val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") | |
- // Spark implicit casts string literal "a" to int to match the key type. | |
- checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) | |
+ if (!conf.ansiEnabled) { | |
+ val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") | |
+ // Spark implicit casts string literal "a" to int to match the key type. | |
+ checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) | |
+ } | |
val arrayDF = Seq(Tuple1(Seq(1))).toDF("col") | |
val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) | |
@@ -1261,12 +1278,12 @@ class DataFrameAggregateSuite extends QueryTest | |
val error = intercept[SparkException] { | |
checkAnswer(df2.select(sum($"year-month")), Nil) | |
} | |
- assert(error.toString contains "java.lang.ArithmeticException: integer overflow") | |
+ assert(error.toString contains "SparkArithmeticException: integer overflow") | |
val error2 = intercept[SparkException] { | |
checkAnswer(df2.select(sum($"day")), Nil) | |
} | |
- assert(error2.toString contains "java.lang.ArithmeticException: long overflow") | |
+ assert(error2.toString contains "SparkArithmeticException: long overflow") | |
} | |
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { | |
@@ -1395,12 +1412,12 @@ class DataFrameAggregateSuite extends QueryTest | |
val error = intercept[SparkException] { | |
checkAnswer(df2.select(avg($"year-month")), Nil) | |
} | |
- assert(error.toString contains "java.lang.ArithmeticException: integer overflow") | |
+ assert(error.toString contains "SparkArithmeticException: integer overflow") | |
val error2 = intercept[SparkException] { | |
checkAnswer(df2.select(avg($"day")), Nil) | |
} | |
- assert(error2.toString contains "java.lang.ArithmeticException: long overflow") | |
+ assert(error2.toString contains "SparkArithmeticException: long overflow") | |
val df3 = intervalData.filter($"class" > 4) | |
val avgDF3 = df3.select(avg($"year-month"), avg($"day")) | |
@@ -1432,6 +1449,16 @@ class DataFrameAggregateSuite extends QueryTest | |
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") | |
checkAnswer(df, Row(2, 3, 1)) | |
} | |
+ | |
+ test("SPARK-41035: Reuse of literal in distinct aggregations should work") { | |
+ val res = sql( | |
+ """select a, count(distinct 100), count(distinct b, 100) | |
+ |from values (1, 2), (4, 5), (4, 6) as data(a, b) | |
+ |group by a; | |
+ |""".stripMargin | |
+ ) | |
+ checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) | |
+ } | |
} | |
case class B(c: Option[Double]) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala | |
new file mode 100644 | |
index 0000000000..749efe95c5 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala | |
@@ -0,0 +1,169 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import scala.collection.JavaConverters._ | |
+ | |
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
+import org.apache.spark.sql.functions._ | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.types._ | |
+ | |
+class DataFrameAsOfJoinSuite extends QueryTest | |
+ with SharedSparkSession | |
+ with AdaptiveSparkPlanHelper { | |
+ | |
+ def prepareForAsOfJoin(): (DataFrame, DataFrame) = { | |
+ val schema1 = StructType( | |
+ StructField("a", IntegerType, false) :: | |
+ StructField("b", StringType, false) :: | |
+ StructField("left_val", StringType, false) :: Nil) | |
+ val rowSeq1: List[Row] = List(Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) | |
+ val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) | |
+ | |
+ val schema2 = StructType( | |
+ StructField("a", IntegerType) :: | |
+ StructField("b", StringType) :: | |
+ StructField("right_val", IntegerType) :: Nil) | |
+ val rowSeq2: List[Row] = List(Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), | |
+ Row(6, "y", 6), Row(7, "z", 7)) | |
+ val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) | |
+ | |
+ (df1, df2) | |
+ } | |
+ | |
+ test("as-of join - simple") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf( | |
+ df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"), | |
+ Seq( | |
+ Row(1, "x", "a", 1, "v", 1), | |
+ Row(5, "y", "b", 3, "x", 3), | |
+ Row(10, "z", "c", 7, "z", 7) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - usingColumns") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), | |
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"), | |
+ Seq( | |
+ Row(10, "z", "c", 7, "z", 7) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - usingColumns, left outer") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), | |
+ joinType = "left", tolerance = null, allowExactMatches = true, direction = "backward"), | |
+ Seq( | |
+ Row(1, "x", "a", null, null, null), | |
+ Row(5, "y", "b", null, null, null), | |
+ Row(10, "z", "c", 7, "z", 7) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - tolerance = 1") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = lit(1), allowExactMatches = true, direction = "backward"), | |
+ Seq( | |
+ Row(1, "x", "a", 1, "v", 1) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - tolerance should be a constant") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ val errMsg = intercept[AnalysisException] { | |
+ df1.joinAsOf( | |
+ df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true, | |
+ direction = "backward") | |
+ }.getMessage | |
+ assert(errMsg.contains("Input argument tolerance must be a constant.")) | |
+ } | |
+ | |
+ test("as-of join - tolerance should be non-negative") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ val errMsg = intercept[AnalysisException] { | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward") | |
+ }.getMessage | |
+ assert(errMsg.contains("Input argument tolerance must be non-negative.")) | |
+ } | |
+ | |
+ test("as-of join - allowExactMatches = false") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = null, allowExactMatches = false, direction = "backward"), | |
+ Seq( | |
+ Row(5, "y", "b", 3, "x", 3), | |
+ Row(10, "z", "c", 7, "z", 7) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - direction = \"forward\"") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "forward"), | |
+ Seq( | |
+ Row(1, "x", "a", 1, "v", 1), | |
+ Row(5, "y", "b", 6, "y", 6) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - direction = \"nearest\"") { | |
+ val (df1, df2) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, | |
+ joinType = "inner", tolerance = null, allowExactMatches = true, direction = "nearest"), | |
+ Seq( | |
+ Row(1, "x", "a", 1, "v", 1), | |
+ Row(5, "y", "b", 6, "y", 6), | |
+ Row(10, "z", "c", 7, "z", 7) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("as-of join - self") { | |
+ val (df1, _) = prepareForAsOfJoin() | |
+ checkAnswer( | |
+ df1.joinAsOf( | |
+ df1, df1.col("a"), df1.col("a"), usingColumns = Seq.empty, | |
+ joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest"), | |
+ Seq( | |
+ Row(1, "x", "a", 5, "y", "b"), | |
+ Row(5, "y", "b", 1, "x", "a"), | |
+ Row(10, "z", "c", 5, "y", "b") | |
+ ) | |
+ ) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | |
index 38b9a75dfb..697cce9b50 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | |
@@ -247,6 +247,93 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
Row(2743272264L, 2180413220L)) | |
} | |
+ test("misc aes function") { | |
+ val key16 = "abcdefghijklmnop" | |
+ val key24 = "abcdefghijklmnop12345678" | |
+ val key32 = "abcdefghijklmnop12345678ABCDEFGH" | |
+ val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w==" | |
+ val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w==" | |
+ val encryptedText32 = "9J3iZbIxnmaG+OIA9Amd+A==" | |
+ val encryptedEmptyText16 = "jmTOhz8XTbskI/zYFFgOFQ==" | |
+ val encryptedEmptyText24 = "9RDK70sHNzqAFRcpfGM5gQ==" | |
+ val encryptedEmptyText32 = "j9IDsCvlYXtcVJUf4FAjQQ==" | |
+ | |
+ val df1 = Seq("Spark", "").toDF | |
+ | |
+ // Successful encryption | |
+ Seq( | |
+ (key16, encryptedText16, encryptedEmptyText16), | |
+ (key24, encryptedText24, encryptedEmptyText24), | |
+ (key32, encryptedText32, encryptedEmptyText32)).foreach { | |
+ case (key, encryptedText, encryptedEmptyText) => | |
+ checkAnswer( | |
+ df1.selectExpr(s"base64(aes_encrypt(value, '$key', 'ECB'))"), | |
+ Seq(Row(encryptedText), Row(encryptedEmptyText))) | |
+ checkAnswer( | |
+ df1.selectExpr(s"base64(aes_encrypt(binary(value), '$key', 'ECB'))"), | |
+ Seq(Row(encryptedText), Row(encryptedEmptyText))) | |
+ } | |
+ | |
+ // Encryption failure - input or key is null | |
+ Seq(key16, key24, key32).foreach { key => | |
+ checkAnswer( | |
+ df1.selectExpr(s"aes_encrypt(cast(null as string), '$key')"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df1.selectExpr(s"aes_encrypt(cast(null as binary), '$key')"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df1.selectExpr(s"aes_encrypt(cast(null as string), binary('$key'))"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df1.selectExpr(s"aes_encrypt(cast(null as binary), binary('$key'))"), | |
+ Seq(Row(null), Row(null))) | |
+ } | |
+ checkAnswer( | |
+ df1.selectExpr("aes_encrypt(value, cast(null as string))"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df1.selectExpr("aes_encrypt(value, cast(null as binary))"), | |
+ Seq(Row(null), Row(null))) | |
+ | |
+ val df2 = Seq( | |
+ (encryptedText16, encryptedText24, encryptedText32), | |
+ (encryptedEmptyText16, encryptedEmptyText24, encryptedEmptyText32) | |
+ ).toDF("value16", "value24", "value32") | |
+ | |
+ // Successful decryption | |
+ Seq( | |
+ ("value16", key16), | |
+ ("value24", key24), | |
+ ("value32", key32)).foreach { | |
+ case (colName, key) => | |
+ checkAnswer( | |
+ df2.selectExpr(s"cast(aes_decrypt(unbase64($colName), '$key', 'ECB') as string)"), | |
+ Seq(Row("Spark"), Row(""))) | |
+ checkAnswer( | |
+ df2.selectExpr(s"cast(aes_decrypt(unbase64($colName), binary('$key'), 'ECB') as string)"), | |
+ Seq(Row("Spark"), Row(""))) | |
+ } | |
+ | |
+ // Decryption failure - input or key is null | |
+ Seq(key16, key24, key32).foreach { key => | |
+ checkAnswer( | |
+ df2.selectExpr(s"aes_decrypt(cast(null as binary), '$key')"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df2.selectExpr(s"aes_decrypt(cast(null as binary), binary('$key'))"), | |
+ Seq(Row(null), Row(null))) | |
+ } | |
+ Seq("value16", "value24", "value32").foreach { colName => | |
+ checkAnswer( | |
+ df2.selectExpr(s"aes_decrypt($colName, cast(null as string))"), | |
+ Seq(Row(null), Row(null))) | |
+ checkAnswer( | |
+ df2.selectExpr(s"aes_decrypt($colName, cast(null as binary))"), | |
+ Seq(Row(null), Row(null))) | |
+ } | |
+ } | |
+ | |
test("string function find_in_set") { | |
val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") | |
@@ -394,6 +481,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
spark.sql("drop temporary function fStringLength") | |
} | |
+ test("SPARK-38130: array_sort with lambda of non-orderable items") { | |
+ val df6 = Seq((Array[Map[String, Int]](Map("a" -> 1), Map("b" -> 2, "c" -> 3), | |
+ Map()), "x")).toDF("a", "b") | |
+ checkAnswer( | |
+ df6.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"), | |
+ Seq( | |
+ Row(Seq[Map[String, Int]](Map(), Map("a" -> 1), Map("b" -> 2, "c" -> 3)))) | |
+ ) | |
+ } | |
+ | |
test("sort_array/array_sort functions") { | |
val df = Seq( | |
(Array[Int](2, 1, 3), Array("b", "c", "a")), | |
@@ -482,8 +579,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
test("array size function - legacy") { | |
- withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { | |
- testSizeOfArray(sizeOfNull = -1) | |
+ if (!conf.ansiEnabled) { | |
+ withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { | |
+ testSizeOfArray(sizeOfNull = -1) | |
+ } | |
} | |
} | |
@@ -622,6 +721,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
+ test("SPARK-40292: arrays_zip should retain field names in nested structs") { | |
+ val df = spark.sql(""" | |
+ select | |
+ named_struct( | |
+ 'arr_1', array(named_struct('a', 1, 'b', 2)), | |
+ 'arr_2', array(named_struct('p', 1, 'q', 2)), | |
+ 'field', named_struct( | |
+ 'arr_3', array(named_struct('x', 1, 'y', 2)) | |
+ ) | |
+ ) as obj | |
+ """) | |
+ | |
+ val res = df.selectExpr("arrays_zip(obj.arr_1, obj.arr_2, obj.field.arr_3) as arr") | |
+ | |
+ val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType] | |
+ .elementType.asInstanceOf[StructType].fieldNames | |
+ assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3")) | |
+ } | |
+ | |
+ test("SPARK-40470: array_zip should return field names in GetArrayStructFields") { | |
+ val df = spark.read.json(Seq( | |
+ """ | |
+ { | |
+ "arr": [ | |
+ { | |
+ "obj": { | |
+ "nested": { | |
+ "field1": [1], | |
+ "field2": [2] | |
+ } | |
+ } | |
+ } | |
+ ] | |
+ } | |
+ """).toDS()) | |
+ | |
+ val res = df | |
+ .selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr") | |
+ .select(col("arr.field1"), col("arr.field2")) | |
+ | |
+ val fieldNames = res.schema.fieldNames | |
+ assert(fieldNames.toSeq === Seq("field1", "field2")) | |
+ | |
+ checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-40470: arrays_zip should return field names in GetMapValue") { | |
+ val df = spark.sql(""" | |
+ select | |
+ map( | |
+ 'arr_1', array(1, 2), | |
+ 'arr_2', array(3, 4) | |
+ ) as map_obj | |
+ """) | |
+ | |
+ val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr") | |
+ | |
+ val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType] | |
+ .elementType.asInstanceOf[StructType].fieldNames | |
+ assert(fieldNames.toSeq === Seq("arr_1", "arr_2")) | |
+ | |
+ checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4)))) | |
+ } | |
+ | |
def testSizeOfMap(sizeOfNull: Any): Unit = { | |
val df = Seq( | |
(Map[Int, Int](1 -> 1, 2 -> 2), "x"), | |
@@ -635,8 +798,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
test("map size function - legacy") { | |
- withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { | |
- testSizeOfMap(sizeOfNull = -1: Int) | |
+ if (!conf.ansiEnabled) { | |
+ withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { | |
+ testSizeOfMap(sizeOfNull = -1: Int) | |
+ } | |
} | |
} | |
@@ -930,15 +1095,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
Seq(Row(false)) | |
) | |
- val e1 = intercept[AnalysisException] { | |
- OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)") | |
- } | |
- val errorMsg1 = | |
- s""" | |
- |Input to function array_contains should have been array followed by a | |
- |value with same element type, but it's [array<int>, decimal(38,29)]. | |
+ if (!conf.ansiEnabled) { | |
+ val e1 = intercept[AnalysisException] { | |
+ OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)") | |
+ } | |
+ val errorMsg1 = | |
+ s""" | |
+ |Input to function array_contains should have been array followed by a | |
+ |value with same element type, but it's [array<int>, decimal(38,29)]. | |
""".stripMargin.replace("\n", " ").trim() | |
- assert(e1.message.contains(errorMsg1)) | |
+ assert(e1.message.contains(errorMsg1)) | |
+ } | |
val e2 = intercept[AnalysisException] { | |
OneRowRelation().selectExpr("array_contains(array(1), 'foo')") | |
@@ -1367,41 +1534,43 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
Seq(Row(null), Row(null), Row(null)) | |
) | |
} | |
- checkAnswer( | |
- df.select(element_at(df("a"), 4)), | |
- Seq(Row(null), Row(null), Row(null)) | |
- ) | |
- checkAnswer( | |
- df.select(element_at(df("a"), df("b"))), | |
- Seq(Row("1"), Row(""), Row(null)) | |
- ) | |
- checkAnswer( | |
- df.selectExpr("element_at(a, b)"), | |
- Seq(Row("1"), Row(""), Row(null)) | |
- ) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ df.select(element_at(df("a"), 4)), | |
+ Seq(Row(null), Row(null), Row(null)) | |
+ ) | |
+ checkAnswer( | |
+ df.select(element_at(df("a"), df("b"))), | |
+ Seq(Row("1"), Row(""), Row(null)) | |
+ ) | |
+ checkAnswer( | |
+ df.selectExpr("element_at(a, b)"), | |
+ Seq(Row("1"), Row(""), Row(null)) | |
+ ) | |
- checkAnswer( | |
- df.select(element_at(df("a"), 1)), | |
- Seq(Row("1"), Row(null), Row(null)) | |
- ) | |
- checkAnswer( | |
- df.select(element_at(df("a"), -1)), | |
- Seq(Row("3"), Row(""), Row(null)) | |
- ) | |
+ checkAnswer( | |
+ df.select(element_at(df("a"), 1)), | |
+ Seq(Row("1"), Row(null), Row(null)) | |
+ ) | |
+ checkAnswer( | |
+ df.select(element_at(df("a"), -1)), | |
+ Seq(Row("3"), Row(""), Row(null)) | |
+ ) | |
- checkAnswer( | |
- df.selectExpr("element_at(a, 4)"), | |
- Seq(Row(null), Row(null), Row(null)) | |
- ) | |
+ checkAnswer( | |
+ df.selectExpr("element_at(a, 4)"), | |
+ Seq(Row(null), Row(null), Row(null)) | |
+ ) | |
- checkAnswer( | |
- df.selectExpr("element_at(a, 1)"), | |
- Seq(Row("1"), Row(null), Row(null)) | |
- ) | |
- checkAnswer( | |
- df.selectExpr("element_at(a, -1)"), | |
- Seq(Row("3"), Row(""), Row(null)) | |
- ) | |
+ checkAnswer( | |
+ df.selectExpr("element_at(a, 1)"), | |
+ Seq(Row("1"), Row(null), Row(null)) | |
+ ) | |
+ checkAnswer( | |
+ df.selectExpr("element_at(a, -1)"), | |
+ Seq(Row("3"), Row(""), Row(null)) | |
+ ) | |
+ } | |
val e1 = intercept[AnalysisException] { | |
Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") | |
@@ -1463,10 +1632,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
Seq(Row("a")) | |
) | |
- checkAnswer( | |
- OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), | |
- Seq(Row(null)) | |
- ) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), | |
+ Seq(Row(null)) | |
+ ) | |
+ } | |
val e3 = intercept[AnalysisException] { | |
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')") | |
@@ -1541,10 +1712,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
// Simple test cases | |
def simpleTest(): Unit = { | |
- checkAnswer ( | |
- df.select(concat($"i1", $"s1")), | |
- Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) | |
- ) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ df.select(concat($"i1", $"s1")), | |
+ Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) | |
+ ) | |
+ } | |
checkAnswer( | |
df.select(concat($"i1", $"i2", $"i3")), | |
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) | |
@@ -2328,7 +2501,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex3 = intercept[AnalysisException] { | |
df.selectExpr("transform(a, x -> x)") | |
} | |
- assert(ex3.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex3.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex3.messageParameters.head == "a") | |
} | |
test("map_filter") { | |
@@ -2399,7 +2573,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex4 = intercept[AnalysisException] { | |
df.selectExpr("map_filter(a, (k, v) -> k > v)") | |
} | |
- assert(ex4.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4.messageParameters.head == "a") | |
} | |
test("filter function - array for primitive type not containing null") { | |
@@ -2558,7 +2733,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex4 = intercept[AnalysisException] { | |
df.selectExpr("filter(a, x -> x)") | |
} | |
- assert(ex4.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4.messageParameters.head == "a") | |
} | |
test("exists function - array for primitive type not containing null") { | |
@@ -2690,7 +2866,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex4 = intercept[AnalysisException] { | |
df.selectExpr("exists(a, x -> x)") | |
} | |
- assert(ex4.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4.messageParameters.head == "a") | |
} | |
test("forall function - array for primitive type not containing null") { | |
@@ -2836,12 +3013,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex4 = intercept[AnalysisException] { | |
df.selectExpr("forall(a, x -> x)") | |
} | |
- assert(ex4.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4.messageParameters.head == "a") | |
val ex4a = intercept[AnalysisException] { | |
df.select(forall(col("a"), x => x)) | |
} | |
- assert(ex4a.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4a.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4a.messageParameters.head == "a") | |
} | |
test("aggregate function - array for primitive type not containing null") { | |
@@ -3018,7 +3197,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex5 = intercept[AnalysisException] { | |
df.selectExpr("aggregate(a, 0, (acc, x) -> x)") | |
} | |
- assert(ex5.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex5.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex5.messageParameters.head == "a") | |
} | |
test("map_zip_with function - map of primitive types") { | |
@@ -3571,7 +3751,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
val ex4 = intercept[AnalysisException] { | |
df.selectExpr("zip_with(a1, a, (acc, x) -> x)") | |
} | |
- assert(ex4.getMessage.contains("cannot resolve 'a'")) | |
+ assert(ex4.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex4.messageParameters.head == "a") | |
} | |
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala | |
index a803fa88ed..4298d503b1 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala | |
@@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest | |
} | |
} | |
+ Seq("left_semi", "left_anti").foreach { joinType => | |
+ test(s"SPARK-41162: $joinType self-joined aggregated dataframe") { | |
+ // aggregated dataframe | |
+ val ids = Seq(1, 2, 3).toDF("id").distinct() | |
+ | |
+ // self-joined via joinType | |
+ val result = ids.withColumn("id", $"id" + 1) | |
+ .join(ids, usingColumns = Seq("id"), joinType = joinType).collect() | |
+ | |
+ val expected = joinType match { | |
+ case "left_semi" => 2 | |
+ case "left_anti" => 1 | |
+ case _ => -1 // unsupported test type, test will always fail | |
+ } | |
+ assert(result.length == expected) | |
+ } | |
+ } | |
+ | |
def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match { | |
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left) | |
case Filter(_, child) => extractLeftDeepInnerJoins(child) | |
@@ -499,4 +517,26 @@ class DataFrameJoinSuite extends QueryTest | |
) | |
} | |
} | |
+ | |
+ test("SPARK-39376: Hide duplicated columns in star expansion of subquery alias from USING JOIN") { | |
+ val joinDf = testData2.as("testData2").join( | |
+ testData3.as("testData3"), usingColumns = Seq("a"), joinType = "fullouter") | |
+ val equivalentQueries = Seq( | |
+ joinDf.select($"*"), | |
+ joinDf.as("r").select($"*"), | |
+ joinDf.as("r").select($"r.*") | |
+ ) | |
+ equivalentQueries.foreach { query => | |
+ checkAnswer(query, | |
+ Seq( | |
+ Row(1, 1, null), | |
+ Row(1, 2, null), | |
+ Row(2, 1, 2), | |
+ Row(2, 2, 2), | |
+ Row(3, 1, null), | |
+ Row(3, 2, null) | |
+ ) | |
+ ) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | |
index 20ae995af6..8dbc57c042 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | |
@@ -444,21 +444,25 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
test("replace float with nan") { | |
- checkAnswer( | |
- createNaNDF().na.replace("*", Map( | |
- 1.0f -> Float.NaN | |
- )), | |
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: | |
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ createNaNDF().na.replace("*", Map( | |
+ 1.0f -> Float.NaN | |
+ )), | |
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: | |
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) | |
+ } | |
} | |
test("replace double with nan") { | |
- checkAnswer( | |
- createNaNDF().na.replace("*", Map( | |
- 1.0 -> Double.NaN | |
- )), | |
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: | |
- Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ createNaNDF().na.replace("*", Map( | |
+ 1.0 -> Double.NaN | |
+ )), | |
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: | |
+ Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) | |
+ } | |
} | |
test("SPARK-34417: test fillMap() for column with a dot in the name") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | |
index 32cbb8b457..1a0c95beb1 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | |
@@ -17,6 +17,7 @@ | |
package org.apache.spark.sql | |
+import java.time.LocalDateTime | |
import java.util.Locale | |
import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst | |
@@ -323,17 +324,6 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { | |
checkAnswer(df, expected) | |
} | |
- test("pivoting column list") { | |
- val exception = intercept[RuntimeException] { | |
- trainingSales | |
- .groupBy($"sales.year") | |
- .pivot(struct(lower($"sales.course"), $"training")) | |
- .agg(sum($"sales.earnings")) | |
- .collect() | |
- } | |
- assert(exception.getMessage.contains("Unsupported literal type")) | |
- } | |
- | |
test("SPARK-26403: pivoting by array column") { | |
val df = Seq( | |
(2, Seq.empty[String]), | |
@@ -352,4 +342,16 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { | |
percentile_approx(col("value"), array(lit(0.5)), lit(10000))) | |
checkAnswer(actual, Row(Array(2.5), Array(3.0))) | |
} | |
+ | |
+ test("SPARK-38133: Grouping by TIMESTAMP_NTZ should not corrupt results") { | |
+ checkAnswer( | |
+ courseSales.withColumn("ts", $"year".cast("string").cast("timestamp_ntz")) | |
+ .groupBy("ts") | |
+ .pivot("course", Seq("dotNET", "Java")) | |
+ .agg(sum($"earnings")) | |
+ .select("ts", "dotNET", "Java"), | |
+ Row(LocalDateTime.of(2012, 1, 1, 0, 0, 0, 0), 15000.0, 20000.0) :: | |
+ Row(LocalDateTime.of(2013, 1, 1, 0, 0, 0, 0), 48000.0, 30000.0) :: Nil | |
+ ) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala | |
index fc549e307c..917f80e581 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala | |
@@ -63,13 +63,15 @@ class DataFrameRangeSuite extends QueryTest with SharedSparkSession with Eventua | |
val res7 = spark.range(-10, -9, -20, 1).select("id") | |
assert(res7.count == 0) | |
- val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") | |
- assert(res8.count == 3) | |
- assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) | |
- | |
- val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") | |
- assert(res9.count == 2) | |
- assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) | |
+ if (!conf.ansiEnabled) { | |
+ val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") | |
+ assert(res8.count == 3) | |
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) | |
+ | |
+ val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") | |
+ assert(res9.count == 2) | |
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) | |
+ } | |
// only end provided as argument | |
val res10 = spark.range(10).select("id") | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala | |
index e520cbea48..4d0dd46b95 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala | |
@@ -481,7 +481,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { | |
val ex = intercept[AnalysisException]( | |
df3.join(df1, year($"df1.timeStr") === year($"df3.tsStr")) | |
) | |
- assert(ex.message.contains("cannot resolve 'df1.timeStr'")) | |
+ assert(ex.message.contains("Column 'df1.timeStr' does not exist.")) | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala | |
index 7a0cd420d4..a5414f3e80 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala | |
@@ -17,12 +17,16 @@ | |
package org.apache.spark.sql | |
+import java.time.LocalDateTime | |
+ | |
import org.scalatest.BeforeAndAfterEach | |
-import org.apache.spark.sql.catalyst.plans.logical.Expand | |
+import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
+import org.apache.spark.sql.catalyst.expressions.AttributeReference | |
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.test.SharedSparkSession | |
-import org.apache.spark.sql.types.StringType | |
+import org.apache.spark.sql.types._ | |
class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
with BeforeAndAfterEach { | |
@@ -79,7 +83,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
// key "b" => (19:39:27 ~ 19:39:37) | |
checkAnswer( | |
- df.groupBy(session_window($"time", "10 seconds"), 'id) | |
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) | |
.agg(count("*").as("counts"), sum("value").as("sum")) | |
.orderBy($"session_window.start".asc) | |
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", | |
@@ -109,7 +113,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
// key "b" => (19:39:27 ~ 19:39:37) | |
checkAnswer( | |
- df.groupBy(session_window($"time", "10 seconds"), 'id) | |
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) | |
.agg(count("*").as("counts"), sum_distinct(col("value")).as("sum")) | |
.orderBy($"session_window.start".asc) | |
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", | |
@@ -138,7 +142,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
// key "b" => (19:39:27 ~ 19:39:37) | |
checkAnswer( | |
- df.groupBy(session_window($"time", "10 seconds"), 'id) | |
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) | |
.agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2")) | |
.orderBy($"session_window.start".asc) | |
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", | |
@@ -167,7 +171,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
// b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) | |
checkAnswer( | |
- df.groupBy(session_window($"time", "10 seconds"), 'id) | |
+ df.groupBy(session_window($"time", "10 seconds"), Symbol("id")) | |
.agg(count("*").as("counts"), sum("value").as("sum")) | |
.orderBy($"session_window.start".asc) | |
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", | |
@@ -377,4 +381,118 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession | |
) | |
} | |
} | |
+ | |
+ test("SPARK-36465: filter out events with invalid gap duration.") { | |
+ val df = Seq( | |
+ ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") | |
+ | |
+ checkAnswer( | |
+ df.groupBy(session_window($"time", "x sec")) | |
+ .agg(count("*").as("counts")) | |
+ .orderBy($"session_window.start".asc) | |
+ .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), | |
+ $"counts"), | |
+ Seq() | |
+ ) | |
+ | |
+ withTempTable { table => | |
+ checkAnswer( | |
+ spark.sql("select session_window(time, " + | |
+ """case when value = 1 then "2 seconds" when value = 2 then "invalid gap duration" """ + | |
+ s"""else "20 seconds" end), value from $table""") | |
+ .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), | |
+ $"value"), | |
+ Seq( | |
+ Row("2016-03-27 19:39:27", "2016-03-27 19:39:47", 4), | |
+ Row("2016-03-27 19:39:34", "2016-03-27 19:39:36", 1) | |
+ ) | |
+ ) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-36724: Support timestamp_ntz as a type of time column for SessionWindow") { | |
+ val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"), | |
+ (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id") | |
+ val aggDF = | |
+ df.groupBy(session_window($"time", "10 seconds")) | |
+ .agg(count("*").as("counts")) | |
+ .orderBy($"session_window.start".asc) | |
+ .select($"session_window.start".cast("string"), | |
+ $"session_window.end".cast("string"), $"counts") | |
+ | |
+ val aggregate = aggDF.queryExecution.analyzed.children(0).children(0) | |
+ assert(aggregate.isInstanceOf[Aggregate]) | |
+ | |
+ val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0) | |
+ assert(timeWindow.isInstanceOf[AttributeReference]) | |
+ | |
+ val attributeReference = timeWindow.asInstanceOf[AttributeReference] | |
+ assert(attributeReference.name == "session_window") | |
+ | |
+ val expectedSchema = StructType( | |
+ Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))) | |
+ assert(attributeReference.dataType == expectedSchema) | |
+ | |
+ checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2))) | |
+ } | |
+ | |
+ test("SPARK-38227: 'start' and 'end' fields should be nullable") { | |
+ // We expect the fields in window struct as nullable since the dataType of SessionWindow | |
+ // defines them as nullable. The rule 'SessionWindowing' should respect the dataType. | |
+ val df1 = Seq( | |
+ ("hello", "2016-03-27 09:00:05", 1), | |
+ ("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value") | |
+ val df2 = Seq( | |
+ ("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1), | |
+ ("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", "time", "value") | |
+ | |
+ val udf = spark.udf.register("gapDuration", (s: String) => { | |
+ if (s == "hello") { | |
+ "1 second" | |
+ } else if (s == "structured") { | |
+ // zero gap duration will be filtered out from aggregation | |
+ "0 second" | |
+ } else if (s == "world") { | |
+ // negative gap duration will be filtered out from aggregation | |
+ "-10 seconds" | |
+ } else { | |
+ "10 seconds" | |
+ } | |
+ }) | |
+ | |
+ def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { | |
+ schema.find(_.name == colName) match { | |
+ case Some(StructField(_, st: StructType, _, _)) => | |
+ assertFieldInWindowStruct(st, "start") | |
+ assertFieldInWindowStruct(st, "end") | |
+ | |
+ case _ => fail("Failed to find suitable window column from DataFrame!") | |
+ } | |
+ } | |
+ | |
+ def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { | |
+ val field = windowType.fields.find(_.name == fieldName) | |
+ assert(field.isDefined, s"'$fieldName' field should exist in window struct") | |
+ assert(field.get.nullable, s"'$fieldName' field should be nullable") | |
+ } | |
+ | |
+ for { | |
+ df <- Seq(df1, df2) | |
+ nullable <- Seq(true, false) | |
+ } { | |
+ val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( | |
+ StructType(df.schema.fields.map(_.copy(nullable = nullable))))) | |
+ // session window without dynamic gap | |
+ val windowedProject = dfWithDesiredNullability | |
+ .select(session_window($"time", "10 seconds").as("session"), $"value") | |
+ val schema = windowedProject.queryExecution.optimizedPlan.schema | |
+ validateWindowColumnInSchema(schema, "session") | |
+ | |
+ // session window with dynamic gap | |
+ val windowedProject2 = dfWithDesiredNullability | |
+ .select(session_window($"time", udf($"id")).as("session"), $"value") | |
+ val schema2 = windowedProject2.queryExecution.optimizedPlan.schema | |
+ validateWindowColumnInSchema(schema2, "session") | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala | |
index f8e0cfc32a..ca04adf642 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala | |
@@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp} | |
import org.apache.spark.sql.catalyst.optimizer.RemoveNoopUnion | |
import org.apache.spark.sql.catalyst.plans.logical.Union | |
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec} | |
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession, SQLTestData} | |
@@ -339,7 +341,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
).toDF("date", "timestamp", "decimal") | |
val widenTypedRows = Seq( | |
- (new Timestamp(2), 10.5D, "string") | |
+ (new Timestamp(2), 10.5D, "2021-01-01 00:00:00") | |
).toDF("date", "timestamp", "decimal") | |
dates.union(widenTypedRows).collect() | |
@@ -536,24 +538,25 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
} | |
test("union by name - type coercion") { | |
- var df1 = Seq((1, "a")).toDF("c0", "c1") | |
- var df2 = Seq((3, 1L)).toDF("c1", "c0") | |
- checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) | |
- | |
- df1 = Seq((1, 1.0)).toDF("c0", "c1") | |
- df2 = Seq((8L, 3.0)).toDF("c1", "c0") | |
+ var df1 = Seq((1, 1.0)).toDF("c0", "c1") | |
+ var df2 = Seq((8L, 3.0)).toDF("c1", "c0") | |
checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) | |
- | |
- df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") | |
- df2 = Seq(("a", 4.0)).toDF("c1", "c0") | |
- checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) | |
- | |
- df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") | |
- df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") | |
- val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") | |
- checkAnswer(df1.unionByName(df2.unionByName(df3)), | |
- Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil | |
- ) | |
+ if (!conf.ansiEnabled) { | |
+ df1 = Seq((1, "a")).toDF("c0", "c1") | |
+ df2 = Seq((3, 1L)).toDF("c1", "c0") | |
+ checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) | |
+ | |
+ df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") | |
+ df2 = Seq(("a", 4.0)).toDF("c1", "c0") | |
+ checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) | |
+ | |
+ df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") | |
+ df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") | |
+ val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") | |
+ checkAnswer(df1.unionByName(df2.unionByName(df3)), | |
+ Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil | |
+ ) | |
+ } | |
} | |
test("union by name - check case sensitivity") { | |
@@ -802,7 +805,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
StructType(Seq(StructField("topLevelCol", nestedStructType2)))) | |
val union = df1.unionByName(df2, allowMissingColumns = true) | |
- assert(union.schema.toDDL == "`topLevelCol` STRUCT<`b`: STRING, `a`: STRING>") | |
+ assert(union.schema.toDDL == "topLevelCol STRUCT<b: STRING, a: STRING>") | |
checkAnswer(union, Row(Row("b", null)) :: Row(Row("b", "a")) :: Nil) | |
} | |
@@ -834,15 +837,15 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
StructType(Seq(StructField("topLevelCol", nestedStructType2)))) | |
var unionDf = df1.unionByName(df2, true) | |
- assert(unionDf.schema.toDDL == "`topLevelCol` " + | |
- "STRUCT<`b`: STRUCT<`ba`: STRING, `bb`: STRING>, `a`: STRUCT<`aa`: STRING>>") | |
+ assert(unionDf.schema.toDDL == "topLevelCol " + | |
+ "STRUCT<b: STRUCT<ba: STRING, bb: STRING>, a: STRUCT<aa: STRING>>") | |
checkAnswer(unionDf, | |
Row(Row(Row("ba", null), null)) :: | |
Row(Row(Row(null, "bb"), Row("aa"))) :: Nil) | |
unionDf = df2.unionByName(df1, true) | |
- assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRUCT<`aa`: STRING>, " + | |
- "`b`: STRUCT<`bb`: STRING, `ba`: STRING>>") | |
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<a: STRUCT<aa: STRING>, " + | |
+ "b: STRUCT<bb: STRING, ba: STRING>>") | |
checkAnswer(unionDf, | |
Row(Row(null, Row(null, "ba"))) :: | |
Row(Row(Row("aa"), Row("bb", null))) :: Nil) | |
@@ -999,8 +1002,9 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
}.getMessage | |
assert(errMsg.contains("Union can only be performed on tables with" + | |
" the compatible column types." + | |
- " struct<c1:int,c2:int,c3:struct<c3:int,c5:int>> <> struct<c1:int,c2:int,c3:struct<c3:int>>" + | |
- " at the third column of the second table")) | |
+ " The third column of the second table is struct<c1:int,c2:int,c3:struct<c3:int,c5:int>>" + | |
+ " type which is not compatible with struct<c1:int,c2:int,c3:struct<c3:int>> at same" + | |
+ " column of first table")) | |
// diff Case sensitive attributes names and diff sequence scenario for unionByName | |
df1 = Seq((1, 2, UnionClass1d(1, 2, Struct3(1)))).toDF("a", "b", "c") | |
@@ -1039,26 +1043,355 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
- test("SPARK-36673: Union of structs with different orders") { | |
+ test("SPARK-36797: Union should resolve nested columns as top-level columns") { | |
+ // Different nested field names, but same nested field types. Union resolves column by position. | |
val df1 = spark.range(2).withColumn("nested", | |
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2")))) | |
val df2 = spark.range(2).withColumn("nested", | |
struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1")))) | |
- val err1 = intercept[AnalysisException](df1.union(df2).collect()) | |
- | |
- assert(err1.message | |
- .contains("Union can only be performed on tables with the compatible column types")) | |
+ checkAnswer(df1.union(df2), | |
+ Row(0, Row(0, Row(0))) :: Row(0, Row(0, Row(0))) :: Row(1, Row(5, Row(10))) :: | |
+ Row(1, Row(5, Row(10))) :: Nil) | |
- val df3 = spark.range(2).withColumn("nested", | |
- struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string")))) | |
+ val df3 = spark.range(2).withColumn("nested array", | |
+ array(struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))) | |
val df4 = spark.range(2).withColumn("nested", | |
+ array(struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))) | |
+ | |
+ checkAnswer(df3.union(df4), | |
+ Row(0, Seq(Row(0, Row(0)))) :: Row(0, Seq(Row(0, Row(0)))) :: Row(1, Seq(Row(5, Row(10)))) :: | |
+ Row(1, Seq(Row(5, Row(10)))) :: Nil) | |
+ | |
+ val df5 = spark.range(2).withColumn("nested array", | |
+ map(struct(expr("id * 5 AS key1")), | |
+ struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))) | |
+ val df6 = spark.range(2).withColumn("nested", | |
+ map(struct(expr("id * 5 AS key2")), | |
+ struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))) | |
+ | |
+ checkAnswer(df5.union(df6), | |
+ Row(0, Map(Row(0) -> Row(0, Row(0)))) :: | |
+ Row(0, Map(Row(0) -> Row(0, Row(0)))) :: | |
+ Row(1, Map(Row(5) ->Row(5, Row(10)))) :: | |
+ Row(1, Map(Row(5) ->Row(5, Row(10)))) :: Nil) | |
+ | |
+ // Different nested field names, and different nested field types. | |
+ val df7 = spark.range(2).withColumn("nested", | |
+ struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string")))) | |
+ val df8 = spark.range(2).withColumn("nested", | |
struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1")))) | |
- val err2 = intercept[AnalysisException](df3.union(df4).collect()) | |
- assert(err2.message | |
+ val err = intercept[AnalysisException](df7.union(df8).collect()) | |
+ assert(err.message | |
.contains("Union can only be performed on tables with the compatible column types")) | |
} | |
+ | |
+ test("SPARK-36546: Add unionByName support to arrays of structs") { | |
+ val arrayType1 = ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType), | |
+ StructField("bb", StringType) | |
+ )) | |
+ ) | |
+ val arrayValues1 = Seq(Row("ba", "bb")) | |
+ | |
+ val arrayType2 = ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType), | |
+ StructField("ba", StringType) | |
+ )) | |
+ ) | |
+ val arrayValues2 = Seq(Row("bb", "ba")) | |
+ | |
+ val df1 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(arrayValues1) :: Nil), | |
+ StructType(Seq(StructField("arr", arrayType1)))) | |
+ | |
+ val df2 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(arrayValues2) :: Nil), | |
+ StructType(Seq(StructField("arr", arrayType2)))) | |
+ | |
+ var unionDf = df1.unionByName(df2) | |
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<ba: STRING, bb: STRING>>") | |
+ checkAnswer(unionDf, | |
+ Row(Seq(Row("ba", "bb"))) :: | |
+ Row(Seq(Row("ba", "bb"))) :: Nil) | |
+ | |
+ unionDf = df2.unionByName(df1) | |
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<bb: STRING, ba: STRING>>") | |
+ checkAnswer(unionDf, | |
+ Row(Seq(Row("bb", "ba"))) :: | |
+ Row(Seq(Row("bb", "ba"))) :: Nil) | |
+ | |
+ val arrayType3 = ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType) | |
+ )) | |
+ ) | |
+ val arrayValues3 = Seq(Row("ba")) | |
+ | |
+ val arrayType4 = ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType) | |
+ )) | |
+ ) | |
+ val arrayValues4 = Seq(Row("bb")) | |
+ | |
+ val df3 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(arrayValues3) :: Nil), | |
+ StructType(Seq(StructField("arr", arrayType3)))) | |
+ | |
+ val df4 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(arrayValues4) :: Nil), | |
+ StructType(Seq(StructField("arr", arrayType4)))) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df3.unionByName(df4) | |
+ } | |
+ | |
+ unionDf = df3.unionByName(df4, true) | |
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<ba: STRING, bb: STRING>>") | |
+ checkAnswer(unionDf, | |
+ Row(Seq(Row("ba", null))) :: | |
+ Row(Seq(Row(null, "bb"))) :: Nil) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df4.unionByName(df3) | |
+ } | |
+ | |
+ unionDf = df4.unionByName(df3, true) | |
+ assert(unionDf.schema.toDDL == "arr ARRAY<STRUCT<bb: STRING, ba: STRING>>") | |
+ checkAnswer(unionDf, | |
+ Row(Seq(Row("bb", null))) :: | |
+ Row(Seq(Row(null, "ba"))) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-36546: Add unionByName support to nested arrays of structs") { | |
+ val nestedStructType1 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType), | |
+ StructField("bb", StringType) | |
+ )) | |
+ )) | |
+ )) | |
+ val nestedStructValues1 = Row(Seq(Row("ba", "bb"))) | |
+ | |
+ val nestedStructType2 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType), | |
+ StructField("ba", StringType) | |
+ )) | |
+ )) | |
+ )) | |
+ val nestedStructValues2 = Row(Seq(Row("bb", "ba"))) | |
+ | |
+ val df1 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues1) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType1)))) | |
+ | |
+ val df2 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues2) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType2)))) | |
+ | |
+ var unionDf = df1.unionByName(df2) | |
+ assert(unionDf.schema.toDDL == "topLevelCol " + | |
+ "STRUCT<b: ARRAY<STRUCT<ba: STRING, bb: STRING>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Row("ba", "bb")))) :: | |
+ Row(Row(Seq(Row("ba", "bb")))) :: Nil) | |
+ | |
+ unionDf = df2.unionByName(df1) | |
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + | |
+ "b: ARRAY<STRUCT<bb: STRING, ba: STRING>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Row("bb", "ba")))) :: | |
+ Row(Row(Seq(Row("bb", "ba")))) :: Nil) | |
+ | |
+ val nestedStructType3 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType) | |
+ )) | |
+ )) | |
+ )) | |
+ val nestedStructValues3 = Row(Seq(Row("ba"))) | |
+ | |
+ val nestedStructType4 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType) | |
+ )) | |
+ )) | |
+ )) | |
+ val nestedStructValues4 = Row(Seq(Row("bb"))) | |
+ | |
+ val df3 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues3) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType3)))) | |
+ | |
+ val df4 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues4) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType4)))) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df3.unionByName(df4) | |
+ } | |
+ | |
+ unionDf = df3.unionByName(df4, true) | |
+ assert(unionDf.schema.toDDL == "topLevelCol " + | |
+ "STRUCT<b: ARRAY<STRUCT<ba: STRING, bb: STRING>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Row("ba", null)))) :: | |
+ Row(Row(Seq(Row(null, "bb")))) :: Nil) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df4.unionByName(df3) | |
+ } | |
+ | |
+ unionDf = df4.unionByName(df3, true) | |
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + | |
+ "b: ARRAY<STRUCT<bb: STRING, ba: STRING>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Row("bb", null)))) :: | |
+ Row(Row(Seq(Row(null, "ba")))) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-36546: Add unionByName support to multiple levels of nested arrays of structs") { | |
+ val nestedStructType1 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType), | |
+ StructField("bb", StringType) | |
+ )) | |
+ ) | |
+ )) | |
+ )) | |
+ val nestedStructValues1 = Row(Seq(Seq(Row("ba", "bb")))) | |
+ | |
+ val nestedStructType2 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType), | |
+ StructField("ba", StringType) | |
+ )) | |
+ ) | |
+ )) | |
+ )) | |
+ val nestedStructValues2 = Row(Seq(Seq(Row("bb", "ba")))) | |
+ | |
+ val df1 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues1) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType1)))) | |
+ | |
+ val df2 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues2) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType2)))) | |
+ | |
+ var unionDf = df1.unionByName(df2) | |
+ assert(unionDf.schema.toDDL == "topLevelCol " + | |
+ "STRUCT<b: ARRAY<ARRAY<STRUCT<ba: STRING, bb: STRING>>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Seq(Row("ba", "bb"))))) :: | |
+ Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Nil) | |
+ | |
+ unionDf = df2.unionByName(df1) | |
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + | |
+ "b: ARRAY<ARRAY<STRUCT<bb: STRING, ba: STRING>>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Seq(Row("bb", "ba"))))) :: | |
+ Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Nil) | |
+ | |
+ val nestedStructType3 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ ArrayType( | |
+ StructType(Seq( | |
+ StructField("ba", StringType) | |
+ )) | |
+ ) | |
+ )) | |
+ )) | |
+ val nestedStructValues3 = Row(Seq(Seq(Row("ba")))) | |
+ | |
+ val nestedStructType4 = StructType(Seq( | |
+ StructField("b", ArrayType( | |
+ ArrayType( | |
+ StructType(Seq( | |
+ StructField("bb", StringType) | |
+ )) | |
+ ) | |
+ )) | |
+ )) | |
+ val nestedStructValues4 = Row(Seq(Seq(Row("bb")))) | |
+ | |
+ val df3 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues3) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType3)))) | |
+ | |
+ val df4 = spark.createDataFrame( | |
+ sparkContext.parallelize(Row(nestedStructValues4) :: Nil), | |
+ StructType(Seq(StructField("topLevelCol", nestedStructType4)))) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df3.unionByName(df4) | |
+ } | |
+ | |
+ unionDf = df3.unionByName(df4, true) | |
+ assert(unionDf.schema.toDDL == "topLevelCol " + | |
+ "STRUCT<b: ARRAY<ARRAY<STRUCT<ba: STRING, bb: STRING>>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Seq(Row("ba", null))))) :: | |
+ Row(Row(Seq(Seq(Row(null, "bb"))))) :: Nil) | |
+ | |
+ assertThrows[AnalysisException] { | |
+ df4.unionByName(df3) | |
+ } | |
+ | |
+ unionDf = df4.unionByName(df3, true) | |
+ assert(unionDf.schema.toDDL == "topLevelCol STRUCT<" + | |
+ "b: ARRAY<ARRAY<STRUCT<bb: STRING, ba: STRING>>>>") | |
+ checkAnswer(unionDf, | |
+ Row(Row(Seq(Seq(Row("bb", null))))) :: | |
+ Row(Row(Seq(Seq(Row(null, "ba"))))) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-37371: UnionExec should support columnar if all children support columnar") { | |
+ def checkIfColumnar( | |
+ plan: SparkPlan, | |
+ targetPlan: (SparkPlan) => Boolean, | |
+ isColumnar: Boolean): Unit = { | |
+ val target = plan.collect { | |
+ case p if targetPlan(p) => p | |
+ } | |
+ assert(target.nonEmpty) | |
+ assert(target.forall(_.supportsColumnar == isColumnar)) | |
+ } | |
+ | |
+ Seq(true, false).foreach { supported => | |
+ withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> supported.toString) { | |
+ val df1 = Seq(1, 2, 3).toDF("i").cache() | |
+ val df2 = Seq(4, 5, 6).toDF("j").cache() | |
+ | |
+ val union = df1.union(df2) | |
+ checkIfColumnar(union.queryExecution.executedPlan, | |
+ _.isInstanceOf[InMemoryTableScanExec], supported) | |
+ checkIfColumnar(union.queryExecution.executedPlan, | |
+ _.isInstanceOf[InMemoryTableScanExec], supported) | |
+ checkIfColumnar(union.queryExecution.executedPlan, _.isInstanceOf[UnionExec], supported) | |
+ checkAnswer(union, Row(1) :: Row(2) :: Row(3) :: Row(4) :: Row(5) :: Row(6) :: Nil) | |
+ | |
+ val nonColumnarUnion = df1.union(Seq(7, 8, 9).toDF("k")) | |
+ checkIfColumnar(nonColumnarUnion.queryExecution.executedPlan, | |
+ _.isInstanceOf[UnionExec], false) | |
+ checkAnswer(nonColumnarUnion, | |
+ Row(1) :: Row(2) :: Row(3) :: Row(7) :: Row(8) :: Row(9) :: Nil) | |
+ } | |
+ } | |
+ } | |
} | |
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | |
index d9b75c7794..a696c3fd49 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | |
@@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} | |
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2} | |
+import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.types.CalendarInterval | |
import org.apache.spark.util.Utils | |
@@ -86,7 +86,9 @@ class DataFrameSuite extends QueryTest | |
test("access complex data") { | |
assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) | |
- assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) | |
+ if (!conf.ansiEnabled) { | |
+ assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) | |
+ } | |
assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) | |
} | |
@@ -480,7 +482,7 @@ class DataFrameSuite extends QueryTest | |
testData.select("key").coalesce(1).select("key"), | |
testData.select("key").collect().toSeq) | |
- assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) | |
+ assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) | |
} | |
test("convert $\"attribute name\" into unresolved attribute") { | |
@@ -631,7 +633,19 @@ class DataFrameSuite extends QueryTest | |
assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) | |
} | |
- test("withColumns") { | |
+ test("withColumns: public API, with Map input") { | |
+ val df = testData.toDF().withColumns(Map( | |
+ "newCol1" -> (col("key") + 1), "newCol2" -> (col("key") + 2) | |
+ )) | |
+ checkAnswer( | |
+ df, | |
+ testData.collect().map { case Row(key: Int, value: String) => | |
+ Row(key, value, key + 1, key + 2) | |
+ }.toSeq) | |
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) | |
+ } | |
+ | |
+ test("withColumns: internal method") { | |
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), | |
Seq(col("key") + 1, col("key") + 2)) | |
checkAnswer( | |
@@ -655,7 +669,7 @@ class DataFrameSuite extends QueryTest | |
assert(err2.getMessage.contains("Found duplicate column(s)")) | |
} | |
- test("withColumns: case sensitive") { | |
+ test("withColumns: internal method, case sensitive") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), | |
Seq(col("key") + 1, col("key") + 2)) | |
@@ -674,7 +688,7 @@ class DataFrameSuite extends QueryTest | |
} | |
} | |
- test("withColumns: given metadata") { | |
+ test("withColumns: internal method, given metadata") { | |
def buildMetadata(num: Int): Seq[Metadata] = { | |
(0 until num).map { n => | |
val builder = new MetadataBuilder | |
@@ -702,6 +716,18 @@ class DataFrameSuite extends QueryTest | |
"The size of column names: 2 isn't equal to the size of metadata elements: 1")) | |
} | |
+ test("SPARK-36642: withMetadata: replace metadata of a column") { | |
+ val metadata = new MetadataBuilder().putLong("key", 1L).build() | |
+ val df1 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") | |
+ val df2 = df1.withMetadata("x", metadata) | |
+ assert(df2.schema(0).metadata === metadata) | |
+ | |
+ val err = intercept[AnalysisException] { | |
+ df1.withMetadata("x1", metadata) | |
+ } | |
+ assert(err.getMessage.contains("Cannot resolve column name")) | |
+ } | |
+ | |
test("replace column using withColumn") { | |
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") | |
val df3 = df2.withColumn("x", df2("x") + 1) | |
@@ -834,6 +860,56 @@ class DataFrameSuite extends QueryTest | |
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) | |
} | |
+ test("SPARK-20384: Value class filter") { | |
+ val df = spark.sparkContext | |
+ .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) | |
+ .toDF() | |
+ val filtered = df.where("s = \"a\"") | |
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF) | |
+ } | |
+ | |
+ test("SPARK-20384: Tuple2 of value class filter") { | |
+ val df = spark.sparkContext | |
+ .parallelize(Seq( | |
+ (StringWrapper("a1"), StringWrapper("a2")), | |
+ (StringWrapper("b1"), StringWrapper("b2")))) | |
+ .toDF() | |
+ val filtered = df.where("_2.s = \"a2\"") | |
+ checkAnswer(filtered, | |
+ spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF) | |
+ } | |
+ | |
+ test("SPARK-20384: Tuple3 of value class filter") { | |
+ val df = spark.sparkContext | |
+ .parallelize(Seq( | |
+ (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")), | |
+ (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3")))) | |
+ .toDF() | |
+ val filtered = df.where("_3.s = \"a3\"") | |
+ checkAnswer(filtered, | |
+ spark.sparkContext.parallelize( | |
+ Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF) | |
+ } | |
+ | |
+ test("SPARK-20384: Array value class filter") { | |
+ val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b"))) | |
+ val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d"))) | |
+ | |
+ val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF | |
+ val filtered = df.where(array_contains(col("wrappers.s"), "b")) | |
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF) | |
+ } | |
+ | |
+ test("SPARK-20384: Nested value class filter") { | |
+ val a = ContainerStringWrapper(StringWrapper("a")) | |
+ val b = ContainerStringWrapper(StringWrapper("b")) | |
+ | |
+ val df = spark.sparkContext.parallelize(Seq(a, b)).toDF | |
+ // flat value class, `s` field is not in schema | |
+ val filtered = df.where("wrapper = \"a\"") | |
+ checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF) | |
+ } | |
+ | |
private lazy val person2: DataFrame = Seq( | |
("Bob", 16, 176), | |
("Alice", 32, 164), | |
@@ -1489,7 +1565,9 @@ class DataFrameSuite extends QueryTest | |
test("SPARK-7133: Implement struct, array, and map field accessor") { | |
assert(complexData.filter(complexData("a")(0) === 2).count() == 1) | |
- assert(complexData.filter(complexData("m")("1") === 1).count() == 1) | |
+ if (!conf.ansiEnabled) { | |
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1) | |
+ } | |
assert(complexData.filter(complexData("s")("key") === 1).count() == 1) | |
assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) | |
assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1) | |
@@ -2384,8 +2462,10 @@ class DataFrameSuite extends QueryTest | |
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) | |
checkAnswer(aggPlusSort1, aggPlusSort2.collect()) | |
- val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) | |
- val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) | |
+ val aggPlusFilter1 = | |
+ df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === "test1") | |
+ val aggPlusFilter2 = | |
+ df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === "test1") | |
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) | |
} | |
} | |
@@ -2569,7 +2649,8 @@ class DataFrameSuite extends QueryTest | |
val err = intercept[AnalysisException] { | |
df.groupBy($"d", $"b").as[GroupByKey, Row] | |
} | |
- assert(err.getMessage.contains("cannot resolve 'd'")) | |
+ assert(err.getErrorClass == "MISSING_COLUMN") | |
+ assert(err.messageParameters.head == "d") | |
} | |
test("emptyDataFrame should be foldable") { | |
@@ -2852,6 +2933,25 @@ class DataFrameSuite extends QueryTest | |
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil) | |
} | |
+ test("SPARK-39293: The accumulator of ArrayAggregate to handle complex types properly") { | |
+ val reverse = udf((s: String) => s.reverse) | |
+ | |
+ val df = Seq(Array("abc", "def")).toDF("array") | |
+ val testArray = df.select( | |
+ aggregate( | |
+ col("array"), | |
+ array().cast("array<string>"), | |
+ (acc, s) => concat(acc, array(reverse(s))))) | |
+ checkAnswer(testArray, Row(Array("cba", "fed")) :: Nil) | |
+ | |
+ val testMap = df.select( | |
+ aggregate( | |
+ col("array"), | |
+ map().cast("map<string, string>"), | |
+ (acc, s) => map_concat(acc, map(s, reverse(s))))) | |
+ checkAnswer(testMap, Row(Map("abc" -> "cba", "def" -> "fed")) :: Nil) | |
+ } | |
+ | |
test("SPARK-34882: Aggregate with multiple distinct null sensitive aggregators") { | |
withUserDefinedFunction(("countNulls", true)) { | |
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong, JLong] { | |
@@ -2952,6 +3052,33 @@ class DataFrameSuite extends QueryTest | |
assert(ids.toSet === Range(0, 10).toSet) | |
} | |
+ test("SPARK-35320: Reading JSON with key type different to String in a map should fail") { | |
+ Seq( | |
+ (MapType(IntegerType, StringType), """{"1": "test"}"""), | |
+ (StructType(Seq(StructField("test", MapType(IntegerType, StringType)))), | |
+ """"test": {"1": "test"}"""), | |
+ (ArrayType(MapType(IntegerType, StringType)), """[{"1": "test"}]"""), | |
+ (MapType(StringType, MapType(IntegerType, StringType)), """{"key": {"1" : "test"}}""") | |
+ ).foreach { case (schema, jsonData) => | |
+ withTempDir { dir => | |
+ val colName = "col" | |
+ val msg = "can only contain STRING as a key type for a MAP" | |
+ | |
+ val thrown1 = intercept[AnalysisException]( | |
+ spark.read.schema(StructType(Seq(StructField(colName, schema)))) | |
+ .json(Seq(jsonData).toDS()).collect()) | |
+ assert(thrown1.getMessage.contains(msg)) | |
+ | |
+ val jsonDir = new File(dir, "json").getCanonicalPath | |
+ Seq(jsonData).toDF(colName).write.json(jsonDir) | |
+ val thrown2 = intercept[AnalysisException]( | |
+ spark.read.schema(StructType(Seq(StructField(colName, schema)))) | |
+ .json(jsonDir).collect()) | |
+ assert(thrown2.getMessage.contains(msg)) | |
+ } | |
+ } | |
+ } | |
+ | |
test("SPARK-37855: IllegalStateException when transforming an array inside a nested struct") { | |
def makeInput(): DataFrame = { | |
val innerElement1 = Row(3, 3.12) | |
@@ -3088,6 +3215,79 @@ class DataFrameSuite extends QueryTest | |
} | |
} | |
} | |
+ | |
+ test("SPARK-39612: exceptAll with following count should work") { | |
+ val d1 = Seq("a").toDF | |
+ assert(d1.exceptAll(d1).count() === 0) | |
+ } | |
+ | |
+ test("SPARK-39887: RemoveRedundantAliases should keep attributes of a Union's first child") { | |
+ val df = sql( | |
+ """ | |
+ |SELECT a, b AS a FROM ( | |
+ | SELECT a, a AS b FROM (SELECT a FROM VALUES (1) AS t(a)) | |
+ | UNION ALL | |
+ | SELECT a, b FROM (SELECT a, b FROM VALUES (1, 2) AS t(a, b)) | |
+ |) | |
+ |""".stripMargin) | |
+ val stringCols = df.logicalPlan.output.map(Column(_).cast(StringType)) | |
+ val castedDf = df.select(stringCols: _*) | |
+ checkAnswer(castedDf, Row("1", "1") :: Row("1", "2") :: Nil) | |
+ } | |
+ | |
+ test("SPARK-39887: RemoveRedundantAliases should keep attributes of a Union's first child 2") { | |
+ val df = sql( | |
+ """ | |
+ |SELECT | |
+ | to_date(a) a, | |
+ | to_date(b) b | |
+ |FROM | |
+ | ( | |
+ | SELECT | |
+ | a, | |
+ | a AS b | |
+ | FROM | |
+ | ( | |
+ | SELECT | |
+ | to_date(a) a | |
+ | FROM | |
+ | VALUES | |
+ | ('2020-02-01') AS t1(a) | |
+ | GROUP BY | |
+ | to_date(a) | |
+ | ) t3 | |
+ | UNION ALL | |
+ | SELECT | |
+ | a, | |
+ | b | |
+ | FROM | |
+ | ( | |
+ | SELECT | |
+ | to_date(a) a, | |
+ | to_date(b) b | |
+ | FROM | |
+ | VALUES | |
+ | ('2020-01-01', '2020-01-02') AS t1(a, b) | |
+ | GROUP BY | |
+ | to_date(a), | |
+ | to_date(b) | |
+ | ) t4 | |
+ | ) t5 | |
+ |GROUP BY | |
+ | to_date(a), | |
+ | to_date(b); | |
+ |""".stripMargin) | |
+ checkAnswer(df, | |
+ Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) :: | |
+ Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { | |
+ val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2) | |
+ assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2) | |
+ } | |
+ } | |
} | |
case class GroupByKey(a: Int, b: Int) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala | |
index c385d9f58c..bd39453f51 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala | |
@@ -19,8 +19,9 @@ package org.apache.spark.sql | |
import java.time.LocalDateTime | |
+import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.catalyst.expressions.AttributeReference | |
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} | |
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types._ | |
@@ -490,4 +491,88 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { | |
assert(attributeReference.dataType == tuple._2) | |
} | |
} | |
+ | |
+ test("No need to filter windows when windowDuration is multiple of slideDuration") { | |
+ val df1 = Seq( | |
+ ("2022-02-15 19:39:34", 1, "a"), | |
+ ("2022-02-15 19:39:56", 2, "a"), | |
+ ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id") | |
+ .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value") | |
+ .orderBy($"window.start".asc, $"value".desc).select("value") | |
+ val df2 = Seq( | |
+ (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"), | |
+ (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"), | |
+ (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id") | |
+ .select(window($"time", "9 seconds", "3 seconds", "0 second"), $"value") | |
+ .orderBy($"window.start".asc, $"value".desc).select("value") | |
+ | |
+ val df3 = Seq( | |
+ ("2022-02-15 19:39:34", 1, "a"), | |
+ ("2022-02-15 19:39:56", 2, "a"), | |
+ ("2022-02-15 19:39:27", 4, "b")).toDF("time", "value", "id") | |
+ .select(window($"time", "9 seconds", "3 seconds", "-2 second"), $"value") | |
+ .orderBy($"window.start".asc, $"value".desc).select("value") | |
+ val df4 = Seq( | |
+ (LocalDateTime.parse("2022-02-15T19:39:34"), 1, "a"), | |
+ (LocalDateTime.parse("2022-02-15T19:39:56"), 2, "a"), | |
+ (LocalDateTime.parse("2022-02-15T19:39:27"), 4, "b")).toDF("time", "value", "id") | |
+ .select(window($"time", "9 seconds", "3 seconds", "2 second"), $"value") | |
+ .orderBy($"window.start".asc, $"value".desc).select("value") | |
+ | |
+ Seq(df1, df2, df3, df4).foreach { df => | |
+ val filter = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) | |
+ assert(filter.isDefined) | |
+ val exist = filter.get.constraints.filter(e => | |
+ e.toString.contains(">=") || e.toString.contains("<")) | |
+ assert(exist.isEmpty, "No need to filter windows " + | |
+ "when windowDuration is multiple of slideDuration") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38227: 'start' and 'end' fields should be nullable") { | |
+ // We expect the fields in window struct as nullable since the dataType of TimeWindow defines | |
+ // them as nullable. The rule 'TimeWindowing' should respect the dataType. | |
+ val df1 = Seq( | |
+ ("2016-03-27 09:00:05", 1), | |
+ ("2016-03-27 09:00:32", 2)).toDF("time", "value") | |
+ val df2 = Seq( | |
+ (LocalDateTime.parse("2016-03-27T09:00:05"), 1), | |
+ (LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value") | |
+ | |
+ def validateWindowColumnInSchema(schema: StructType, colName: String): Unit = { | |
+ schema.find(_.name == colName) match { | |
+ case Some(StructField(_, st: StructType, _, _)) => | |
+ assertFieldInWindowStruct(st, "start") | |
+ assertFieldInWindowStruct(st, "end") | |
+ | |
+ case _ => fail("Failed to find suitable window column from DataFrame!") | |
+ } | |
+ } | |
+ | |
+ def assertFieldInWindowStruct(windowType: StructType, fieldName: String): Unit = { | |
+ val field = windowType.fields.find(_.name == fieldName) | |
+ assert(field.isDefined, s"'$fieldName' field should exist in window struct") | |
+ assert(field.get.nullable, s"'$fieldName' field should be nullable") | |
+ } | |
+ | |
+ for { | |
+ df <- Seq(df1, df2) | |
+ nullable <- Seq(true, false) | |
+ } { | |
+ val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( | |
+ StructType(df.schema.fields.map(_.copy(nullable = nullable))))) | |
+ // tumbling windows | |
+ val windowedProject = dfWithDesiredNullability | |
+ .select(window($"time", "10 seconds").as("window"), $"value") | |
+ val schema = windowedProject.queryExecution.optimizedPlan.schema | |
+ validateWindowColumnInSchema(schema, "window") | |
+ | |
+ // sliding windows | |
+ val windowedProject2 = dfWithDesiredNullability | |
+ .select(window($"time", "10 seconds", "3 seconds").as("window"), | |
+ $"value") | |
+ val schema2 = windowedProject2.queryExecution.optimizedPlan.schema | |
+ validateWindowColumnInSchema(schema2, "window") | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala | |
index 666bf739ca..e57650ff62 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala | |
@@ -20,9 +20,12 @@ package org.apache.spark.sql | |
import org.scalatest.matchers.must.Matchers.the | |
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} | |
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} | |
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow | |
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning | |
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
-import org.apache.spark.sql.execution.exchange.Exchange | |
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} | |
+import org.apache.spark.sql.execution.window.WindowExec | |
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
@@ -94,7 +97,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest | |
} | |
test("corr, covar_pop, stddev_pop functions in specific window") { | |
- withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { | |
+ withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true", | |
+ SQLConf.ANSI_ENABLED.key -> "false") { | |
val df = Seq( | |
("a", "p1", 10.0, 20.0), | |
("b", "p1", 20.0, 10.0), | |
@@ -147,7 +151,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest | |
test("SPARK-13860: " + | |
"corr, covar_pop, stddev_pop functions in specific window " + | |
"LEGACY_STATISTICAL_AGGREGATE off") { | |
- withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") { | |
+ withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false", | |
+ SQLConf.ANSI_ENABLED.key -> "false") { | |
val df = Seq( | |
("a", "p1", 10.0, 20.0), | |
("b", "p1", 20.0, 10.0), | |
@@ -399,26 +404,29 @@ class DataFrameWindowFunctionsSuite extends QueryTest | |
val df = Seq((1, "1")).toDF("key", "value") | |
val e = intercept[AnalysisException]( | |
df.select($"key", count("invalid").over())) | |
- assert(e.message.contains("cannot resolve 'invalid' given input columns: [key, value]")) | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("invalid", "value, key"))) | |
} | |
test("numerical aggregate functions on string column") { | |
- val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") | |
- checkAnswer( | |
- df.select($"key", | |
- var_pop("value1").over(), | |
- variance("value1").over(), | |
- stddev_pop("value1").over(), | |
- stddev("value1").over(), | |
- sum("value1").over(), | |
- mean("value1").over(), | |
- avg("value1").over(), | |
- corr("value1", "value2").over(), | |
- covar_pop("value1", "value2").over(), | |
- covar_samp("value1", "value2").over(), | |
- skewness("value1").over(), | |
- kurtosis("value1").over()), | |
- Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) | |
+ if (!conf.ansiEnabled) { | |
+ val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") | |
+ checkAnswer( | |
+ df.select($"key", | |
+ var_pop("value1").over(), | |
+ variance("value1").over(), | |
+ stddev_pop("value1").over(), | |
+ stddev("value1").over(), | |
+ sum("value1").over(), | |
+ mean("value1").over(), | |
+ avg("value1").over(), | |
+ corr("value1", "value2").over(), | |
+ covar_pop("value1", "value2").over(), | |
+ covar_samp("value1", "value2").over(), | |
+ skewness("value1").over(), | |
+ kurtosis("value1").over()), | |
+ Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) | |
+ } | |
} | |
test("statistical functions") { | |
@@ -702,6 +710,50 @@ class DataFrameWindowFunctionsSuite extends QueryTest | |
Row("a", 4, "x", "x", "y", "x", "x", "y"), | |
Row("b", 1, null, null, null, null, null, null), | |
Row("b", 2, null, null, null, null, null, null))) | |
+ | |
+ val df2 = Seq( | |
+ ("a", 1, "x"), | |
+ ("a", 2, "y"), | |
+ ("a", 3, "z")). | |
+ toDF("key", "order", "value") | |
+ checkAnswer( | |
+ df2.select( | |
+ $"key", | |
+ $"order", | |
+ nth_value($"value", 2).over(window1), | |
+ nth_value($"value", 2, ignoreNulls = true).over(window1), | |
+ nth_value($"value", 2).over(window2), | |
+ nth_value($"value", 2, ignoreNulls = true).over(window2), | |
+ nth_value($"value", 3).over(window1), | |
+ nth_value($"value", 3, ignoreNulls = true).over(window1), | |
+ nth_value($"value", 3).over(window2), | |
+ nth_value($"value", 3, ignoreNulls = true).over(window2), | |
+ nth_value($"value", 4).over(window1), | |
+ nth_value($"value", 4, ignoreNulls = true).over(window1), | |
+ nth_value($"value", 4).over(window2), | |
+ nth_value($"value", 4, ignoreNulls = true).over(window2)), | |
+ Seq( | |
+ Row("a", 1, "y", "y", null, null, "z", "z", null, null, null, null, null, null), | |
+ Row("a", 2, "y", "y", "y", "y", "z", "z", null, null, null, null, null, null), | |
+ Row("a", 3, "y", "y", "y", "y", "z", "z", "z", "z", null, null, null, null))) | |
+ | |
+ val df3 = Seq( | |
+ ("a", 1, "x"), | |
+ ("a", 2, nullStr), | |
+ ("a", 3, "z")). | |
+ toDF("key", "order", "value") | |
+ checkAnswer( | |
+ df3.select( | |
+ $"key", | |
+ $"order", | |
+ nth_value($"value", 3).over(window1), | |
+ nth_value($"value", 3, ignoreNulls = true).over(window1), | |
+ nth_value($"value", 3).over(window2), | |
+ nth_value($"value", 3, ignoreNulls = true).over(window2)), | |
+ Seq( | |
+ Row("a", 1, "z", null, null, null), | |
+ Row("a", 2, "z", null, null, null), | |
+ Row("a", 3, "z", null, "z", null))) | |
} | |
test("nth_value on descending ordered window") { | |
@@ -1070,4 +1122,98 @@ class DataFrameWindowFunctionsSuite extends QueryTest | |
Row("a", 1, "x", "x"), | |
Row("b", 0, null, null))) | |
} | |
+ | |
+ test("SPARK-38237: require all cluster keys for child required distribution for window query") { | |
+ def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = { | |
+ expressions.flatMap { | |
+ case ref: AttributeReference => Some(ref.name) | |
+ } | |
+ } | |
+ | |
+ def isShuffleExecByRequirement( | |
+ plan: ShuffleExchangeExec, | |
+ desiredClusterColumns: Seq[String]): Boolean = plan match { | |
+ case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) => | |
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns | |
+ case _ => false | |
+ } | |
+ | |
+ val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value") | |
+ val windowSpec = Window.partitionBy("key1", "key2").orderBy("value") | |
+ | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", | |
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") { | |
+ | |
+ val windowed = df | |
+ // repartition by subset of window partitionBy keys which satisfies ClusteredDistribution | |
+ .repartition($"key1") | |
+ .select( | |
+ lead($"key1", 1).over(windowSpec), | |
+ lead($"value", 1).over(windowSpec)) | |
+ | |
+ checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), Row(null, null))) | |
+ | |
+ val shuffleByRequirement = windowed.queryExecution.executedPlan.exists { | |
+ case w: WindowExec => | |
+ w.child.exists { | |
+ case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2")) | |
+ case _ => false | |
+ } | |
+ case _ => false | |
+ } | |
+ | |
+ assert(shuffleByRequirement, "Can't find desired shuffle node from the query plan") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38308: Properly handle Stream of window expressions") { | |
+ val df = Seq( | |
+ (1, 2, 3), | |
+ (1, 3, 4), | |
+ (2, 4, 5), | |
+ (2, 5, 6) | |
+ ).toDF("a", "b", "c") | |
+ | |
+ val w = Window.partitionBy("a").orderBy("b") | |
+ val selectExprs = Stream( | |
+ sum("c").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)).as("sumc"), | |
+ avg("c").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)).as("avgc") | |
+ ) | |
+ checkAnswer( | |
+ df.select(selectExprs: _*), | |
+ Seq( | |
+ Row(3, 3), | |
+ Row(7, 3.5), | |
+ Row(5, 5), | |
+ Row(11, 5.5) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("SPARK-38614: percent_rank should apply before limit") { | |
+ val df = Seq.tabulate(101)(identity).toDF("id") | |
+ val w = Window.orderBy("id") | |
+ checkAnswer( | |
+ df.select($"id", percent_rank().over(w)).limit(3), | |
+ Seq( | |
+ Row(0, 0.0d), | |
+ Row(1, 0.01d), | |
+ Row(2, 0.02d) | |
+ ) | |
+ ) | |
+ } | |
+ | |
+ test("SPARK-40002: ntile should apply before limit") { | |
+ val df = Seq.tabulate(101)(identity).toDF("id") | |
+ val w = Window.orderBy("id") | |
+ checkAnswer( | |
+ df.select($"id", ntile(10).over(w)).limit(3), | |
+ Seq( | |
+ Row(0, 1), | |
+ Row(1, 1), | |
+ Row(2, 1) | |
+ ) | |
+ ) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala | |
index 8aef27a1b6..86108a81da 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala | |
@@ -23,12 +23,15 @@ import scala.collection.JavaConverters._ | |
import org.scalatest.BeforeAndAfter | |
+import org.apache.spark.sql.catalyst.TableIdentifier | |
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException} | |
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} | |
+import org.apache.spark.sql.connector.InMemoryV1Provider | |
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog} | |
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} | |
import org.apache.spark.sql.execution.QueryExecution | |
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation | |
+import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.sources.FakeSourceOne | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} | |
@@ -531,6 +534,23 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo | |
assert(table.properties === (Map("provider" -> "foo") ++ defaultOwnership).asJava) | |
} | |
+ test("SPARK-39543 writeOption should be passed to storage properties when fallback to v1") { | |
+ val provider = classOf[InMemoryV1Provider].getName | |
+ | |
+ withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, provider)) { | |
+ spark.range(10) | |
+ .writeTo("table_name") | |
+ .option("compression", "zstd").option("name", "table_name") | |
+ .using(provider) | |
+ .create() | |
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("table_name")) | |
+ | |
+ assert(table.identifier === TableIdentifier("table_name", Some("default"))) | |
+ assert(table.storage.properties.contains("compression")) | |
+ assert(table.storage.properties.getOrElse("compression", "foo") == "zstd") | |
+ } | |
+ } | |
+ | |
test("Replace: basic behavior") { | |
spark.sql( | |
"CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | |
index 009ccb9a45..2f4098d7cc 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | |
@@ -250,7 +250,7 @@ class DatasetCacheSuite extends QueryTest | |
case i: InMemoryRelation => i.cacheBuilder.cachedPlan | |
} | |
assert(df2LimitInnerPlan.isDefined && | |
- df2LimitInnerPlan.get.find(_.isInstanceOf[InMemoryTableScanExec]).isEmpty) | |
+ !df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec])) | |
} | |
test("SPARK-27739 Save stats from optimized plan") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | |
index 347e9fc08a..f5e736621e 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | |
@@ -20,11 +20,14 @@ package org.apache.spark.sql | |
import java.io.{Externalizable, ObjectInput, ObjectOutput} | |
import java.sql.{Date, Timestamp} | |
+import org.apache.hadoop.fs.{Path, PathFilter} | |
import org.scalatest.Assertions._ | |
import org.scalatest.exceptions.TestFailedException | |
import org.scalatest.prop.TableDrivenPropertyChecks._ | |
import org.apache.spark.{SparkException, TaskContext} | |
+import org.apache.spark.TestUtils.withListener | |
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} | |
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} | |
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} | |
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} | |
@@ -322,23 +325,25 @@ class DatasetSuite extends QueryTest | |
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { | |
var e = intercept[AnalysisException] { | |
ds.select(expr("`(_1)?+.+`").as[Int]) | |
- }.getMessage | |
- assert(e.contains("cannot resolve '`(_1)?+.+`'")) | |
+ } | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.head == "`(_1)?+.+`") | |
e = intercept[AnalysisException] { | |
ds.select(expr("`(_1|_2)`").as[Int]) | |
- }.getMessage | |
- assert(e.contains("cannot resolve '`(_1|_2)`'")) | |
+ } | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.head == "`(_1|_2)`") | |
e = intercept[AnalysisException] { | |
ds.select(ds("`(_1)?+.+`")) | |
- }.getMessage | |
- assert(e.contains("Cannot resolve column name \"`(_1)?+.+`\"")) | |
+ } | |
+ assert(e.getMessage.contains("Cannot resolve column name \"`(_1)?+.+`\"")) | |
e = intercept[AnalysisException] { | |
ds.select(ds("`(_1|_2)`")) | |
- }.getMessage | |
- assert(e.contains("Cannot resolve column name \"`(_1|_2)`\"")) | |
+ } | |
+ assert(e.getMessage.contains("Cannot resolve column name \"`(_1|_2)`\"")) | |
} | |
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { | |
@@ -703,6 +708,72 @@ class DatasetSuite extends QueryTest | |
1 -> "a", 2 -> "bc", 3 -> "d") | |
} | |
+ test("SPARK-34806: observation on datasets") { | |
+ val namedObservation = Observation("named") | |
+ val unnamedObservation = Observation() | |
+ | |
+ val df = spark.range(100) | |
+ val observed_df = df | |
+ .observe( | |
+ namedObservation, | |
+ min($"id").as("min_val"), | |
+ max($"id").as("max_val"), | |
+ sum($"id").as("sum_val"), | |
+ count(when($"id" % 2 === 0, 1)).as("num_even") | |
+ ) | |
+ .observe( | |
+ unnamedObservation, | |
+ avg($"id").cast("int").as("avg_val") | |
+ ) | |
+ | |
+ def checkMetrics(namedMetric: Observation, unnamedMetric: Observation): Unit = { | |
+ assert(namedMetric.get === Map( | |
+ "min_val" -> 0L, "max_val" -> 99L, "sum_val" -> 4950L, "num_even" -> 50L) | |
+ ) | |
+ assert(unnamedMetric.get === Map("avg_val" -> 49)) | |
+ } | |
+ | |
+ observed_df.collect() | |
+ // we can get the result multiple times | |
+ checkMetrics(namedObservation, unnamedObservation) | |
+ checkMetrics(namedObservation, unnamedObservation) | |
+ | |
+ // an observation can be used only once | |
+ val err = intercept[IllegalArgumentException] { | |
+ df.observe(namedObservation, sum($"id").as("sum_val")) | |
+ } | |
+ assert(err.getMessage.contains("An Observation can be used with a Dataset only once")) | |
+ | |
+ // streaming datasets are not supported | |
+ val streamDf = new MemoryStream[Int](0, sqlContext).toDF() | |
+ val streamObservation = Observation("stream") | |
+ val streamErr = intercept[IllegalArgumentException] { | |
+ streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val")) | |
+ } | |
+ assert(streamErr.getMessage.contains("Observation does not support streaming Datasets")) | |
+ | |
+ // an observation cannot have an empty name | |
+ val err2 = intercept[IllegalArgumentException] { | |
+ Observation("") | |
+ } | |
+ assert(err2.getMessage.contains("Name must not be empty")) | |
+ } | |
+ | |
+ test("SPARK-37203: Fix NotSerializableException when observe with TypedImperativeAggregate") { | |
+ def observe[T](df: Dataset[T], expected: Map[String, _]): Unit = { | |
+ val namedObservation = Observation("named") | |
+ val observed_df = df.observe( | |
+ namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) | |
+ observed_df.collect() | |
+ assert(namedObservation.get === expected) | |
+ } | |
+ | |
+ observe(spark.range(100), Map("percentile_approx_val" -> 49)) | |
+ observe(spark.range(0), Map("percentile_approx_val" -> null)) | |
+ observe(spark.range(1, 10), Map("percentile_approx_val" -> 5)) | |
+ observe(spark.range(1, 10, 1, 11), Map("percentile_approx_val" -> 5)) | |
+ } | |
+ | |
test("sample with replacement") { | |
val n = 100 | |
val data = sparkContext.parallelize(1 to n, 2).toDS() | |
@@ -866,7 +937,8 @@ class DatasetSuite extends QueryTest | |
val e = intercept[AnalysisException] { | |
ds.as[ClassData2] | |
} | |
- assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("c", "a, b"))) | |
} | |
test("runtime nullability check") { | |
@@ -933,24 +1005,6 @@ class DatasetSuite extends QueryTest | |
checkDataset(cogrouped, "a13", "b24") | |
} | |
- test("give nice error message when the real number of fields doesn't match encoder schema") { | |
- val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() | |
- | |
- val message = intercept[AnalysisException] { | |
- ds.as[(String, Int, Long)] | |
- }.message | |
- assert(message == | |
- "Try to map struct<a:string,b:int> to Tuple3, " + | |
- "but failed as the number of fields does not line up.") | |
- | |
- val message2 = intercept[AnalysisException] { | |
- ds.as[Tuple1[String]] | |
- }.message | |
- assert(message2 == | |
- "Try to map struct<a:string,b:int> to Tuple1, " + | |
- "but failed as the number of fields does not line up.") | |
- } | |
- | |
test("SPARK-13440: Resolving option fields") { | |
val df = Seq(1, 2, 3).toDS() | |
val ds = df.as[Option[Int]] | |
@@ -1392,8 +1446,29 @@ class DatasetSuite extends QueryTest | |
} | |
testCheckpointing("basic") { | |
- val ds = spark.range(10).repartition($"id" % 2).filter($"id" > 5).orderBy($"id".desc) | |
- val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) | |
+ val ds = spark | |
+ .range(10) | |
+ // Num partitions is set to 1 to avoid a RangePartitioner in the orderBy below | |
+ .repartition(1, $"id" % 2) | |
+ .filter($"id" > 5) | |
+ .orderBy($"id".desc) | |
+ @volatile var jobCounter = 0 | |
+ val listener = new SparkListener { | |
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = { | |
+ jobCounter += 1 | |
+ } | |
+ } | |
+ var cp = ds | |
+ withListener(spark.sparkContext, listener) { _ => | |
+ // AQE adds a job per shuffle. The expression above does multiple shuffles and | |
+ // that screws up the job counting | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { | |
+ cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) | |
+ } | |
+ } | |
+ if (eager) { | |
+ assert(jobCounter === 1) | |
+ } | |
val logicalRDD = cp.logicalPlan match { | |
case plan: LogicalRDD => plan | |
@@ -1861,7 +1936,7 @@ class DatasetSuite extends QueryTest | |
.map(b => b - 1) | |
.collect() | |
} | |
- assert(thrownException.message.contains("Cannot up cast id from bigint to tinyint")) | |
+ assert(thrownException.message.contains("""Cannot up cast id from "BIGINT" to "TINYINT"""")) | |
} | |
test("SPARK-26690: checkpoints should be executed with an execution id") { | |
@@ -2060,6 +2135,23 @@ class DatasetSuite extends QueryTest | |
(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), | |
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) | |
} | |
+ | |
+ test("SPARK-40407: repartition should not result in severe data skew") { | |
+ val df = spark.range(0, 100, 1, 50).repartition(4) | |
+ val result = df.mapPartitions(iter => Iterator.single(iter.length)).collect() | |
+ assert(result.sorted.toSeq === Seq(23, 25, 25, 27)) | |
+ } | |
+ | |
+ test("SPARK-40660: Switch to XORShiftRandom to distribute elements") { | |
+ withTempDir { dir => | |
+ spark.range(10).repartition(10).write.mode(SaveMode.Overwrite).parquet(dir.getCanonicalPath) | |
+ val fs = new Path(dir.getAbsolutePath).getFileSystem(spark.sessionState.newHadoopConf()) | |
+ val parquetFiles = fs.listStatus(new Path(dir.getAbsolutePath), new PathFilter { | |
+ override def accept(path: Path): Boolean = path.getName.endsWith("parquet") | |
+ }) | |
+ assert(parquetFiles.size === 10) | |
+ } | |
+ } | |
} | |
case class Bar(a: Int) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | |
index 543f845aff..fa246fa79b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | |
@@ -23,7 +23,7 @@ import java.time.{Instant, LocalDateTime, ZoneId} | |
import java.util.{Locale, TimeZone} | |
import java.util.concurrent.TimeUnit | |
-import org.apache.spark.{SparkException, SparkUpgradeException} | |
+import org.apache.spark.{SparkConf, SparkException, SparkUpgradeException} | |
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{CEST, LA} | |
import org.apache.spark.sql.catalyst.util.DateTimeUtils | |
import org.apache.spark.sql.functions._ | |
@@ -35,6 +35,10 @@ import org.apache.spark.unsafe.types.CalendarInterval | |
class DateFunctionsSuite extends QueryTest with SharedSparkSession { | |
import testImplicits._ | |
+ // The test cases which throw exceptions under ANSI mode are covered by date.sql and | |
+ // datetime-parsing-invalid.sql in org.apache.spark.sql.SQLQueryTestSuite. | |
+ override def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false") | |
+ | |
test("function current_date") { | |
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") | |
val d0 = DateTimeUtils.currentDate(ZoneId.systemDefault()) | |
@@ -512,7 +516,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { | |
Seq(Row(null), Row(null), Row(null))) | |
val e = intercept[SparkUpgradeException](df.select(to_date(col("s"), "yyyy-dd-aa")).collect()) | |
assert(e.getCause.isInstanceOf[IllegalArgumentException]) | |
- assert(e.getMessage.contains("You may get a different result due to the upgrading of Spark")) | |
+ assert(e.getMessage.contains("You may get a different result due to the upgrading to Spark")) | |
// February | |
val x1 = "2016-02-29" | |
@@ -695,7 +699,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { | |
val e = intercept[SparkUpgradeException](invalid.collect()) | |
assert(e.getCause.isInstanceOf[IllegalArgumentException]) | |
assert( | |
- e.getMessage.contains("You may get a different result due to the upgrading of Spark")) | |
+ e.getMessage.contains("You may get a different result due to the upgrading to Spark")) | |
} | |
// February | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala | |
index 1033d43a1b..d5498c469c 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala | |
@@ -31,14 +31,14 @@ import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec | |
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
-import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} | |
/** | |
* Test suite for the filtering ratio policy used to trigger dynamic partition pruning (DPP). | |
*/ | |
abstract class DynamicPartitionPruningSuiteBase | |
extends QueryTest | |
- with SharedSparkSession | |
+ with SQLTestUtils | |
with GivenWhenThen | |
with AdaptiveSparkPlanHelper { | |
@@ -49,7 +49,7 @@ abstract class DynamicPartitionPruningSuiteBase | |
protected def initState(): Unit = {} | |
protected def runAnalyzeColumnCommands: Boolean = true | |
- override def beforeAll(): Unit = { | |
+ override protected def beforeAll(): Unit = { | |
super.beforeAll() | |
initState() | |
@@ -107,6 +107,10 @@ abstract class DynamicPartitionPruningSuiteBase | |
(6, 60) | |
) | |
+ if (tableFormat == "hive") { | |
+ spark.sql("set hive.exec.dynamic.partition.mode=nonstrict") | |
+ } | |
+ | |
spark.range(1000) | |
.select($"id" as "product_id", ($"id" % 10) as "store_id", ($"id" + 1) as "code") | |
.write | |
@@ -150,11 +154,12 @@ abstract class DynamicPartitionPruningSuiteBase | |
if (runAnalyzeColumnCommands) { | |
sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id") | |
sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id") | |
+ sql("ANALYZE TABLE dim_store COMPUTE STATISTICS FOR COLUMNS store_id") | |
sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id") | |
} | |
} | |
- override def afterAll(): Unit = { | |
+ override protected def afterAll(): Unit = { | |
try { | |
sql("DROP TABLE IF EXISTS fact_np") | |
sql("DROP TABLE IF EXISTS fact_sk") | |
@@ -183,11 +188,11 @@ abstract class DynamicPartitionPruningSuiteBase | |
val plan = df.queryExecution.executedPlan | |
val dpExprs = collectDynamicPruningExpressions(plan) | |
val hasSubquery = dpExprs.exists { | |
- case InSubqueryExec(_, _: SubqueryExec, _, _) => true | |
+ case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true | |
case _ => false | |
} | |
val subqueryBroadcast = dpExprs.collect { | |
- case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _) => b | |
+ case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => b | |
} | |
val hasFilter = if (withSubquery) "Should" else "Shouldn't" | |
@@ -202,10 +207,10 @@ abstract class DynamicPartitionPruningSuiteBase | |
case _: ReusedExchangeExec => // reuse check ok. | |
case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok. | |
case b: BroadcastExchangeLike => | |
- val hasReuse = plan.find { | |
+ val hasReuse = plan.exists { | |
case ReusedExchangeExec(_, e) => e eq b | |
case _ => false | |
- }.isDefined | |
+ } | |
assert(hasReuse, s"$s\nshould have been reused in\n$plan") | |
case a: AdaptiveSparkPlanExec => | |
val broadcastQueryStage = collectFirst(a) { | |
@@ -229,7 +234,7 @@ abstract class DynamicPartitionPruningSuiteBase | |
case r: ReusedSubqueryExec => r.child | |
case o => o | |
} | |
- assert(subquery.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive) | |
+ assert(subquery.exists(_.isInstanceOf[AdaptiveSparkPlanExec]) == isMainQueryAdaptive) | |
} | |
} | |
@@ -240,7 +245,7 @@ abstract class DynamicPartitionPruningSuiteBase | |
df.collect() | |
val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect { | |
- case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _) => | |
+ case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => | |
b.index | |
} | |
assert(buf.distinct.size == n) | |
@@ -249,7 +254,7 @@ abstract class DynamicPartitionPruningSuiteBase | |
/** | |
* Collect the children of all correctly pushed down dynamic pruning expressions in a spark plan. | |
*/ | |
- private def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = { | |
+ protected def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = { | |
flatMap(plan) { | |
case s: FileSourceScanExec => s.partitionFilters.collect { | |
case d: DynamicPruningExpression => d.child | |
@@ -339,12 +344,12 @@ abstract class DynamicPartitionPruningSuiteBase | |
| ) | |
""".stripMargin) | |
- val found = df.queryExecution.executedPlan.find { | |
+ val found = df.queryExecution.executedPlan.exists { | |
case BroadcastHashJoinExec(_, _, p: ExistenceJoin, _, _, _, _, _) => true | |
case _ => false | |
} | |
- assert(found.isEmpty) | |
+ assert(!found) | |
} | |
} | |
@@ -467,31 +472,70 @@ abstract class DynamicPartitionPruningSuiteBase | |
Given("no stats and selective predicate with the size of dim too large") | |
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", | |
- SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { | |
- sql( | |
- """ | |
- |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
- |FROM fact_sk f WHERE store_id < 5 | |
- """.stripMargin) | |
- .write | |
- .partitionBy("store_id") | |
- .saveAsTable("fact_aux") | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true", | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.02") { | |
+ withTable("fact_aux") { | |
+ sql( | |
+ """ | |
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
+ |FROM fact_sk f WHERE store_id < 5 | |
+ """.stripMargin) | |
+ .write | |
+ .partitionBy("store_id") | |
+ .saveAsTable("fact_aux") | |
+ | |
+ val df = sql( | |
+ """ | |
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
+ |FROM fact_aux f JOIN dim_store s | |
+ |ON f.store_id = s.store_id WHERE s.country = 'US' | |
+ """.stripMargin) | |
+ | |
+ checkPartitionPruningPredicate(df, false, false) | |
+ | |
+ checkAnswer(df, | |
+ Row(1070, 2, 10, 4) :: | |
+ Row(1080, 3, 20, 4) :: | |
+ Row(1090, 3, 10, 4) :: | |
+ Row(1100, 3, 10, 4) :: Nil | |
+ ) | |
+ } | |
+ } | |
- val df = sql( | |
- """ | |
- |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
- |FROM fact_aux f JOIN dim_store s | |
- |ON f.store_id = s.store_id WHERE s.country = 'US' | |
- """.stripMargin) | |
+ Given("no stats and selective predicate with the size of dim too large but cached") | |
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { | |
+ withTable("fact_aux") { | |
+ withTempView("cached_dim_store") { | |
+ sql( | |
+ """ | |
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
+ |FROM fact_sk f WHERE store_id < 5 | |
+ """.stripMargin) | |
+ .write | |
+ .partitionBy("store_id") | |
+ .saveAsTable("fact_aux") | |
- checkPartitionPruningPredicate(df, false, false) | |
+ spark.table("dim_store").cache() | |
+ .createOrReplaceTempView("cached_dim_store") | |
- checkAnswer(df, | |
- Row(1070, 2, 10, 4) :: | |
- Row(1080, 3, 20, 4) :: | |
- Row(1090, 3, 10, 4) :: | |
- Row(1100, 3, 10, 4) :: Nil | |
- ) | |
+ val df = sql( | |
+ """ | |
+ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id | |
+ |FROM fact_aux f JOIN cached_dim_store s | |
+ |ON f.store_id = s.store_id WHERE s.country = 'US' | |
+ """.stripMargin) | |
+ | |
+ checkPartitionPruningPredicate(df, true, false) | |
+ | |
+ checkAnswer(df, | |
+ Row(1070, 2, 10, 4) :: | |
+ Row(1080, 3, 20, 4) :: | |
+ Row(1090, 3, 10, 4) :: | |
+ Row(1100, 3, 10, 4) :: Nil | |
+ ) | |
+ } | |
+ } | |
} | |
Given("no stats and selective predicate with the size of dim small") | |
@@ -983,41 +1027,6 @@ abstract class DynamicPartitionPruningSuiteBase | |
} | |
} | |
- test("no partition pruning when the build side is a stream") { | |
- withTable("fact") { | |
- val input = MemoryStream[Int] | |
- val stream = input.toDF.select($"value" as "one", ($"value" * 3) as "code") | |
- spark.range(100).select( | |
- $"id", | |
- ($"id" + 1).as("one"), | |
- ($"id" + 2).as("two"), | |
- ($"id" + 3).as("three")) | |
- .write.partitionBy("one") | |
- .format(tableFormat).mode("overwrite").saveAsTable("fact") | |
- val table = sql("SELECT * from fact f") | |
- | |
- // join a partitioned table with a stream | |
- val joined = table.join(stream, Seq("one")).where("code > 40") | |
- val query = joined.writeStream.format("memory").queryName("test").start() | |
- input.addData(1, 10, 20, 40, 50) | |
- try { | |
- query.processAllAvailable() | |
- } finally { | |
- query.stop() | |
- } | |
- // search dynamic pruning predicates on the executed plan | |
- val plan = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan | |
- val ret = plan.find { | |
- case s: FileSourceScanExec => s.partitionFilters.exists { | |
- case _: DynamicPruningExpression => true | |
- case _ => false | |
- } | |
- case _ => false | |
- } | |
- assert(ret.isDefined == false) | |
- } | |
- } | |
- | |
test("avoid reordering broadcast join keys to match input hash partitioning") { | |
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", | |
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | |
@@ -1144,7 +1153,8 @@ abstract class DynamicPartitionPruningSuiteBase | |
test("join key with multiple references on the filtering plan") { | |
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", | |
- SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName | |
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName, | |
+ SQLConf.ANSI_ENABLED.key -> "false" // ANSI mode doesn't support "String + String" | |
) { | |
// when enable AQE, the reusedExchange is inserted when executed. | |
withTable("fact", "dim") { | |
@@ -1473,9 +1483,124 @@ abstract class DynamicPartitionPruningSuiteBase | |
checkAnswer(df, Row(1150, 1) :: Row(1130, 4) :: Row(1140, 4) :: Nil) | |
} | |
} | |
+ | |
+ test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " + | |
+ "pruning") { | |
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { | |
+ Seq( | |
+ "f.store_id = 1" -> false, | |
+ "1 = f.store_id" -> false, | |
+ "f.store_id <=> 1" -> false, | |
+ "1 <=> f.store_id" -> false, | |
+ "f.store_id > 1" -> true, | |
+ "5 > f.store_id" -> true).foreach { case (condition, hasDPP) => | |
+ // partitioned table at left side | |
+ val df1 = sql( | |
+ s""" | |
+ |SELECT /*+ broadcast(s) */ * FROM fact_sk f | |
+ |JOIN dim_store s ON f.store_id = s.store_id AND $condition | |
+ """.stripMargin) | |
+ checkPartitionPruningPredicate(df1, false, withBroadcast = hasDPP) | |
+ | |
+ val df2 = sql( | |
+ s""" | |
+ |SELECT /*+ broadcast(s) */ * FROM fact_sk f | |
+ |JOIN dim_store s ON f.store_id = s.store_id | |
+ |WHERE $condition | |
+ """.stripMargin) | |
+ checkPartitionPruningPredicate(df2, false, withBroadcast = hasDPP) | |
+ | |
+ // partitioned table at right side | |
+ val df3 = sql( | |
+ s""" | |
+ |SELECT /*+ broadcast(s) */ * FROM dim_store s | |
+ |JOIN fact_sk f ON f.store_id = s.store_id AND $condition | |
+ """.stripMargin) | |
+ checkPartitionPruningPredicate(df3, false, withBroadcast = hasDPP) | |
+ | |
+ val df4 = sql( | |
+ s""" | |
+ |SELECT /*+ broadcast(s) */ * FROM dim_store s | |
+ |JOIN fact_sk f ON f.store_id = s.store_id | |
+ |WHERE $condition | |
+ """.stripMargin) | |
+ checkPartitionPruningPredicate(df4, false, withBroadcast = hasDPP) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38570: Fix incorrect DynamicPartitionPruning caused by Literal") { | |
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { | |
+ val df = sql( | |
+ """ | |
+ |SELECT f.store_id, | |
+ | f.date_id, | |
+ | s.state_province | |
+ |FROM (SELECT 4 AS store_id, | |
+ | date_id, | |
+ | product_id | |
+ | FROM fact_sk | |
+ | WHERE date_id >= 1300 | |
+ | UNION ALL | |
+ | SELECT 5 AS store_id, | |
+ | date_id, | |
+ | product_id | |
+ | FROM fact_stats | |
+ | WHERE date_id <= 1000) f | |
+ |JOIN dim_store s | |
+ |ON f.store_id = s.store_id | |
+ |WHERE s.country = 'US' | |
+ |""".stripMargin) | |
+ | |
+ checkPartitionPruningPredicate(df, withSubquery = false, withBroadcast = false) | |
+ checkAnswer(df, Row(4, 1300, "California") :: Row(5, 1000, "Texas") :: Nil) | |
+ } | |
+ } | |
} | |
-abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSuiteBase { | |
+abstract class DynamicPartitionPruningDataSourceSuiteBase | |
+ extends DynamicPartitionPruningSuiteBase | |
+ with SharedSparkSession { | |
+ | |
+ import testImplicits._ | |
+ | |
+ test("no partition pruning when the build side is a stream") { | |
+ withTable("fact") { | |
+ val input = MemoryStream[Int] | |
+ val stream = input.toDF.select($"value" as "one", ($"value" * 3) as "code") | |
+ spark.range(100).select( | |
+ $"id", | |
+ ($"id" + 1).as("one"), | |
+ ($"id" + 2).as("two"), | |
+ ($"id" + 3).as("three")) | |
+ .write.partitionBy("one") | |
+ .format(tableFormat).mode("overwrite").saveAsTable("fact") | |
+ val table = sql("SELECT * from fact f") | |
+ | |
+ // join a partitioned table with a stream | |
+ val joined = table.join(stream, Seq("one")).where("code > 40") | |
+ val query = joined.writeStream.format("memory").queryName("test").start() | |
+ input.addData(1, 10, 20, 40, 50) | |
+ try { | |
+ query.processAllAvailable() | |
+ } finally { | |
+ query.stop() | |
+ } | |
+ // search dynamic pruning predicates on the executed plan | |
+ val plan = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan | |
+ val ret = plan.exists { | |
+ case s: FileSourceScanExec => s.partitionFilters.exists { | |
+ case _: DynamicPruningExpression => true | |
+ case _ => false | |
+ } | |
+ case _ => false | |
+ } | |
+ assert(!ret) | |
+ } | |
+ } | |
+} | |
+ | |
+abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDataSourceSuiteBase { | |
import testImplicits._ | |
@@ -1510,10 +1635,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSui | |
val scanOption = | |
find(plan) { | |
case s: FileSourceScanExec => | |
- s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined) | |
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("fid"))) | |
case s: BatchScanExec => | |
// we use f1 col for v2 tables due to schema pruning | |
- s.output.exists(_.find(_.argString(maxFields = 100).contains("f1")).isDefined) | |
+ s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) | |
case _ => false | |
} | |
assert(scanOption.isDefined) | |
@@ -1569,6 +1694,25 @@ class DynamicPartitionPruningV1SuiteAEOff extends DynamicPartitionPruningV1Suite | |
class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite | |
with EnableAdaptiveExecutionSuite { | |
+ test("SPARK-39447: Avoid AssertionError in AdaptiveSparkPlanExec.doExecuteBroadcast") { | |
+ val df = sql( | |
+ """ | |
+ |WITH empty_result AS ( | |
+ | SELECT * FROM fact_stats WHERE product_id < 0 | |
+ |) | |
+ |SELECT * | |
+ |FROM (SELECT /*+ SHUFFLE_MERGE(fact_sk) */ empty_result.store_id | |
+ | FROM fact_sk | |
+ | JOIN empty_result | |
+ | ON fact_sk.product_id = empty_result.product_id) t2 | |
+ | JOIN empty_result | |
+ | ON t2.store_id = empty_result.store_id | |
+ """.stripMargin) | |
+ | |
+ checkPartitionPruningPredicate(df, false, false) | |
+ checkAnswer(df, Nil) | |
+ } | |
+ | |
test("SPARK-37995: PlanAdaptiveDynamicPruningFilters should use prepareExecutedPlan " + | |
"rather than createSparkPlan to re-plan subquery") { | |
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", | |
@@ -1587,7 +1731,7 @@ class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite | |
} | |
} | |
-abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningSuiteBase { | |
+abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningDataSourceSuiteBase { | |
override protected def runAnalyzeColumnCommands: Boolean = false | |
override protected def initState(): Unit = { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala | |
index aef39a24e9..d637283446 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala | |
@@ -106,7 +106,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite | |
keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") | |
} | |
- test("optimized plan should show the rewritten aggregate expression") { | |
+ test("optimized plan should show the rewritten expression") { | |
withTempView("test_agg") { | |
sql( | |
""" | |
@@ -125,6 +125,13 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite | |
"Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, " + | |
"any(v#x) AS any(v)#x]") | |
} | |
+ | |
+ withTable("t") { | |
+ sql("CREATE TABLE t(col TIMESTAMP) USING parquet") | |
+ val df = sql("SELECT date_part('month', col) FROM t") | |
+ checkKeywordsExistsInExplain(df, | |
+ "Project [month(cast(col#x as date)) AS date_part(month, col)#x]") | |
+ } | |
} | |
test("explain inline tables cross-joins") { | |
@@ -217,8 +224,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite | |
// AND conjunction | |
// OR disjunction | |
// --------------------------------------------------------------------------------------- | |
- checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"), | |
- "Project [null AS (concat(a, 1) + 2)#x]") | |
+ checkKeywordsExistsInExplain(sql("select '1' || 1 + 2"), | |
+ "Project [13", " AS (concat(1, 1) + 2)#x") | |
checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"), | |
"Project [-1b AS concat((1 - 2), b)#x]") | |
checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"), | |
@@ -232,12 +239,11 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite | |
} | |
test("explain for these functions; use range to avoid constant folding") { | |
- val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " + | |
+ val df = sql("select ifnull(id, 1), nullif(id, 1), nvl(id, 1), nvl2(id, 1, 2) " + | |
"from range(2)") | |
checkKeywordsExistsInExplain(df, | |
- "Project [coalesce(cast(id#xL as string), x) AS ifnull(id, x)#x, " + | |
- "id#xL AS nullif(id, x)#xL, coalesce(cast(id#xL as string), x) AS nvl(id, x)#x, " + | |
- "x AS nvl2(id, x, y)#x]") | |
+ "Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " + | |
+ "else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]") | |
} | |
test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") { | |
@@ -522,6 +528,21 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite | |
"== Analyzed Logical Plan ==\nCreateViewCommand") | |
} | |
} | |
+ | |
+ test("SPARK-39112: UnsupportedOperationException if explain cost command using v2 command") { | |
+ withTempDir { dir => | |
+ sql("EXPLAIN COST CREATE DATABASE tmp") | |
+ sql("EXPLAIN COST DESC DATABASE tmp") | |
+ sql(s"EXPLAIN COST ALTER DATABASE tmp SET LOCATION '${dir.toURI.toString}'") | |
+ sql("EXPLAIN COST USE tmp") | |
+ sql("EXPLAIN COST CREATE TABLE t(c1 int) USING PARQUET") | |
+ sql("EXPLAIN COST SHOW TABLES") | |
+ sql("EXPLAIN COST SHOW CREATE TABLE t") | |
+ sql("EXPLAIN COST SHOW TBLPROPERTIES t") | |
+ sql("EXPLAIN COST DROP TABLE t") | |
+ sql("EXPLAIN COST DROP DATABASE tmp") | |
+ } | |
+ } | |
} | |
class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { | |
@@ -594,7 +615,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit | |
} | |
test("SPARK-35884: Explain should only display one plan before AQE takes effect") { | |
- val df = (0 to 10).toDF("id").where('id > 5) | |
+ val df = (0 to 10).toDF("id").where(Symbol("id") > 5) | |
val modes = Seq(SimpleMode, ExtendedMode, CostMode, FormattedMode) | |
modes.foreach { mode => | |
checkKeywordsExistsInExplain(df, mode, "AdaptiveSparkPlan") | |
@@ -609,7 +630,8 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit | |
test("SPARK-35884: Explain formatted with subquery") { | |
withTempView("t1", "t2") { | |
- spark.range(100).select('id % 10 as "key", 'id as "value").createOrReplaceTempView("t1") | |
+ spark.range(100).select(Symbol("id") % 10 as "key", Symbol("id") as "value") | |
+ .createOrReplaceTempView("t1") | |
spark.range(10).createOrReplaceTempView("t2") | |
val query = | |
""" | |
@@ -678,6 +700,33 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit | |
} | |
} | |
+ test("SPARK-32986: Bucketed scan info should be a part of explain string") { | |
+ withTable("t1", "t2") { | |
+ Seq((1, 2), (2, 3)).toDF("i", "j").write.bucketBy(8, "i").saveAsTable("t1") | |
+ Seq(2, 3).toDF("i").write.bucketBy(8, "i").saveAsTable("t2") | |
+ val df1 = spark.table("t1") | |
+ val df2 = spark.table("t2") | |
+ | |
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { | |
+ checkKeywordsExistsInExplain( | |
+ df1.join(df2, df1("i") === df2("i")), | |
+ "Bucketed: true") | |
+ } | |
+ | |
+ withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { | |
+ checkKeywordsExistsInExplain( | |
+ df1.join(df2, df1("i") === df2("i")), | |
+ "Bucketed: false (disabled by configuration)") | |
+ } | |
+ | |
+ checkKeywordsExistsInExplain(df1, "Bucketed: false (disabled by query planner)" ) | |
+ | |
+ checkKeywordsExistsInExplain( | |
+ df1.select("j"), | |
+ "Bucketed: false (bucket column(s) not read)") | |
+ } | |
+ } | |
+ | |
test("SPARK-36795: Node IDs should not be duplicated when InMemoryRelation present") { | |
withTempView("t1", "t2") { | |
Seq(1).toDF("k").write.saveAsTable("t1") | |
@@ -696,6 +745,35 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit | |
assert(inMemoryRelationNodeId != columnarToRowNodeId) | |
} | |
} | |
+ | |
+ test("SPARK-38232: Explain formatted does not collect subqueries under query stage in AQE") { | |
+ withTable("t") { | |
+ sql("CREATE TABLE t USING PARQUET AS SELECT 1 AS c") | |
+ val expected = | |
+ "Subquery:1 Hosting operator id = 2 Hosting Expression = Subquery subquery#x, [id=#x]" | |
+ val df = sql("SELECT count(s) FROM (SELECT (SELECT c FROM t) as s)") | |
+ df.collect() | |
+ withNormalizedExplain(df, FormattedMode) { output => | |
+ assert(output.contains(expected)) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38322: Support query stage show runtime statistics in formatted explain mode") { | |
+ val df = Seq(1, 2).toDF("c").distinct() | |
+ val statistics = "Statistics(sizeInBytes=32.0 B, rowCount=2)" | |
+ | |
+ checkKeywordsNotExistsInExplain( | |
+ df, | |
+ FormattedMode, | |
+ statistics) | |
+ | |
+ df.collect() | |
+ checkKeywordsExistsInExplain( | |
+ df, | |
+ FormattedMode, | |
+ statistics) | |
+ } | |
} | |
case class ExplainSingleData(id: Int) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | |
index 001b6a00af..17dfde65ca 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | |
@@ -382,7 +382,9 @@ class FileBasedDataSourceSuite extends QueryTest | |
msg.toLowerCase(Locale.ROOT).contains(msg2)) | |
} | |
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) { | |
+ withSQLConf( | |
+ SQLConf.USE_V1_SOURCE_LIST.key -> useV1List, | |
+ SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { | |
// write path | |
Seq("csv", "json", "parquet", "orc").foreach { format => | |
val msg = intercept[AnalysisException] { | |
@@ -532,6 +534,64 @@ class FileBasedDataSourceSuite extends QueryTest | |
} | |
} | |
+ test("SPARK-30362: test input metrics for DSV2") { | |
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { | |
+ Seq("json", "orc", "parquet").foreach { format => | |
+ withTempPath { path => | |
+ val dir = path.getCanonicalPath | |
+ spark.range(0, 10).write.format(format).save(dir) | |
+ val df = spark.read.format(format).load(dir) | |
+ val bytesReads = new mutable.ArrayBuffer[Long]() | |
+ val recordsRead = new mutable.ArrayBuffer[Long]() | |
+ val bytesReadListener = new SparkListener() { | |
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { | |
+ bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead | |
+ recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead | |
+ } | |
+ } | |
+ sparkContext.addSparkListener(bytesReadListener) | |
+ try { | |
+ df.collect() | |
+ sparkContext.listenerBus.waitUntilEmpty() | |
+ assert(bytesReads.sum > 0) | |
+ assert(recordsRead.sum == 10) | |
+ } finally { | |
+ sparkContext.removeSparkListener(bytesReadListener) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37585: test input metrics for DSV2 with output limits") { | |
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { | |
+ Seq("json", "orc", "parquet").foreach { format => | |
+ withTempPath { path => | |
+ val dir = path.getCanonicalPath | |
+ spark.range(0, 100).write.format(format).save(dir) | |
+ val df = spark.read.format(format).load(dir) | |
+ val bytesReads = new mutable.ArrayBuffer[Long]() | |
+ val recordsRead = new mutable.ArrayBuffer[Long]() | |
+ val bytesReadListener = new SparkListener() { | |
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { | |
+ bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead | |
+ recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead | |
+ } | |
+ } | |
+ sparkContext.addSparkListener(bytesReadListener) | |
+ try { | |
+ df.limit(10).collect() | |
+ sparkContext.listenerBus.waitUntilEmpty() | |
+ assert(bytesReads.sum > 0) | |
+ assert(recordsRead.sum > 0) | |
+ } finally { | |
+ sparkContext.removeSparkListener(bytesReadListener) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
test("Do not use cache on overwrite") { | |
Seq("", "orc").foreach { useV1SourceReaderList => | |
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1SourceReaderList) { | |
@@ -695,7 +755,7 @@ class FileBasedDataSourceSuite extends QueryTest | |
test("SPARK-22790,SPARK-27668: spark.sql.sources.compressionFactor takes effect") { | |
Seq(1.0, 0.5).foreach { compressionFactor => | |
withSQLConf(SQLConf.FILE_COMPRESSION_FACTOR.key -> compressionFactor.toString, | |
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "350") { | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "457") { | |
withTempPath { workDir => | |
// the file size is 504 bytes | |
val workDirPath = workDir.getAbsolutePath | |
@@ -731,6 +791,28 @@ class FileBasedDataSourceSuite extends QueryTest | |
} | |
} | |
+ test("SPARK-36568: FileScan statistics estimation takes read schema into account") { | |
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { | |
+ withTempDir { dir => | |
+ spark.range(1000).map(x => (x / 100, x, x)).toDF("k", "v1", "v2"). | |
+ write.partitionBy("k").mode(SaveMode.Overwrite).orc(dir.toString) | |
+ val dfAll = spark.read.orc(dir.toString) | |
+ val dfK = dfAll.select("k") | |
+ val dfV1 = dfAll.select("v1") | |
+ val dfV2 = dfAll.select("v2") | |
+ val dfV1V2 = dfAll.select("v1", "v2") | |
+ | |
+ def sizeInBytes(df: DataFrame): BigInt = df.queryExecution.optimizedPlan.stats.sizeInBytes | |
+ | |
+ assert(sizeInBytes(dfAll) === BigInt(getLocalDirSize(dir))) | |
+ assert(sizeInBytes(dfK) < sizeInBytes(dfAll)) | |
+ assert(sizeInBytes(dfV1) < sizeInBytes(dfAll)) | |
+ assert(sizeInBytes(dfV2) === sizeInBytes(dfV1)) | |
+ assert(sizeInBytes(dfV1V2) < sizeInBytes(dfAll)) | |
+ } | |
+ } | |
+ } | |
+ | |
test("File source v2: support partition pruning") { | |
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { | |
allFileBasedDataSources.foreach { format => | |
@@ -754,12 +836,13 @@ class FileBasedDataSourceSuite extends QueryTest | |
} | |
assert(filterCondition.isDefined) | |
// The partitions filters should be pushed down and no need to be reevaluated. | |
- assert(filterCondition.get.collectFirst { | |
- case a: AttributeReference if a.name == "p1" || a.name == "p2" => a | |
- }.isEmpty) | |
+ assert(!filterCondition.get.exists { | |
+ case a: AttributeReference => a.name == "p1" || a.name == "p2" | |
+ case _ => false | |
+ }) | |
val fileScan = df.queryExecution.executedPlan collectFirst { | |
- case BatchScanExec(_, f: FileScan, _) => f | |
+ case BatchScanExec(_, f: FileScan, _, _) => f | |
} | |
assert(fileScan.nonEmpty) | |
assert(fileScan.get.partitionFilters.nonEmpty) | |
@@ -799,7 +882,7 @@ class FileBasedDataSourceSuite extends QueryTest | |
assert(filterCondition.isDefined) | |
val fileScan = df.queryExecution.executedPlan collectFirst { | |
- case BatchScanExec(_, f: FileScan, _) => f | |
+ case BatchScanExec(_, f: FileScan, _, _) => f | |
} | |
assert(fileScan.nonEmpty) | |
assert(fileScan.get.partitionFilters.isEmpty) | |
@@ -885,52 +968,57 @@ class FileBasedDataSourceSuite extends QueryTest | |
// cases when value == MAX | |
var v = Short.MaxValue | |
- checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true) | |
- checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"), | |
- sources.EqualTo("id", v))) | |
- checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"), | |
- sources.EqualTo("id", v))) | |
- checkPushedFilters(format, df.where('id <=> v.toInt), | |
+ checkPushedFilters(format, df.where(Symbol("id") > v.toInt), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(Symbol("id") >= v.toInt), | |
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) | |
+ checkPushedFilters(format, df.where(Symbol("id") === v.toInt), | |
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) | |
+ checkPushedFilters(format, df.where(Symbol("id") <=> v.toInt), | |
Array(sources.EqualNullSafe("id", v))) | |
- checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id"))) | |
- checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"), | |
- sources.Not(sources.EqualTo("id", v)))) | |
+ checkPushedFilters(format, df.where(Symbol("id") <= v.toInt), | |
+ Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(Symbol("id") < v.toInt), | |
+ Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) | |
// cases when value > MAX | |
var v1: Int = positiveInt | |
- checkPushedFilters(format, df.where('id > v1), Array(), noScan = true) | |
- checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true) | |
- checkPushedFilters(format, df.where('id === v1), Array(), noScan = true) | |
- checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true) | |
- checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id"))) | |
- checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(Symbol("id") > v1), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(Symbol("id") >= v1), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(Symbol("id") === v1), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(Symbol("id") <=> v1), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(Symbol("id") <= v1), Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(Symbol("id") < v1), Array(sources.IsNotNull("id"))) | |
// cases when value = MIN | |
v = Short.MinValue | |
- checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"), | |
- sources.Not(sources.EqualTo("id", v)))) | |
- checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id"))) | |
- checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"), | |
+ checkPushedFilters(format, df.where(lit(v.toInt) < Symbol("id")), | |
+ Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) | |
+ checkPushedFilters(format, df.where(lit(v.toInt) <= Symbol("id")), | |
+ Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(lit(v.toInt) === Symbol("id")), | |
+ Array(sources.IsNotNull("id"), | |
sources.EqualTo("id", v))) | |
- checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id), | |
+ checkPushedFilters(format, df.where(lit(v.toInt) <=> Symbol("id")), | |
Array(sources.EqualNullSafe("id", v))) | |
- checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"), | |
- sources.EqualTo("id", v))) | |
- checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(lit(v.toInt) >= Symbol("id")), | |
+ Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) | |
+ checkPushedFilters(format, df.where(lit(v.toInt) > Symbol("id")), Array(), noScan = true) | |
// cases when value < MIN | |
v1 = negativeInt | |
- checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id"))) | |
- checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id"))) | |
- checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true) | |
- checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true) | |
- checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(lit(v1) < Symbol("id")), | |
+ Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(lit(v1) <= Symbol("id")), | |
+ Array(sources.IsNotNull("id"))) | |
+ checkPushedFilters(format, df.where(lit(v1) === Symbol("id")), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(lit(v1) >= Symbol("id")), Array(), noScan = true) | |
+ checkPushedFilters(format, df.where(lit(v1) > Symbol("id")), Array(), noScan = true) | |
// cases when value is within range (MIN, MAX) | |
- checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"), | |
+ checkPushedFilters(format, df.where(Symbol("id") > 30), Array(sources.IsNotNull("id"), | |
sources.GreaterThan("id", 30))) | |
- checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"), | |
- sources.LessThanOrEqual("id", 100))) | |
+ checkPushedFilters(format, df.where(lit(100) >= Symbol("id")), | |
+ Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100))) | |
} | |
} | |
} | |
@@ -967,28 +1055,6 @@ class FileBasedDataSourceSuite extends QueryTest | |
checkAnswer(df, Row("v1", "v2")) | |
} | |
} | |
- | |
- test("SPARK-36271: V1 insert should check schema field name too") { | |
- withView("v") { | |
- spark.range(1).createTempView("v") | |
- withTempDir { dir => | |
- val e = intercept[AnalysisException] { | |
- sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite) | |
- .format("parquet").save(dir.getCanonicalPath) | |
- }.getMessage | |
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s).")) | |
- } | |
- | |
- withTempDir { dir => | |
- val e = intercept[AnalysisException] { | |
- sql("SELECT NAMED_STRUCT('(IF((ID = 1), 1, 0))', IF(ID=1,ID,0)) AS col1 FROM v") | |
- .write.mode(SaveMode.Overwrite) | |
- .format("parquet").save(dir.getCanonicalPath) | |
- }.getMessage | |
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s).")) | |
- } | |
- } | |
- } | |
} | |
object TestingUDT { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala | |
index 4e7fe8455f..ce98fd2735 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala | |
@@ -85,10 +85,11 @@ trait FileScanSuiteBase extends SharedSparkSession { | |
val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap)) | |
val optionsNotEqual = | |
new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", "value2"))) | |
- val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) | |
- val partitionFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) | |
- val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0))) | |
- val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1))) | |
+ val partitionFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0))) | |
+ val partitionFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), | |
+ LessThan(Symbol("data").int, 1))) | |
+ val dataFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0))) | |
+ val dataFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 1))) | |
scanBuilders.foreach { case (name, scanBuilder, exclusions) => | |
test(s"SPARK-33482: Test $name equals") { | |
@@ -354,11 +355,11 @@ class FileScanSuite extends FileScanSuiteBase { | |
val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( | |
("ParquetScan", | |
(s, fi, ds, rds, rps, f, o, pf, df) => | |
- ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), | |
+ ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df), | |
Seq.empty), | |
("OrcScan", | |
(s, fi, ds, rds, rps, f, o, pf, df) => | |
- OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), | |
+ OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df), | |
Seq.empty), | |
("CSVScan", | |
(s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), | |
@@ -367,7 +368,7 @@ class FileScanSuite extends FileScanSuiteBase { | |
(s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df), | |
Seq.empty), | |
("TextScan", | |
- (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df), | |
+ (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df), | |
Seq("dataSchema", "pushedFilters"))) | |
run(scanBuilders) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala | |
index d5c2d93055..49cdc80241 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala | |
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo | |
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | |
import org.apache.spark.sql.catalyst.trees.LeafLike | |
import org.apache.spark.sql.functions._ | |
+import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types.{IntegerType, StructType} | |
@@ -331,7 +332,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { | |
val msg1 = intercept[AnalysisException] { | |
sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1") | |
}.getMessage | |
- assert(msg1.contains("Generators are not supported when it's nested in expressions")) | |
+ assert(msg1.contains("The generator is not supported: nested in expressions")) | |
val msg2 = intercept[AnalysisException] { | |
sql( | |
@@ -341,7 +342,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { | |
|from t1 group by c1 | |
""".stripMargin) | |
}.getMessage | |
- assert(msg2.contains("Only one generator allowed per aggregate clause")) | |
+ assert(msg2.contains("The generator is not supported: " + | |
+ "only one generator allowed per aggregate clause")) | |
} | |
} | |
@@ -349,14 +351,99 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { | |
val errMsg = intercept[AnalysisException] { | |
sql("SELECT array(array(1, 2), array(3)) v").select(explode(explode($"v"))).collect | |
}.getMessage | |
- assert(errMsg.contains("Generators are not supported when it's nested in expressions, " + | |
- "but got: explode(explode(v))")) | |
+ assert(errMsg.contains("The generator is not supported: " + | |
+ "nested in expressions \"explode(explode(v))\"")) | |
} | |
test("SPARK-30997: generators in aggregate expressions for dataframe") { | |
val df = Seq(1, 2, 3).toDF("v") | |
checkAnswer(df.select(explode(array(min($"v"), max($"v")))), Row(1) :: Row(3) :: Nil) | |
} | |
+ | |
+ test("SPARK-38528: generator in stream of aggregate expressions") { | |
+ val df = Seq(1, 2, 3).toDF("v") | |
+ checkAnswer( | |
+ df.select(Stream(explode(array(min($"v"), max($"v"))), sum($"v")): _*), | |
+ Row(1, 6) :: Row(3, 6) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-37947: lateral view <func>_outer()") { | |
+ checkAnswer( | |
+ sql("select * from values 1, 2 lateral view explode_outer(array()) a as b"), | |
+ Row(1, null) :: Row(2, null) :: Nil) | |
+ | |
+ checkAnswer( | |
+ sql("select * from values 1, 2 lateral view outer explode_outer(array()) a as b"), | |
+ Row(1, null) :: Row(2, null) :: Nil) | |
+ | |
+ withTempView("t1") { | |
+ sql( | |
+ """select * from values | |
+ |array(struct(0, 1), struct(3, 4)), | |
+ |array(struct(6, 7)), | |
+ |array(), | |
+ |null | |
+ |as tbl(arr) | |
+ """.stripMargin).createOrReplaceTempView("t1") | |
+ checkAnswer( | |
+ sql("select f1, f2 from t1 lateral view inline_outer(arr) as f1, f2"), | |
+ Row(0, 1) :: Row(3, 4) :: Row(6, 7) :: Row(null, null) :: Row(null, null) :: Nil) | |
+ } | |
+ } | |
+ | |
+ def testNullStruct(): Unit = { | |
+ val df = sql( | |
+ """select * from values | |
+ |( | |
+ | 1, | |
+ | array( | |
+ | named_struct('c1', 0, 'c2', 1), | |
+ | null, | |
+ | named_struct('c1', 2, 'c2', 3), | |
+ | null | |
+ | ) | |
+ |) | |
+ |as tbl(a, b) | |
+ """.stripMargin) | |
+ df.createOrReplaceTempView("t1") | |
+ | |
+ checkAnswer( | |
+ sql("select inline(b) from t1"), | |
+ Row(0, 1) :: Row(null, null) :: Row(2, 3) :: Row(null, null) :: Nil) | |
+ | |
+ checkAnswer( | |
+ sql("select a, inline(b) from t1"), | |
+ Row(1, 0, 1) :: Row(1, null, null) :: Row(1, 2, 3) :: Row(1, null, null) :: Nil) | |
+ } | |
+ | |
+ test("SPARK-39061: inline should handle null struct") { | |
+ testNullStruct | |
+ } | |
+ | |
+ test("SPARK-39496: inline eval path should handle null struct") { | |
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { | |
+ testNullStruct | |
+ } | |
+ } | |
+ | |
+ test("SPARK-40963: generator output has correct nullability") { | |
+ // This test does not check nullability directly. Before SPARK-40963, | |
+ // the below query got wrong results due to incorrect nullability. | |
+ val df = sql( | |
+ """select c1, explode(c4) as c5 from ( | |
+ | select c1, array(c3) as c4 from ( | |
+ | select c1, explode_outer(c2) as c3 | |
+ | from values | |
+ | (1, array(1, 2)), | |
+ | (2, array(2, 3)), | |
+ | (3, null) | |
+ | as data(c1, c2) | |
+ | ) | |
+ |) | |
+ |""".stripMargin) | |
+ checkAnswer(df, | |
+ Row(1, 1) :: Row(1, 2) :: Row(2, 2) :: Row(2, 3) :: Row(3, null) :: Nil) | |
+ } | |
} | |
case class EmptyGenerator() extends Generator with LeafLike[Expression] { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala | |
new file mode 100644 | |
index 0000000000..6c6bd1799e | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala | |
@@ -0,0 +1,579 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} | |
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} | |
+import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries | |
+import org.apache.spark.sql.catalyst.plans.LeftSemi | |
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan} | |
+import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec} | |
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} | |
+import org.apache.spark.sql.types.{IntegerType, StructType} | |
+ | |
+class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession | |
+ with AdaptiveSparkPlanHelper { | |
+ | |
+ protected override def beforeAll(): Unit = { | |
+ super.beforeAll() | |
+ val schema = new StructType().add("a1", IntegerType, nullable = true) | |
+ .add("b1", IntegerType, nullable = true) | |
+ .add("c1", IntegerType, nullable = true) | |
+ .add("d1", IntegerType, nullable = true) | |
+ .add("e1", IntegerType, nullable = true) | |
+ .add("f1", IntegerType, nullable = true) | |
+ | |
+ val data1 = Seq(Seq(null, 47, null, 4, 6, 48), | |
+ Seq(73, 63, null, 92, null, null), | |
+ Seq(76, 10, 74, 98, 37, 5), | |
+ Seq(0, 63, null, null, null, null), | |
+ Seq(15, 77, null, null, null, null), | |
+ Seq(null, 57, 33, 55, null, 58), | |
+ Seq(4, 0, 86, null, 96, 14), | |
+ Seq(28, 16, 58, null, null, null), | |
+ Seq(1, 88, null, 8, null, 79), | |
+ Seq(59, null, null, null, 20, 25), | |
+ Seq(1, 50, null, 94, 94, null), | |
+ Seq(null, null, null, 67, 51, 57), | |
+ Seq(77, 50, 8, 90, 16, 21), | |
+ Seq(34, 28, null, 5, null, 64), | |
+ Seq(null, null, 88, 11, 63, 79), | |
+ Seq(92, 94, 23, 1, null, 64), | |
+ Seq(57, 56, null, 83, null, null), | |
+ Seq(null, 35, 8, 35, null, 70), | |
+ Seq(null, 8, null, 35, null, 87), | |
+ Seq(9, null, null, 60, null, 5), | |
+ Seq(null, 15, 66, null, 83, null)) | |
+ val rdd1 = spark.sparkContext.parallelize(data1) | |
+ val rddRow1 = rdd1.map(s => Row.fromSeq(s)) | |
+ spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1") | |
+ | |
+ val schema2 = new StructType().add("a2", IntegerType, nullable = true) | |
+ .add("b2", IntegerType, nullable = true) | |
+ .add("c2", IntegerType, nullable = true) | |
+ .add("d2", IntegerType, nullable = true) | |
+ .add("e2", IntegerType, nullable = true) | |
+ .add("f2", IntegerType, nullable = true) | |
+ | |
+ | |
+ val data2 = Seq(Seq(67, 17, 45, 91, null, null), | |
+ Seq(98, 63, 0, 89, null, 40), | |
+ Seq(null, 76, 68, 75, 20, 19), | |
+ Seq(8, null, null, null, 78, null), | |
+ Seq(48, 62, null, null, 11, 98), | |
+ Seq(84, null, 99, 65, 66, 51), | |
+ Seq(98, null, null, null, 42, 51), | |
+ Seq(10, 3, 29, null, 68, 8), | |
+ Seq(85, 36, 41, null, 28, 71), | |
+ Seq(89, null, 94, 95, 67, 21), | |
+ Seq(44, null, 24, 33, null, 6), | |
+ Seq(null, 6, 78, 31, null, 69), | |
+ Seq(59, 2, 63, 9, 66, 20), | |
+ Seq(5, 23, 10, 86, 68, null), | |
+ Seq(null, 63, 99, 55, 9, 65), | |
+ Seq(57, 62, 68, 5, null, 0), | |
+ Seq(75, null, 15, null, 81, null), | |
+ Seq(53, null, 6, 68, 28, 13), | |
+ Seq(null, null, null, null, 89, 23), | |
+ Seq(36, 73, 40, null, 8, null), | |
+ Seq(24, null, null, 40, null, null)) | |
+ val rdd2 = spark.sparkContext.parallelize(data2) | |
+ val rddRow2 = rdd2.map(s => Row.fromSeq(s)) | |
+ spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2") | |
+ | |
+ val schema3 = new StructType().add("a3", IntegerType, nullable = true) | |
+ .add("b3", IntegerType, nullable = true) | |
+ .add("c3", IntegerType, nullable = true) | |
+ .add("d3", IntegerType, nullable = true) | |
+ .add("e3", IntegerType, nullable = true) | |
+ .add("f3", IntegerType, nullable = true) | |
+ | |
+ val data3 = Seq(Seq(67, 17, 45, 91, null, null), | |
+ Seq(98, 63, 0, 89, null, 40), | |
+ Seq(null, 76, 68, 75, 20, 19), | |
+ Seq(8, null, null, null, 78, null), | |
+ Seq(48, 62, null, null, 11, 98), | |
+ Seq(84, null, 99, 65, 66, 51), | |
+ Seq(98, null, null, null, 42, 51), | |
+ Seq(10, 3, 29, null, 68, 8), | |
+ Seq(85, 36, 41, null, 28, 71), | |
+ Seq(89, null, 94, 95, 67, 21), | |
+ Seq(44, null, 24, 33, null, 6), | |
+ Seq(null, 6, 78, 31, null, 69), | |
+ Seq(59, 2, 63, 9, 66, 20), | |
+ Seq(5, 23, 10, 86, 68, null), | |
+ Seq(null, 63, 99, 55, 9, 65), | |
+ Seq(57, 62, 68, 5, null, 0), | |
+ Seq(75, null, 15, null, 81, null), | |
+ Seq(53, null, 6, 68, 28, 13), | |
+ Seq(null, null, null, null, 89, 23), | |
+ Seq(36, 73, 40, null, 8, null), | |
+ Seq(24, null, null, 40, null, null)) | |
+ val rdd3 = spark.sparkContext.parallelize(data3) | |
+ val rddRow3 = rdd3.map(s => Row.fromSeq(s)) | |
+ spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3") | |
+ | |
+ | |
+ val schema4 = new StructType().add("a4", IntegerType, nullable = true) | |
+ .add("b4", IntegerType, nullable = true) | |
+ .add("c4", IntegerType, nullable = true) | |
+ .add("d4", IntegerType, nullable = true) | |
+ .add("e4", IntegerType, nullable = true) | |
+ .add("f4", IntegerType, nullable = true) | |
+ | |
+ val data4 = Seq(Seq(67, 17, 45, 91, null, null), | |
+ Seq(98, 63, 0, 89, null, 40), | |
+ Seq(null, 76, 68, 75, 20, 19), | |
+ Seq(8, null, null, null, 78, null), | |
+ Seq(48, 62, null, null, 11, 98), | |
+ Seq(84, null, 99, 65, 66, 51), | |
+ Seq(98, null, null, null, 42, 51), | |
+ Seq(10, 3, 29, null, 68, 8), | |
+ Seq(85, 36, 41, null, 28, 71), | |
+ Seq(89, null, 94, 95, 67, 21), | |
+ Seq(44, null, 24, 33, null, 6), | |
+ Seq(null, 6, 78, 31, null, 69), | |
+ Seq(59, 2, 63, 9, 66, 20), | |
+ Seq(5, 23, 10, 86, 68, null), | |
+ Seq(null, 63, 99, 55, 9, 65), | |
+ Seq(57, 62, 68, 5, null, 0), | |
+ Seq(75, null, 15, null, 81, null), | |
+ Seq(53, null, 6, 68, 28, 13), | |
+ Seq(null, null, null, null, 89, 23), | |
+ Seq(36, 73, 40, null, 8, null), | |
+ Seq(24, null, null, 40, null, null)) | |
+ val rdd4 = spark.sparkContext.parallelize(data4) | |
+ val rddRow4 = rdd4.map(s => Row.fromSeq(s)) | |
+ spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4") | |
+ | |
+ val schema5part = new StructType().add("a5", IntegerType, nullable = true) | |
+ .add("b5", IntegerType, nullable = true) | |
+ .add("c5", IntegerType, nullable = true) | |
+ .add("d5", IntegerType, nullable = true) | |
+ .add("e5", IntegerType, nullable = true) | |
+ .add("f5", IntegerType, nullable = true) | |
+ | |
+ val data5part = Seq(Seq(67, 17, 45, 91, null, null), | |
+ Seq(98, 63, 0, 89, null, 40), | |
+ Seq(null, 76, 68, 75, 20, 19), | |
+ Seq(8, null, null, null, 78, null), | |
+ Seq(48, 62, null, null, 11, 98), | |
+ Seq(84, null, 99, 65, 66, 51), | |
+ Seq(98, null, null, null, 42, 51), | |
+ Seq(10, 3, 29, null, 68, 8), | |
+ Seq(85, 36, 41, null, 28, 71), | |
+ Seq(89, null, 94, 95, 67, 21), | |
+ Seq(44, null, 24, 33, null, 6), | |
+ Seq(null, 6, 78, 31, null, 69), | |
+ Seq(59, 2, 63, 9, 66, 20), | |
+ Seq(5, 23, 10, 86, 68, null), | |
+ Seq(null, 63, 99, 55, 9, 65), | |
+ Seq(57, 62, 68, 5, null, 0), | |
+ Seq(75, null, 15, null, 81, null), | |
+ Seq(53, null, 6, 68, 28, 13), | |
+ Seq(null, null, null, null, 89, 23), | |
+ Seq(36, 73, 40, null, 8, null), | |
+ Seq(24, null, null, 40, null, null)) | |
+ val rdd5part = spark.sparkContext.parallelize(data5part) | |
+ val rddRow5part = rdd5part.map(s => Row.fromSeq(s)) | |
+ spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5") | |
+ .saveAsTable("bf5part") | |
+ spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30") | |
+ .write.partitionBy("f5") | |
+ .saveAsTable("bf5filtered") | |
+ | |
+ sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1") | |
+ sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2") | |
+ sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3") | |
+ sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4") | |
+ sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5") | |
+ sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5") | |
+ | |
+ // `MergeScalarSubqueries` can duplicate subqueries in the optimized plan and would make testing | |
+ // complicated. | |
+ conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, MergeScalarSubqueries.ruleName) | |
+ } | |
+ | |
+ protected override def afterAll(): Unit = try { | |
+ conf.setConfString(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, | |
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.defaultValueString) | |
+ | |
+ sql("DROP TABLE IF EXISTS bf1") | |
+ sql("DROP TABLE IF EXISTS bf2") | |
+ sql("DROP TABLE IF EXISTS bf3") | |
+ sql("DROP TABLE IF EXISTS bf4") | |
+ sql("DROP TABLE IF EXISTS bf5part") | |
+ sql("DROP TABLE IF EXISTS bf5filtered") | |
+ } finally { | |
+ super.afterAll() | |
+ } | |
+ | |
+ private def ensureLeftSemiJoinExists(plan: LogicalPlan): Unit = { | |
+ assert( | |
+ plan.find { | |
+ case j: Join if j.joinType == LeftSemi => true | |
+ case _ => false | |
+ }.isDefined | |
+ ) | |
+ } | |
+ | |
+ def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean, | |
+ shouldReplace: Boolean): Unit = { | |
+ var planDisabled: LogicalPlan = null | |
+ var planEnabled: LogicalPlan = null | |
+ var expectedAnswer: Array[Row] = null | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { | |
+ planDisabled = sql(query).queryExecution.optimizedPlan | |
+ expectedAnswer = sql(query).collect() | |
+ } | |
+ | |
+ if (testSemiJoin) { | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { | |
+ planEnabled = sql(query).queryExecution.optimizedPlan | |
+ checkAnswer(sql(query), expectedAnswer) | |
+ } | |
+ if (shouldReplace) { | |
+ val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled)) | |
+ val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) | |
+ ensureLeftSemiJoinExists(planEnabled) | |
+ assert(normalizedEnabled != normalizedDisabled) | |
+ } else { | |
+ comparePlans(planDisabled, planEnabled) | |
+ } | |
+ } else { | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { | |
+ planEnabled = sql(query).queryExecution.optimizedPlan | |
+ checkAnswer(sql(query), expectedAnswer) | |
+ if (shouldReplace) { | |
+ assert(!columnPruningTakesEffect(planEnabled)) | |
+ assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled)) | |
+ } else { | |
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ def getNumBloomFilters(plan: LogicalPlan): Integer = { | |
+ val numBloomFilterAggs = plan.collect { | |
+ case Filter(condition, _) => condition.collect { | |
+ case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery | |
+ => subquery.plan.collect { | |
+ case Aggregate(_, aggregateExpressions, _) => | |
+ aggregateExpressions.map { | |
+ case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _), | |
+ _) => | |
+ assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal]) | |
+ assert(bfAgg.numBitsExpression.isInstanceOf[Literal]) | |
+ 1 | |
+ }.sum | |
+ }.sum | |
+ }.sum | |
+ }.sum | |
+ val numMightContains = plan.collect { | |
+ case Filter(condition, _) => condition.collect { | |
+ case BloomFilterMightContain(_, _) => 1 | |
+ }.sum | |
+ }.sum | |
+ assert(numBloomFilterAggs == numMightContains) | |
+ numMightContains | |
+ } | |
+ | |
+ def columnPruningTakesEffect(plan: LogicalPlan): Boolean = { | |
+ def takesEffect(plan: LogicalPlan): Boolean = { | |
+ val result = org.apache.spark.sql.catalyst.optimizer.ColumnPruning.apply(plan) | |
+ !result.fastEquals(plan) | |
+ } | |
+ | |
+ plan.collectFirst { | |
+ case Filter(condition, _) if condition.collectFirst { | |
+ case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery | |
+ if takesEffect(subquery.plan) => true | |
+ }.nonEmpty => true | |
+ }.nonEmpty | |
+ } | |
+ | |
+ def assertRewroteSemiJoin(query: String): Unit = { | |
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true) | |
+ } | |
+ | |
+ def assertDidNotRewriteSemiJoin(query: String): Unit = { | |
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false) | |
+ } | |
+ | |
+ def assertRewroteWithBloomFilter(query: String): Unit = { | |
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true) | |
+ } | |
+ | |
+ def assertDidNotRewriteWithBloomFilter(query: String): Unit = { | |
+ checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false) | |
+ } | |
+ | |
+ test("Runtime semi join reduction: simple") { | |
+ // Filter creation side is 3409 bytes | |
+ // Filter application side scan is 3362 bytes | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62") | |
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2") | |
+ } | |
+ } | |
+ | |
+ test("Runtime semi join reduction: two joins") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + | |
+ "and bf3.c3 = bf2.c2 where bf2.a2 = 5") | |
+ } | |
+ } | |
+ | |
+ test("Runtime semi join reduction: three joins") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " + | |
+ "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5") | |
+ } | |
+ } | |
+ | |
+ test("Runtime semi join reduction: simple expressions only") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ val squared = (s: Long) => { | |
+ s * s | |
+ } | |
+ spark.udf.register("square", squared) | |
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + | |
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62") | |
+ assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + | |
+ "bf1.c1 = square(bf2.c2) where bf2.a2= 62") | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: simple") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + | |
+ "where bf2.a2 = 62") | |
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2") | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: two filters single join") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ var planDisabled: LogicalPlan = null | |
+ var planEnabled: LogicalPlan = null | |
+ var expectedAnswer: Array[Row] = null | |
+ | |
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + | |
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62" | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { | |
+ planDisabled = sql(query).queryExecution.optimizedPlan | |
+ expectedAnswer = sql(query).collect() | |
+ } | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { | |
+ planEnabled = sql(query).queryExecution.optimizedPlan | |
+ checkAnswer(sql(query), expectedAnswer) | |
+ } | |
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: test the number of filter threshold") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ var planDisabled: LogicalPlan = null | |
+ var planEnabled: LogicalPlan = null | |
+ var expectedAnswer: Array[Row] = null | |
+ | |
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + | |
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62" | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { | |
+ planDisabled = sql(query).queryExecution.optimizedPlan | |
+ expectedAnswer = sql(query).collect() | |
+ } | |
+ | |
+ for (numFilterThreshold <- 0 to 3) { | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", | |
+ SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) { | |
+ planEnabled = sql(query).queryExecution.optimizedPlan | |
+ checkAnswer(sql(query), expectedAnswer) | |
+ } | |
+ if (numFilterThreshold < 3) { | |
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) | |
+ + numFilterThreshold) | |
+ } else { | |
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: insert one bloom filter per column") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ var planDisabled: LogicalPlan = null | |
+ var planEnabled: LogicalPlan = null | |
+ var expectedAnswer: Array[Row] = null | |
+ | |
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + | |
+ "bf1.c1 = bf2.b2 where bf2.a2 = 62" | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { | |
+ planDisabled = sql(query).queryExecution.optimizedPlan | |
+ expectedAnswer = sql(query).collect() | |
+ } | |
+ | |
+ withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { | |
+ planEnabled = sql(query).queryExecution.optimizedPlan | |
+ checkAnswer(sql(query), expectedAnswer) | |
+ } | |
+ assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1) | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " + | |
+ "on the same column") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " + | |
+ "bf5part.f5 = bf2.c2 where bf2.a2 = 62") | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: add bloom filter if dpp filter exists on " + | |
+ "a different column") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ assertRewroteWithBloomFilter("select * from bf5part join bf2 on " + | |
+ "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62") | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: BF rewrite triggering threshold test") { | |
+ // Filter creation side data size is 3409 bytes. On the filter application side, an individual | |
+ // scan's byte size is 3362. | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" | |
+ ) { | |
+ assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + | |
+ "where bf2.a2 = 62") | |
+ } | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50" | |
+ ) { | |
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + | |
+ "where bf2.a2 = 62") | |
+ } | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" | |
+ ) { | |
+ // Rewrite should not be triggered as the Bloom filter application side scan size is small. | |
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " | |
+ + "where bf2.a2 = 62") | |
+ } | |
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") { | |
+ // Test that the max scan size rather than an individual scan size on the filter | |
+ // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes. | |
+ withSQLConf( | |
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") { | |
+ assertRewroteWithBloomFilter("select * from " + | |
+ "(select * from bf5filtered union all select * from bf2) t " + | |
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") | |
+ } | |
+ withSQLConf( | |
+ SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") { | |
+ assertDidNotRewriteWithBloomFilter("select * from " + | |
+ "(select * from bf5filtered union all select * from bf2) t " + | |
+ "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Runtime bloom filter join: simple expressions only") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { | |
+ val squared = (s: Long) => { | |
+ s * s | |
+ } | |
+ spark.udf.register("square", squared) | |
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + | |
+ "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" ) | |
+ assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + | |
+ "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" ) | |
+ } | |
+ } | |
+ | |
+ test("Support Left Semi join in row level runtime filters") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32") { | |
+ assertRewroteWithBloomFilter( | |
+ """ | |
+ |SELECT * | |
+ |FROM bf1 LEFT SEMI | |
+ |JOIN (SELECT * FROM bf2 WHERE bf2.a2 = 62) tmp | |
+ |ON bf1.c1 = tmp.c2 | |
+ """.stripMargin) | |
+ } | |
+ } | |
+ | |
+ test("Merge runtime bloom filters") { | |
+ withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000", | |
+ SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", | |
+ SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", | |
+ // Re-enable `MergeScalarSubqueries` | |
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "", | |
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { | |
+ | |
+ val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + | |
+ "bf1.b1 = bf2.b2 where bf2.a2 = 62" | |
+ val df = sql(query) | |
+ df.collect() | |
+ val plan = df.queryExecution.executedPlan | |
+ | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 1, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala | |
index a090eba430..86c8c8261e 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala | |
@@ -27,19 +27,20 @@ import org.scalatest.Assertions._ | |
import org.apache.spark.TestUtils | |
import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils} | |
import org.apache.spark.broadcast.Broadcast | |
-import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} | |
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} | |
import org.apache.spark.sql.catalyst.plans.SQLHelper | |
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction | |
import org.apache.spark.sql.expressions.SparkUserDefinedFunction | |
-import org.apache.spark.sql.types.StringType | |
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType} | |
/** | |
- * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and | |
- * Scalar Pandas UDFs can be tested in SBT & Maven tests. | |
+ * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, | |
+ * Scalar Pandas UDF and Grouped Aggregate Pandas UDF can be tested in SBT & Maven tests. | |
* | |
- * The available UDFs are special. It defines an UDF wrapped by cast. So, the input column is | |
- * casted into string, UDF returns strings as are, and then output column is casted back to | |
- * the input column. In this way, UDF is virtually no-op. | |
+ * The available UDFs are special. For Scalar UDF, Python UDF and Scalar Pandas UDF, | |
+ * it defines an UDF wrapped by cast. So, the input column is casted into string, | |
+ * UDF returns strings as are, and then output column is casted back to the input column. | |
+ * In this way, UDF is virtually no-op. | |
* | |
* Note that, due to this implementation limitation, complex types such as map, array and struct | |
* types do not work with this UDFs because they cannot be same after the cast roundtrip. | |
@@ -69,6 +70,28 @@ import org.apache.spark.sql.types.StringType | |
* df.select(expr("udf_name(id)") | |
* df.select(pandasTestUDF(df("id"))) | |
* }}} | |
+ * | |
+ * For Grouped Aggregate Pandas UDF, it defines an UDF that calculates the count using pandas. | |
+ * The UDF returns the count of the given column. In this way, UDF is virtually not no-op. | |
+ * | |
+ * To register Grouped Aggregate Pandas UDF in SQL: | |
+ * {{{ | |
+ * val groupedAggPandasTestUDF = TestGroupedAggPandasUDF(name = "udf_name") | |
+ * registerTestUDF(groupedAggPandasTestUDF, spark) | |
+ * }}} | |
+ * | |
+ * To use it in Scala API and SQL: | |
+ * {{{ | |
+ * sql("SELECT udf_name(1)") | |
+ * val df = Seq( | |
+ * (536361, "85123A", 2, 17850), | |
+ * (536362, "85123B", 4, 17850), | |
+ * (536363, "86123A", 6, 17851) | |
+ * ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ * | |
+ * df.groupBy("CustomerID").agg(expr("udf_name(Quantity)")) | |
+ * df.groupBy("CustomerID").agg(groupedAggPandasTestUDF(df("Quantity"))) | |
+ * }}} | |
*/ | |
object IntegratedUDFTestUtils extends SQLHelper { | |
import scala.sys.process._ | |
@@ -190,6 +213,28 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") | |
} | |
+ private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { | |
+ var binaryPandasFunc: Array[Byte] = null | |
+ withTempPath { path => | |
+ Process( | |
+ Seq( | |
+ pythonExec, | |
+ "-c", | |
+ "from pyspark.sql.types import IntegerType; " + | |
+ "from pyspark.serializers import CloudPickleSerializer; " + | |
+ s"f = open('$path', 'wb');" + | |
+ "f.write(CloudPickleSerializer().dumps((" + | |
+ "lambda x: x.agg('count'), IntegerType())))"), | |
+ None, | |
+ "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! | |
+ binaryPandasFunc = Files.readAllBytes(path.toPath) | |
+ } | |
+ assert(binaryPandasFunc != null) | |
+ binaryPandasFunc | |
+ } else { | |
+ throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") | |
+ } | |
+ | |
// Make sure this map stays mutable - this map gets updated later in Python runners. | |
private val workerEnv = new java.util.HashMap[String, String]() | |
workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") | |
@@ -209,6 +254,8 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
lazy val shouldTestScalarPandasUDFs: Boolean = | |
isPythonAvailable && isPandasAvailable && isPyArrowAvailable | |
+ lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs | |
+ | |
/** | |
* A base trait for various UDFs defined in this object. | |
*/ | |
@@ -218,6 +265,29 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
val prettyName: String | |
} | |
+ class PythonUDFWithoutId( | |
+ name: String, | |
+ func: PythonFunction, | |
+ dataType: DataType, | |
+ children: Seq[Expression], | |
+ evalType: Int, | |
+ udfDeterministic: Boolean, | |
+ resultId: ExprId) | |
+ extends PythonUDF(name, func, dataType, children, evalType, udfDeterministic, resultId) { | |
+ | |
+ def this(pudf: PythonUDF) = { | |
+ this(pudf.name, pudf.func, pudf.dataType, pudf.children, | |
+ pudf.evalType, pudf.udfDeterministic, pudf.resultId) | |
+ } | |
+ | |
+ override def toString: String = s"$name(${children.mkString(", ")})" | |
+ | |
+ override protected def withNewChildrenInternal( | |
+ newChildren: IndexedSeq[Expression]): PythonUDFWithoutId = { | |
+ new PythonUDFWithoutId(super.withNewChildrenInternal(newChildren)) | |
+ } | |
+ } | |
+ | |
/** | |
* A Python UDF that takes one column, casts into string, executes the Python native function, | |
* and casts back to the type of input column. | |
@@ -253,7 +323,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
val expr = e.head | |
assert(expr.resolved, "column should be resolved to use the same type " + | |
"as input. Try df(name) or df.col(name)") | |
- Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) | |
+ val pythonUDF = new PythonUDFWithoutId( | |
+ super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) | |
+ Cast(pythonUDF, expr.dataType) | |
} | |
} | |
@@ -297,7 +369,9 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
val expr = e.head | |
assert(expr.resolved, "column should be resolved to use the same type " + | |
"as input. Try df(name) or df.col(name)") | |
- Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) | |
+ val pythonUDF = new PythonUDFWithoutId( | |
+ super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) | |
+ Cast(pythonUDF, expr.dataType) | |
} | |
} | |
@@ -306,6 +380,46 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
val prettyName: String = "Scalar Pandas UDF" | |
} | |
+ /** | |
+ * A Grouped Aggregate Pandas UDF that takes one column, executes the | |
+ * Python native function calculating the count of the column using pandas. | |
+ * | |
+ * Virtually equivalent to: | |
+ * | |
+ * {{{ | |
+ * import pandas as pd | |
+ * from pyspark.sql.functions import pandas_udf | |
+ * | |
+ * df = spark.createDataFrame( | |
+ * [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) | |
+ * | |
+ * @pandas_udf("double") | |
+ * def pandas_count(v: pd.Series) -> int: | |
+ * return v.count() | |
+ * | |
+ * count_col = pandas_count(df['v']) | |
+ * }}} | |
+ */ | |
+ case class TestGroupedAggPandasUDF(name: String) extends TestUDF { | |
+ private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( | |
+ name = name, | |
+ func = PythonFunction( | |
+ command = pandasGroupedAggFunc, | |
+ envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], | |
+ pythonIncludes = List.empty[String].asJava, | |
+ pythonExec = pythonExec, | |
+ pythonVer = pythonVer, | |
+ broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, | |
+ accumulator = null), | |
+ dataType = IntegerType, | |
+ pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, | |
+ udfDeterministic = true) | |
+ | |
+ def apply(exprs: Column*): Column = udf(exprs: _*) | |
+ | |
+ val prettyName: String = "Grouped Aggregate Pandas UDF" | |
+ } | |
+ | |
/** | |
* A Scala UDF that takes one column, casts into string, executes the | |
* Scala native function, and casts back to the type of input column. | |
@@ -360,6 +474,7 @@ object IntegratedUDFTestUtils extends SQLHelper { | |
def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match { | |
case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf) | |
case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf) | |
+ case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf) | |
case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf) | |
case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]") | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala | |
index 9f4c24b46a..1792b4c32e 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala | |
@@ -17,7 +17,7 @@ | |
package org.apache.spark.sql | |
-import org.apache.log4j.Level | |
+import org.apache.logging.log4j.Level | |
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, EliminateResolvedHint} | |
import org.apache.spark.sql.catalyst.plans.PlanTest | |
@@ -55,7 +55,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP | |
} | |
val warningMessages = logAppender.loggingEvents | |
.filter(_.getLevel == Level.WARN) | |
- .map(_.getRenderedMessage) | |
+ .map(_.getMessage.getFormattedMessage) | |
.filter(_.contains("hint")) | |
assert(warningMessages.size == warnings.size) | |
warnings.foreach { w => | |
@@ -597,4 +597,115 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP | |
assert(df9.collect().size == df10.collect().size) | |
} | |
} | |
+ | |
+ test("SPARK-35221: Add join hint build side check") { | |
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { | |
+ Seq("left_outer", "left_semi", "left_anti").foreach { joinType => | |
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertShuffleMergeJoin( | |
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType)) | |
+ assertShuffleMergeJoin( | |
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType)) | |
+ } | |
+ | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.size === 2) | |
+ logs.foreach(log => | |
+ assert(log.contains(s"build left for ${joinType.split("_").mkString(" ")} join."))) | |
+ } | |
+ | |
+ Seq("left_outer", "left_semi", "left_anti").foreach { joinType => | |
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertBroadcastHashJoin( | |
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType), BuildRight) | |
+ assertShuffleHashJoin( | |
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType), BuildRight) | |
+ } | |
+ | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.isEmpty) | |
+ } | |
+ | |
+ Seq("right_outer").foreach { joinType => | |
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertShuffleMergeJoin( | |
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType)) | |
+ assertShuffleMergeJoin( | |
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType)) | |
+ } | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.size === 2) | |
+ logs.foreach(log => | |
+ assert(log.contains(s"build right for ${joinType.split("_").mkString(" ")} join."))) | |
+ } | |
+ | |
+ Seq("right_outer").foreach { joinType => | |
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertBroadcastHashJoin( | |
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType), BuildLeft) | |
+ assertShuffleHashJoin( | |
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType), BuildLeft) | |
+ } | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.isEmpty) | |
+ } | |
+ | |
+ Seq("inner", "cross").foreach { joinType => | |
+ val hintAppender = new LogAppender(s"join hint build side check for $joinType") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertBroadcastHashJoin( | |
+ df1.hint("BROADCAST").join(df2, $"a1" === $"b1", joinType), BuildLeft) | |
+ assertBroadcastHashJoin( | |
+ df1.join(df2.hint("BROADCAST"), $"a1" === $"b1", joinType), BuildRight) | |
+ | |
+ assertShuffleHashJoin( | |
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" === $"b1", joinType), BuildLeft) | |
+ assertShuffleHashJoin( | |
+ df1.join(df2.hint("SHUFFLE_HASH"), $"a1" === $"b1", joinType), BuildRight) | |
+ } | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.isEmpty) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-35221: Add join hint non equi-join check") { | |
+ val hintAppender = new LogAppender(s"join hint check for equi-join") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertBroadcastNLJoin( | |
+ df1.hint("SHUFFLE_HASH").join(df2, $"a1" !== $"b1"), BuildRight) | |
+ } | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ assertBroadcastNLJoin( | |
+ df1.join(df2.hint("MERGE"), $"a1" !== $"b1"), BuildRight) | |
+ } | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query:")) | |
+ assert(logs.size === 2) | |
+ logs.foreach(log => assert(log.contains("no equi-join keys"))) | |
+ } | |
+ | |
+ test("SPARK-36652: AQE dynamic join selection should not apply to non-equi join") { | |
+ val hintAppender = new LogAppender(s"join hint check for equi-join") | |
+ withLogAppender(hintAppender, level = Some(Level.WARN)) { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "64MB") { | |
+ df1.join(df2.repartition($"b1"), $"a1" =!= $"b1").collect() | |
+ } | |
+ val logs = hintAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ .filter(_.contains("is not supported in the query: no equi-join keys")) | |
+ assert(logs.isEmpty) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | |
index abfc19ac6d..4a8421a221 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | |
@@ -183,7 +183,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan | |
test("inner join where, one match per row") { | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
checkAnswer( | |
- upperCaseData.join(lowerCaseData).where('n === 'N), | |
+ upperCaseData.join(lowerCaseData).where(Symbol("n") === 'N), | |
Seq( | |
Row(1, "A", 1, "a"), | |
Row(2, "B", 2, "b"), | |
@@ -404,8 +404,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan | |
test("full outer join") { | |
withTempView("`left`", "`right`") { | |
- upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") | |
- upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") | |
+ upperCaseData.where(Symbol("N") <= 4).createOrReplaceTempView("`left`") | |
+ upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") | |
val left = UnresolvedRelation(TableIdentifier("left")) | |
val right = UnresolvedRelation(TableIdentifier("right")) | |
@@ -623,7 +623,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan | |
testData.createOrReplaceTempView("B") | |
testData2.createOrReplaceTempView("C") | |
testData3.createOrReplaceTempView("D") | |
- upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") | |
+ upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") | |
val cartesianQueries = Seq( | |
/** The following should error out since there is no explicit cross join */ | |
"SELECT * FROM testData inner join testData2", | |
@@ -1074,8 +1074,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan | |
val df = left.crossJoin(right).where(pythonTestUDF(left("a")) === right.col("c")) | |
// Before optimization, there is a logical Filter operator. | |
- val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter]) | |
- assert(filterInAnalysis.isDefined) | |
+ val filterInAnalysis = df.queryExecution.analyzed.exists(_.isInstanceOf[Filter]) | |
+ assert(filterInAnalysis) | |
// Filter predicate was pushdown as join condition. So there is no Filter exec operator. | |
val filterExec = find(df.queryExecution.executedPlan)(_.isInstanceOf[FilterExec]) | |
@@ -1402,4 +1402,42 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan | |
assertJoin(sql, classOf[ShuffledHashJoinExec]) | |
} | |
} | |
+ | |
+ test("SPARK-36794: Ignore duplicated key when building relation for semi/anti hash join") { | |
+ withTable("t1", "t2") { | |
+ spark.range(10).map(i => (i.toString, i + 1)).toDF("c1", "c2").write.saveAsTable("t1") | |
+ spark.range(10).map(i => ((i % 5).toString, i % 3)).toDF("c1", "c2").write.saveAsTable("t2") | |
+ | |
+ val semiJoinQueries = Seq( | |
+ // No join condition, ignore duplicated key. | |
+ (s"SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2 ON t1.c1 = t2.c1", | |
+ true), | |
+ // Have join condition on build join key only, ignore duplicated key. | |
+ (s""" | |
+ |SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2 | |
+ |ON t1.c1 = t2.c1 AND CAST(t1.c2 * 2 AS STRING) != t2.c1 | |
+ """.stripMargin, | |
+ true), | |
+ // Have join condition on other build attribute beside join key, do not ignore | |
+ // duplicated key. | |
+ (s""" | |
+ |SELECT /*+ SHUFFLE_HASH(t2) */ t1.c1 FROM t1 LEFT SEMI JOIN t2 | |
+ |ON t1.c1 = t2.c1 AND t1.c2 * 100 != t2.c2 | |
+ """.stripMargin, | |
+ false) | |
+ ) | |
+ semiJoinQueries.foreach { | |
+ case (query, ignoreDuplicatedKey) => | |
+ val semiJoinDF = sql(query) | |
+ val antiJoinDF = sql(query.replaceAll("SEMI", "ANTI")) | |
+ checkAnswer(semiJoinDF, Seq(Row("0"), Row("1"), Row("2"), Row("3"), Row("4"))) | |
+ checkAnswer(antiJoinDF, Seq(Row("5"), Row("6"), Row("7"), Row("8"), Row("9"))) | |
+ Seq(semiJoinDF, antiJoinDF).foreach { df => | |
+ assert(collect(df.queryExecution.executedPlan) { | |
+ case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true | |
+ }.size == 1) | |
+ } | |
+ } | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala | |
index 82cca2b737..1c6bbc5a09 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala | |
@@ -390,16 +390,16 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("SPARK-24027: from_json of a map with unsupported key type") { | |
val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) | |
- | |
- checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), | |
- Row(null)) | |
- checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), | |
- Row(null)) | |
+ val startMsg = "cannot resolve 'entries' due to data type mismatch:" | |
+ val exception = intercept[AnalysisException] { | |
+ Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)) | |
+ }.getMessage | |
+ assert(exception.contains(startMsg)) | |
} | |
test("SPARK-24709: infers schemas of json strings and pass them to from_json") { | |
val in = Seq("""{"a": [1, 2, 3]}""").toDS() | |
- val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed") | |
+ val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a": [1]}""")) as "parsed") | |
val expected = StructType(StructField( | |
"parsed", | |
StructType(StructField( | |
@@ -413,7 +413,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("infers schemas using options") { | |
val df = spark.range(1) | |
.select(schema_of_json(lit("{a:1}"), Map("allowUnquotedFieldNames" -> "true").asJava)) | |
- checkAnswer(df, Seq(Row("STRUCT<`a`: BIGINT>"))) | |
+ checkAnswer(df, Seq(Row("STRUCT<a: BIGINT>"))) | |
} | |
test("from_json - array of primitive types") { | |
@@ -595,6 +595,31 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
+ test("SPARK-36069: from_json invalid json schema - check field name and field value") { | |
+ withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { | |
+ val schema = new StructType() | |
+ .add("a", IntegerType) | |
+ .add("b", IntegerType) | |
+ .add("_unparsed", StringType) | |
+ val badRec = """{"a": "1", "b": 11}""" | |
+ val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS() | |
+ | |
+ checkAnswer( | |
+ df.select(from_json($"value", schema, Map("mode" -> "PERMISSIVE"))), | |
+ Row(Row(null, 11, badRec)) :: Row(Row(2, 12, null)) :: Nil) | |
+ | |
+ val errMsg = intercept[SparkException] { | |
+ df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))).collect() | |
+ }.getMessage | |
+ | |
+ assert(errMsg.contains( | |
+ "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) | |
+ assert(errMsg.contains( | |
+ "Failed to parse field name a, field value 1, " + | |
+ "[VALUE_STRING] to target spark data type [IntegerType].")) | |
+ } | |
+ } | |
+ | |
test("corrupt record column in the middle") { | |
val schema = new StructType() | |
.add("a", IntegerType) | |
@@ -668,14 +693,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
val input = regexp_replace(lit("""{"item_id": 1, "item_price": 0.1}"""), "item_", "") | |
checkAnswer( | |
spark.range(1).select(schema_of_json(input)), | |
- Seq(Row("STRUCT<`id`: BIGINT, `price`: DOUBLE>"))) | |
+ Seq(Row("STRUCT<id: BIGINT, price: DOUBLE>"))) | |
} | |
test("SPARK-31065: schema_of_json - null and empty strings as strings") { | |
Seq("""{"id": null}""", """{"id": ""}""").foreach { input => | |
checkAnswer( | |
spark.range(1).select(schema_of_json(input)), | |
- Seq(Row("STRUCT<`id`: STRING>"))) | |
+ Seq(Row("STRUCT<id: STRING>"))) | |
} | |
} | |
@@ -687,7 +712,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
schema_of_json( | |
lit("""{"id": "a", "drop": {"drop": null}}"""), | |
options.asJava)), | |
- Seq(Row("STRUCT<`id`: STRING>"))) | |
+ Seq(Row("STRUCT<id: STRING>"))) | |
// Array of structs | |
checkAnswer( | |
@@ -695,7 +720,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { | |
schema_of_json( | |
lit("""[{"id": "a", "drop": {"drop": null}}]"""), | |
options.asJava)), | |
- Seq(Row("ARRAY<STRUCT<`id`: STRING>>"))) | |
+ Seq(Row("ARRAY<STRUCT<id: STRING>>"))) | |
// Other types are not affected. | |
checkAnswer( | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala | |
index 1e7e8c1c5e..88071da293 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala | |
@@ -18,6 +18,7 @@ | |
package org.apache.spark.sql | |
import java.nio.charset.StandardCharsets | |
+import java.time.Period | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.functions.{log => logarithm} | |
@@ -46,12 +47,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
c: Column => Column, | |
f: T => U): Unit = { | |
checkAnswer( | |
- doubleData.select(c('a)), | |
+ doubleData.select(c(Symbol("a"))), | |
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) | |
) | |
checkAnswer( | |
- doubleData.select(c('b)), | |
+ doubleData.select(c(Symbol("b"))), | |
(1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) | |
) | |
@@ -64,13 +65,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = | |
{ | |
checkAnswer( | |
- nnDoubleData.select(c('a)), | |
+ nnDoubleData.select(c(Symbol("a"))), | |
(1 to 10).map(n => Row(f(n * 0.1))) | |
) | |
if (f(-1) === StrictMath.log1p(-1)) { | |
checkAnswer( | |
- nnDoubleData.select(c('b)), | |
+ nnDoubleData.select(c(Symbol("b"))), | |
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) | |
) | |
} | |
@@ -86,12 +87,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
d: (Column, Double) => Column, | |
f: (Double, Double) => Double): Unit = { | |
checkAnswer( | |
- nnDoubleData.select(c('a, 'a)), | |
+ nnDoubleData.select(c('a, Symbol("a"))), | |
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) | |
) | |
checkAnswer( | |
- nnDoubleData.select(c('a, 'b)), | |
+ nnDoubleData.select(c('a, Symbol("b"))), | |
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) | |
) | |
@@ -108,7 +109,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) | |
checkAnswer( | |
- nullDoubles.select(c('a, 'a)).orderBy('a.asc), | |
+ nullDoubles.select(c('a, Symbol("a"))).orderBy(Symbol("a").asc), | |
Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) | |
) | |
} | |
@@ -117,6 +118,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
testOneToOneMathFunction(sin, math.sin) | |
} | |
+ test("csc") { | |
+ testOneToOneMathFunction(csc, | |
+ (x: Double) => (1 / math.sin(x)) ) | |
+ } | |
+ | |
test("asin") { | |
testOneToOneMathFunction(asin, math.asin) | |
} | |
@@ -134,6 +140,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
testOneToOneMathFunction(cos, math.cos) | |
} | |
+ test("sec") { | |
+ testOneToOneMathFunction(sec, | |
+ (x: Double) => (1 / math.cos(x)) ) | |
+ } | |
+ | |
test("acos") { | |
testOneToOneMathFunction(acos, math.acos) | |
} | |
@@ -151,6 +162,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
testOneToOneMathFunction(tan, math.tan) | |
} | |
+ test("cot") { | |
+ testOneToOneMathFunction(cot, | |
+ (x: Double) => (1 / math.tan(x)) ) | |
+ } | |
+ | |
test("atan") { | |
testOneToOneMathFunction(atan, math.atan) | |
} | |
@@ -186,6 +202,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("ceil and ceiling") { | |
testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong) | |
+ // testOneToOneMathFunction does not validate the resulting data type | |
+ assert( | |
+ spark.range(1).select(ceil(col("id")).alias("a")).schema == | |
+ types.StructType(Seq(types.StructField("a", types.LongType)))) | |
+ assert( | |
+ spark.range(1).select(ceil(col("id"), lit(0)).alias("a")).schema == | |
+ types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0))))) | |
checkAnswer( | |
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), | |
Row(0L, 1L, 2L)) | |
@@ -234,12 +257,19 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("floor") { | |
testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong) | |
+ // testOneToOneMathFunction does not validate the resulting data type | |
+ assert( | |
+ spark.range(1).select(floor(col("id")).alias("a")).schema == | |
+ types.StructType(Seq(types.StructField("a", types.LongType)))) | |
+ assert( | |
+ spark.range(1).select(floor(col("id"), lit(0)).alias("a")).schema == | |
+ types.StructType(Seq(types.StructField("a", types.DecimalType(21, 0))))) | |
} | |
test("factorial") { | |
val df = (0 to 5).map(i => (i, i)).toDF("a", "b") | |
checkAnswer( | |
- df.select(factorial('a)), | |
+ df.select(factorial(Symbol("a"))), | |
Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) | |
) | |
checkAnswer( | |
@@ -252,16 +282,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
testOneToOneMathFunction(rint, math.rint) | |
} | |
- test("round/bround") { | |
+ test("round/bround/ceil/floor") { | |
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") | |
checkAnswer( | |
- df.select(round('a), round('a, -1), round('a, -2)), | |
+ df.select(round(Symbol("a")), round('a, -1), round('a, -2)), | |
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) | |
) | |
checkAnswer( | |
- df.select(bround('a), bround('a, -1), bround('a, -2)), | |
+ df.select(bround(Symbol("a")), bround('a, -1), bround('a, -2)), | |
Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) | |
) | |
+ checkAnswer( | |
+ df.select(ceil('a), ceil('a, lit(-1)), ceil('a, lit(-2))), | |
+ Seq(Row(5, 10, 100), Row(55, 60, 100), Row(555, 560, 600)) | |
+ ) | |
+ checkAnswer( | |
+ df.select(floor('a), floor('a, lit(-1)), floor('a, lit(-2))), | |
+ Seq(Row(5, 0, 0), Row(55, 50, 0), Row(555, 550, 500)) | |
+ ) | |
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") { | |
val pi = "3.1415" | |
@@ -277,6 +315,18 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), | |
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) | |
) | |
+ checkAnswer( | |
+ sql(s"SELECT ceil($pi), ceil($pi, -3), ceil($pi, -2), ceil($pi, -1), " + | |
+ s"ceil($pi, 0), ceil($pi, 1), ceil($pi, 2), ceil($pi, 3)"), | |
+ Seq(Row(BigDecimal(4), BigDecimal("1E3"), BigDecimal("1E2"), BigDecimal("1E1"), | |
+ BigDecimal(4), BigDecimal("3.2"), BigDecimal("3.15"), BigDecimal("3.142"))) | |
+ ) | |
+ checkAnswer( | |
+ sql(s"SELECT floor($pi), floor($pi, -3), floor($pi, -2), floor($pi, -1), " + | |
+ s"floor($pi, 0), floor($pi, 1), floor($pi, 2), floor($pi, 3)"), | |
+ Seq(Row(BigDecimal(3), BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), | |
+ BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.141"))) | |
+ ) | |
} | |
val bdPi: BigDecimal = BigDecimal(31415925L, 7) | |
@@ -291,21 +341,46 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), | |
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) | |
) | |
+ checkAnswer( | |
+ sql(s"SELECT ceil($bdPi, 7), ceil($bdPi, 8), ceil($bdPi, 9), ceil($bdPi, 10), " + | |
+ s"ceil($bdPi, 100), ceil($bdPi, 6), ceil(null, 8)"), | |
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) | |
+ ) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT floor($bdPi, 7), floor($bdPi, 8), floor($bdPi, 9), floor($bdPi, 10), " + | |
+ s"floor($bdPi, 100), floor($bdPi, 6), floor(null, 8)"), | |
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) | |
+ ) | |
} | |
- test("round/bround with data frame from a local Seq of Product") { | |
+ test("round/bround/ceil/floor with data frame from a local Seq of Product") { | |
val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") | |
checkAnswer( | |
- df.withColumn("value_rounded", round('value)), | |
+ df.withColumn("value_rounded", round(Symbol("value"))), | |
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) | |
) | |
checkAnswer( | |
- df.withColumn("value_brounded", bround('value)), | |
+ df.withColumn("value_brounded", bround(Symbol("value"))), | |
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) | |
) | |
+ checkAnswer( | |
+ df | |
+ .withColumn("value_ceil", ceil('value)) | |
+ .withColumn("value_ceil1", ceil('value, lit(0))) | |
+ .withColumn("value_ceil2", ceil('value, lit(1))), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"), BigDecimal("6"), BigDecimal("5.9"))) | |
+ ) | |
+ checkAnswer( | |
+ df | |
+ .withColumn("value_floor", floor('value)) | |
+ .withColumn("value_floor1", floor('value, lit(0))) | |
+ .withColumn("value_floor2", floor('value, lit(1))), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5"), BigDecimal("5"), BigDecimal("5.9"))) | |
+ ) | |
} | |
- test("round/bround with table columns") { | |
+ test("round/bround/ceil/floor with table columns") { | |
withTable("t") { | |
Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") | |
checkAnswer( | |
@@ -314,6 +389,24 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
checkAnswer( | |
sql("select i, bround(i) from t"), | |
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) | |
+ checkAnswer( | |
+ sql("select i, ceil(i) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) | |
+ checkAnswer( | |
+ sql("select i, ceil(i, 0) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) | |
+ checkAnswer( | |
+ sql("select i, ceil(i, 1) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5.9")))) | |
+ checkAnswer( | |
+ sql("select i, floor(i) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5")))) | |
+ checkAnswer( | |
+ sql("select i, floor(i, 0) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5")))) | |
+ checkAnswer( | |
+ sql("select i, floor(i, 1) from t"), | |
+ Seq(Row(BigDecimal("5.9"), BigDecimal("5.9")))) | |
} | |
} | |
@@ -344,10 +437,10 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("hex") { | |
val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") | |
- checkAnswer(data.select(hex('a)), Seq(Row("1C"))) | |
- checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) | |
- checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) | |
- checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) | |
+ checkAnswer(data.select(hex(Symbol("a"))), Seq(Row("1C"))) | |
+ checkAnswer(data.select(hex(Symbol("b"))), Seq(Row("FFFFFFFFFFFFFFE4"))) | |
+ checkAnswer(data.select(hex(Symbol("c"))), Seq(Row("177828FED4"))) | |
+ checkAnswer(data.select(hex(Symbol("d"))), Seq(Row("68656C6C6F"))) | |
checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) | |
checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) | |
checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) | |
@@ -357,8 +450,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
test("unhex") { | |
val data = Seq(("1C", "737472696E67")).toDF("a", "b") | |
- checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) | |
- checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8))) | |
+ checkAnswer(data.select(unhex(Symbol("a"))), Row(Array[Byte](28.toByte))) | |
+ checkAnswer(data.select(unhex(Symbol("b"))), Row("string".getBytes(StandardCharsets.UTF_8))) | |
checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) | |
checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8))) | |
checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) | |
@@ -505,4 +598,20 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { | |
checkAnswer(df.selectExpr("positive(a)"), Row(1)) | |
checkAnswer(df.selectExpr("positive(b)"), Row(-1)) | |
} | |
+ | |
+ test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") { | |
+ Seq( | |
+ (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10) -> 0, | |
+ (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10) -> 1, | |
+ (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10) -> 2, | |
+ (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10) -> 2, | |
+ (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10) -> 11, | |
+ (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10) -> 11, | |
+ (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 6, | |
+ (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 5 | |
+ ).foreach { case ((value, start, end, num), expected) => | |
+ val df = Seq((value, start, end, num)).toDF("v", "s", "e", "n") | |
+ checkAnswer(df.selectExpr("width_bucket(v, s, e, n)"), Row(expected)) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala | |
index b166b9b684..37ba52023d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala | |
@@ -21,6 +21,7 @@ import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} | |
import org.apache.spark.sql.catalyst.parser.ParseException | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.types.BinaryType | |
class MiscFunctionsSuite extends QueryTest with SharedSparkSession { | |
import testImplicits._ | |
@@ -49,13 +50,32 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { | |
val df = sql("select current_user(), current_user") | |
checkAnswer(df, Row(user, user)) | |
} | |
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true", | |
+ SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") { | |
val df = sql("select current_user") | |
checkAnswer(df, Row(spark.sparkContext.sparkUser)) | |
val e = intercept[ParseException](sql("select current_user()")) | |
assert(e.getMessage.contains("current_user")) | |
} | |
} | |
+ | |
+ test("SPARK-37591: AES functions - GCM mode") { | |
+ Seq( | |
+ ("abcdefghijklmnop", ""), | |
+ ("abcdefghijklmnop", "abcdefghijklmnop"), | |
+ ("abcdefghijklmnop12345678", "Spark"), | |
+ ("abcdefghijklmnop12345678ABCDEFGH", "GCM mode") | |
+ ).foreach { case (key, input) => | |
+ val df = Seq((key, input)).toDF("key", "input") | |
+ val encrypted = df.selectExpr("aes_encrypt(input, key, 'GCM', 'NONE') AS enc", "input", "key") | |
+ assert(encrypted.schema("enc").dataType === BinaryType) | |
+ assert(encrypted.filter($"enc" === $"input").isEmpty) | |
+ val result = encrypted.selectExpr( | |
+ "CAST(aes_decrypt(enc, key, 'GCM', 'NONE') AS STRING) AS res", "input") | |
+ assert(!result.filter($"res" === $"input").isEmpty && | |
+ result.filter($"res" =!= $"input").isEmpty) | |
+ } | |
+ } | |
} | |
object ReflectClass { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala | |
new file mode 100644 | |
index 0000000000..823c1375de | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala | |
@@ -0,0 +1,49 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import java.time.{Duration, Period} | |
+ | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+/** | |
+ * End-to-end tests for percentile aggregate function. | |
+ */ | |
+class PercentileQuerySuite extends QueryTest with SharedSparkSession { | |
+ import testImplicits._ | |
+ | |
+ private val table = "percentile_test" | |
+ | |
+ test("SPARK-37138, SPARK-39427: Disable Ansi Interval type in Percentile") { | |
+ withTempView(table) { | |
+ Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), | |
+ (Period.ofMonths(200), Duration.ofSeconds(200L)), | |
+ (Period.ofMonths(300), Duration.ofSeconds(300L))) | |
+ .toDF("col1", "col2").createOrReplaceTempView(table) | |
+ val e = intercept[AnalysisException] { | |
+ spark.sql( | |
+ s"""SELECT | |
+ | CAST(percentile(col1, 0.5) AS STRING), | |
+ | SUM(null), | |
+ | CAST(percentile(col2, 0.5) AS STRING) | |
+ |FROM $table""".stripMargin).collect() | |
+ } | |
+ assert(e.getMessage.contains("data type mismatch")) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala | |
index 69a8c72153..8cbb841e7d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala | |
@@ -30,6 +30,7 @@ import org.apache.spark.sql.execution._ | |
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite | |
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec, ValidateRequirements} | |
import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.tags.ExtendedSQLTest | |
// scalastyle:off line.size.limit | |
/** | |
@@ -47,38 +48,28 @@ import org.apache.spark.sql.internal.SQLConf | |
* | |
* To run the entire test suite: | |
* {{{ | |
- * build/sbt "sql/testOnly *PlanStability[WithStats]Suite" | |
+ * build/sbt "sql/testOnly *PlanStability*Suite" | |
* }}} | |
* | |
* To run a single test file upon change: | |
* {{{ | |
- * build/sbt "sql/testOnly *PlanStability[WithStats]Suite -- -z (tpcds-v1.4/q49)" | |
+ * build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)" | |
* }}} | |
* | |
* To re-generate golden files for entire suite, run: | |
* {{{ | |
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability[WithStats]Suite" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability*Suite" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 SPARK_ANSI_SQL_MODE=true build/sbt "sql/testOnly *PlanStability*Suite" | |
* }}} | |
* | |
* To re-generate golden file for a single test, run: | |
* {{{ | |
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability[WithStats]Suite -- -z (tpcds-v1.4/q49)" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 SPARK_ANSI_SQL_MODE=true build/sbt "sql/testOnly *PlanStability*Suite -- -z (tpcds-v1.4/q49)" | |
* }}} | |
*/ | |
// scalastyle:on line.size.limit | |
-trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite { | |
- | |
- private val originalMaxToStringFields = conf.maxToStringFields | |
- | |
- override def beforeAll(): Unit = { | |
- conf.setConf(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) | |
- super.beforeAll() | |
- } | |
- | |
- override def afterAll(): Unit = { | |
- super.afterAll() | |
- conf.setConf(SQLConf.MAX_TO_STRING_FIELDS, originalMaxToStringFields) | |
- } | |
+trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { | |
private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" | |
@@ -89,13 +80,24 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite { | |
private val referenceRegex = "#\\d+".r | |
private val normalizeRegex = "#\\d+L?".r | |
+ private val planIdRegex = "plan_id=\\d+".r | |
private val clsName = this.getClass.getCanonicalName | |
def goldenFilePath: String | |
+ private val approvedAnsiPlans: Seq[String] = Seq( | |
+ "q83", | |
+ "q83.sf100" | |
+ ) | |
+ | |
private def getDirForTest(name: String): File = { | |
- new File(goldenFilePath, name) | |
+ val goldenFileName = if (SQLConf.get.ansiEnabled && approvedAnsiPlans.contains(name)) { | |
+ name + ".ansi" | |
+ } else { | |
+ name | |
+ } | |
+ new File(goldenFilePath, goldenFileName) | |
} | |
private def isApproved( | |
@@ -233,7 +235,15 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite { | |
val map = new mutable.HashMap[String, String]() | |
normalizeRegex.findAllMatchIn(plan).map(_.toString) | |
.foreach(map.getOrElseUpdate(_, (map.size + 1).toString)) | |
- normalizeRegex.replaceAllIn(plan, regexMatch => s"#${map(regexMatch.toString)}") | |
+ val exprIdNormalized = normalizeRegex.replaceAllIn( | |
+ plan, regexMatch => s"#${map(regexMatch.toString)}") | |
+ | |
+ // Normalize the plan id in Exchange nodes. See `Exchange.stringArgs`. | |
+ val planIdMap = new mutable.HashMap[String, String]() | |
+ planIdRegex.findAllMatchIn(exprIdNormalized).map(_.toString) | |
+ .foreach(planIdMap.getOrElseUpdate(_, (planIdMap.size + 1).toString)) | |
+ planIdRegex.replaceAllIn( | |
+ exprIdNormalized, regexMatch => s"plan_id=${planIdMap(regexMatch.toString)}") | |
} | |
private def normalizeLocation(plan: String): String = { | |
@@ -262,7 +272,8 @@ trait PlanStabilitySuite extends TPCDSBase with DisableAdaptiveExecutionSuite { | |
} | |
} | |
-class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite with TPCDSBase { | |
override val goldenFilePath: String = | |
new File(baseResourcePath, s"approved-plans-v1_4").getAbsolutePath | |
@@ -273,7 +284,8 @@ class TPCDSV1_4_PlanStabilitySuite extends PlanStabilitySuite { | |
} | |
} | |
-class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase { | |
override def injectStats: Boolean = true | |
override val goldenFilePath: String = | |
@@ -286,7 +298,8 @@ class TPCDSV1_4_PlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
} | |
} | |
-class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite with TPCDSBase { | |
override val goldenFilePath: String = | |
new File(baseResourcePath, s"approved-plans-v2_7").getAbsolutePath | |
@@ -297,7 +310,8 @@ class TPCDSV2_7_PlanStabilitySuite extends PlanStabilitySuite { | |
} | |
} | |
-class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase { | |
override def injectStats: Boolean = true | |
override val goldenFilePath: String = | |
@@ -310,7 +324,8 @@ class TPCDSV2_7_PlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
} | |
} | |
-class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite with TPCDSBase { | |
override val goldenFilePath: String = | |
new File(baseResourcePath, s"approved-plans-modified").getAbsolutePath | |
@@ -321,7 +336,8 @@ class TPCDSModifiedPlanStabilitySuite extends PlanStabilitySuite { | |
} | |
} | |
-class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
+@ExtendedSQLTest | |
+class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite with TPCDSBase { | |
override def injectStats: Boolean = true | |
override val goldenFilePath: String = | |
@@ -333,3 +349,15 @@ class TPCDSModifiedPlanStabilityWithStatsSuite extends PlanStabilitySuite { | |
} | |
} | |
} | |
+ | |
+@ExtendedSQLTest | |
+class TPCHPlanStabilitySuite extends PlanStabilitySuite with TPCHBase { | |
+ override def goldenFilePath: String = getWorkspaceFilePath( | |
+ "sql", "core", "src", "test", "resources", "tpch-plan-stability").toFile.getAbsolutePath | |
+ | |
+ tpchQueries.foreach { q => | |
+ test(s"check simplified (tpch/$q)") { | |
+ testQuery("tpch", q) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | |
index 8469216901..06f94c62d9 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | |
@@ -207,11 +207,12 @@ abstract class QueryTest extends PlanTest { | |
*/ | |
def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = { | |
val planWithCaching = query.queryExecution.withCachedData | |
- val matched = planWithCaching.collectFirst { case cached: InMemoryRelation => | |
- val cacheBuilder = cached.asInstanceOf[InMemoryRelation].cacheBuilder | |
- cachedName == cacheBuilder.tableName.get && | |
- (storageLevel == cacheBuilder.storageLevel) | |
- }.getOrElse(false) | |
+ val matched = planWithCaching.exists { | |
+ case cached: InMemoryRelation => | |
+ val cacheBuilder = cached.cacheBuilder | |
+ cachedName == cacheBuilder.tableName.get && (storageLevel == cacheBuilder.storageLevel) | |
+ case _ => false | |
+ } | |
assert(matched, s"Expected query plan to hit cache $cachedName with storage " + | |
s"level $storageLevel, but it doesn't.") | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala | |
index 739b4052ee..8883e9be19 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala | |
@@ -59,13 +59,13 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared | |
val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out") | |
checkAnswer(q5, Row(1) :: Row(1) :: Nil) | |
q5.queryExecution.executedPlan.foreach { p => | |
- assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty)) | |
+ assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[If]))) | |
} | |
val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END") | |
checkAnswer(q6, Row(2) :: Row(2) :: Nil) | |
q6.queryExecution.executedPlan.foreach { p => | |
- assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty)) | |
+ assert(p.expressions.forall(e => !e.exists(_.isInstanceOf[CaseWhen]))) | |
} | |
checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) | |
@@ -75,10 +75,10 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared | |
test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { | |
def assertNoLiteralNullInPlan(df: DataFrame): Unit = { | |
df.queryExecution.executedPlan.foreach { p => | |
- assert(p.expressions.forall(_.find { | |
+ assert(p.expressions.forall(!_.exists { | |
case Literal(null, BooleanType) => true | |
case _ => false | |
- }.isEmpty)) | |
+ })) | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala | |
index 2f56fbaf7f..f30465203d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala | |
@@ -264,7 +264,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { | |
val e = intercept[AnalysisException] { | |
sql("INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)") | |
} | |
- assert(e.getMessage.contains("Found duplicate keys 'c'")) | |
+ assert(e.getMessage.contains("Found duplicate keys `c`")) | |
} | |
// The following code is skipped for Hive because columns stored in Hive Metastore is always | |
// case insensitive and we cannot create such table in Hive Metastore. | |
@@ -286,20 +286,44 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { | |
} else { | |
SQLConf.StoreAssignmentPolicy.values | |
} | |
+ | |
+ def shouldThrowException(policy: SQLConf.StoreAssignmentPolicy.Value): Boolean = policy match { | |
+ case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => | |
+ true | |
+ case SQLConf.StoreAssignmentPolicy.LEGACY => | |
+ false | |
+ } | |
+ | |
testingPolicies.foreach { policy => | |
- withSQLConf( | |
- SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { | |
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { | |
withTable("t") { | |
sql("create table t(a int, b string) using parquet partitioned by (a)") | |
- policy match { | |
- case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => | |
- val errorMsg = intercept[NumberFormatException] { | |
- sql("insert into t partition(a='ansi') values('ansi')") | |
- }.getMessage | |
- assert(errorMsg.contains("invalid input syntax for type numeric: ansi")) | |
- case SQLConf.StoreAssignmentPolicy.LEGACY => | |
+ if (shouldThrowException(policy)) { | |
+ val errorMsg = intercept[NumberFormatException] { | |
sql("insert into t partition(a='ansi') values('ansi')") | |
- checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil) | |
+ }.getMessage | |
+ assert(errorMsg.contains( | |
+ """The value 'ansi' of the type "STRING" cannot be cast to "INT"""")) | |
+ } else { | |
+ sql("insert into t partition(a='ansi') values('ansi')") | |
+ checkAnswer(sql("select * from t"), Row("ansi", null) :: Nil) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38228: legacy store assignment should not fail on error under ANSI mode") { | |
+ // DS v2 doesn't support the legacy policy | |
+ if (format != "foo") { | |
+ Seq(true, false).foreach { ansiEnabled => | |
+ withSQLConf( | |
+ SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString, | |
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { | |
+ withTable("t") { | |
+ sql("create table t(a int) using parquet") | |
+ sql("insert into t values('ansi')") | |
+ checkAnswer(spark.table("t"), Row(null)) | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | |
index 1515abb052..66f9700e8a 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | |
@@ -24,34 +24,40 @@ import java.time.{Duration, Period} | |
import java.util.Locale | |
import java.util.concurrent.atomic.AtomicBoolean | |
+import scala.collection.mutable | |
+ | |
import org.apache.commons.io.FileUtils | |
import org.apache.spark.{AccumulatorSuite, SparkException} | |
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} | |
-import org.apache.spark.sql.catalyst.expressions.GenericRow | |
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | |
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex} | |
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} | |
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} | |
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression, Sort} | |
import org.apache.spark.sql.catalyst.util.StringUtils | |
import org.apache.spark.sql.execution.{CommandResultExec, UnionExec} | |
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
-import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} | |
+import org.apache.spark.sql.execution.aggregate._ | |
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec | |
-import org.apache.spark.sql.execution.command.{DataWritingCommandExec, FunctionsCommand} | |
+import org.apache.spark.sql.execution.command.DataWritingCommandExec | |
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, LogicalRelation} | |
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec | |
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan | |
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan | |
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec | |
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} | |
+import org.apache.spark.sql.expressions.Aggregator | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.test.SQLTestData._ | |
import org.apache.spark.sql.types._ | |
-import org.apache.spark.unsafe.types.CalendarInterval | |
+import org.apache.spark.tags.ExtendedSQLTest | |
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | |
import org.apache.spark.util.ResetSystemProperties | |
+@ExtendedSQLTest | |
class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper | |
with ResetSystemProperties { | |
import testImplicits._ | |
@@ -65,8 +71,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") | |
val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") | |
- checkAnswer(queryCaseWhen, Row("1.0") :: Nil) | |
- checkAnswer(queryCoalesce, Row("1") :: Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(queryCaseWhen, Row("1.0") :: Nil) | |
+ checkAnswer(queryCoalesce, Row("1") :: Nil) | |
+ } | |
} | |
} | |
@@ -74,7 +82,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
def getFunctions(pattern: String): Seq[Row] = { | |
StringUtils.filterPattern( | |
spark.sessionState.catalog.listFunctions("default").map(_._1.funcName) | |
- ++ FunctionsCommand.virtualOperators, pattern) | |
+ ++ FunctionRegistry.builtinOperators.keys, pattern) | |
.map(Row(_)) | |
} | |
@@ -123,7 +131,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkKeywordsNotExist(sql("describe functioN Upper"), "Extended Usage") | |
- checkKeywordsExist(sql("describe functioN abcadf"), "Function: abcadf not found.") | |
+ val e = intercept[AnalysisException](sql("describe functioN abcadf")) | |
+ assert(e.message.contains("Undefined function: abcadf. This function is neither a " + | |
+ "built-in/temporary function, nor a persistent function")) | |
} | |
test("SPARK-34678: describe functions for table-valued functions") { | |
@@ -387,10 +397,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
testCodeGen( | |
"SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", | |
Row(100, 1, 50.5, 300, 100) :: Nil) | |
- // Aggregate with Code generation handling all null values | |
- testCodeGen( | |
- "SELECT sum('a'), avg('a'), count(null) FROM testData", | |
- Row(null, null, 0) :: Nil) | |
+ // Aggregate with Code generation handling all null values. | |
+ // If ANSI mode is on, there will be an error since 'a' cannot converted as Numeric. | |
+ // Here we simply test it when ANSI mode is off. | |
+ if (!conf.ansiEnabled) { | |
+ testCodeGen( | |
+ "SELECT sum('a'), avg('a'), count(null) FROM testData", | |
+ Row(null, null, 0) :: Nil) | |
+ } | |
} finally { | |
spark.catalog.dropTempView("testData3x") | |
} | |
@@ -482,9 +496,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
Seq(Row(Timestamp.valueOf("1969-12-31 16:00:00.001")), | |
Row(Timestamp.valueOf("1969-12-31 16:00:00.002")))) | |
- checkAnswer(sql( | |
- "SELECT time FROM timestamps WHERE time='123'"), | |
- Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(sql( | |
+ "SELECT time FROM timestamps WHERE time='123'"), | |
+ Nil) | |
+ } | |
} | |
} | |
@@ -579,6 +595,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
|select * from q1 union all select * from q2""".stripMargin), | |
Row(5, "5") :: Row(4, "4") :: Nil) | |
+ // inner CTE relation refers to outer CTE relation. | |
+ withSQLConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY.key -> "CORRECTED") { | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |with temp1 as (select 1 col), | |
+ |temp2 as ( | |
+ | with temp1 as (select col + 1 AS col from temp1), | |
+ | temp3 as (select col + 1 from temp1) | |
+ | select * from temp3 | |
+ |) | |
+ |select * from temp2 | |
+ |""".stripMargin), | |
+ Row(3)) | |
+ } | |
} | |
test("Allow only a single WITH clause per query") { | |
@@ -933,9 +964,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: | |
Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) | |
// Column type mismatches are not allowed, forcing a type coercion. | |
- checkAnswer( | |
- sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), | |
- ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) | |
+ // When ANSI mode is on, the String input will be cast as Int in the following Union, which will | |
+ // cause a runtime error. Here we simply test the case when ANSI mode is off. | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), | |
+ ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) | |
+ } | |
// Column type mismatches where a coercion is not possible, in this case between integer | |
// and array types, trigger a TreeNodeException. | |
intercept[AnalysisException] { | |
@@ -1032,32 +1067,35 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
Row(Row(3, true), Map("C3" -> null)) :: | |
Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) | |
- checkAnswer( | |
- sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), | |
- Row(1, null) :: | |
- Row(2, null) :: | |
- Row(3, null) :: | |
- Row(4, 2147483644) :: Nil) | |
- | |
- // The value of a MapType column can be a mutable map. | |
- val rowRDD3 = unparsedStrings.map { r => | |
- val values = r.split(",").map(_.trim) | |
- val v4 = try values(3).toInt catch { | |
- case _: NumberFormatException => null | |
+ // If ANSI mode is on, there will be an error "Key D4 does not exist". | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer( | |
+ sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), | |
+ Row(1, null) :: | |
+ Row(2, null) :: | |
+ Row(3, null) :: | |
+ Row(4, 2147483644) :: Nil) | |
+ | |
+ // The value of a MapType column can be a mutable map. | |
+ val rowRDD3 = unparsedStrings.map { r => | |
+ val values = r.split(",").map(_.trim) | |
+ val v4 = try values(3).toInt catch { | |
+ case _: NumberFormatException => null | |
+ } | |
+ Row(Row(values(0).toInt, values(2).toBoolean), | |
+ scala.collection.mutable.Map(values(1) -> v4)) | |
} | |
- Row(Row(values(0).toInt, values(2).toBoolean), | |
- scala.collection.mutable.Map(values(1) -> v4)) | |
- } | |
- val df3 = spark.createDataFrame(rowRDD3, schema2) | |
- df3.createOrReplaceTempView("applySchema3") | |
+ val df3 = spark.createDataFrame(rowRDD3, schema2) | |
+ df3.createOrReplaceTempView("applySchema3") | |
- checkAnswer( | |
- sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), | |
- Row(1, null) :: | |
- Row(2, null) :: | |
- Row(3, null) :: | |
- Row(4, 2147483644) :: Nil) | |
+ checkAnswer( | |
+ sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), | |
+ Row(1, null) :: | |
+ Row(2, null) :: | |
+ Row(3, null) :: | |
+ Row(4, 2147483644) :: Nil) | |
+ } | |
} | |
} | |
@@ -1098,7 +1136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
|order by struct.a, struct.b | |
|""".stripMargin) | |
} | |
- assert(error.message contains "cannot resolve 'struct.a' given input columns: [a, b]") | |
+ assert(error.getErrorClass == "MISSING_COLUMN") | |
+ assert(error.messageParameters.sameElements(Array("struct.a", "a, b"))) | |
} | |
@@ -1396,22 +1435,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
test("SPARK-7952: fix the equality check between boolean and numeric types") { | |
- withTempView("t") { | |
- // numeric field i, boolean field j, result of i = j, result of i <=> j | |
- Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( | |
- (1, true, true, true), | |
- (0, false, true, true), | |
- (2, true, false, false), | |
- (2, false, false, false), | |
- (null, true, null, false), | |
- (null, false, null, false), | |
- (0, null, null, false), | |
- (1, null, null, false), | |
- (null, null, null, true) | |
- ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t") | |
- | |
- checkAnswer(sql("select i = b from t"), sql("select r1 from t")) | |
- checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) | |
+ // If ANSI mode is on, Spark disallows comparing Int with Boolean. | |
+ if (!conf.ansiEnabled) { | |
+ withTempView("t") { | |
+ // numeric field i, boolean field j, result of i = j, result of i <=> j | |
+ Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( | |
+ (1, true, true, true), | |
+ (0, false, true, true), | |
+ (2, true, false, false), | |
+ (2, false, false, false), | |
+ (null, true, null, false), | |
+ (null, false, null, false), | |
+ (0, null, null, false), | |
+ (1, null, null, false), | |
+ (null, null, null, true) | |
+ ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t") | |
+ | |
+ checkAnswer(sql("select i = b from t"), sql("select r1 from t")) | |
+ checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) | |
+ } | |
} | |
} | |
@@ -1445,32 +1487,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
val ymDF = sql("select interval 3 years -3 month") | |
checkAnswer(ymDF, Row(Period.of(2, 9, 0))) | |
- withTempPath(f => { | |
- val e = intercept[AnalysisException] { | |
- ymDF.write.json(f.getCanonicalPath) | |
- } | |
- e.message.contains("Cannot save interval data type into external storage") | |
- }) | |
val dtDF = sql("select interval 5 days 8 hours 12 minutes 50 seconds") | |
checkAnswer(dtDF, Row(Duration.ofDays(5).plusHours(8).plusMinutes(12).plusSeconds(50))) | |
- withTempPath(f => { | |
- val e = intercept[AnalysisException] { | |
- dtDF.write.json(f.getCanonicalPath) | |
- } | |
- e.message.contains("Cannot save interval data type into external storage") | |
- }) | |
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { | |
val df = sql("select interval 3 years -3 month 7 week 123 microseconds") | |
checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7 * 7, 123))) | |
- withTempPath(f => { | |
- // Currently we don't yet support saving out values of interval data type. | |
- val e = intercept[AnalysisException] { | |
- df.write.json(f.getCanonicalPath) | |
- } | |
- e.message.contains("Cannot save interval data type into external storage") | |
- }) | |
} | |
} | |
@@ -2689,10 +2712,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { | |
- val m1 = intercept[AnalysisException] { | |
- sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") | |
- }.message | |
- assert(m1.contains("Union can only be performed on tables with the compatible column types")) | |
+ // Union resolves nested columns by position too. | |
+ checkAnswer(sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))"), | |
+ Row(Row(1)) :: Row(Row(2)) :: Nil) | |
val m2 = intercept[AnalysisException] { | |
sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") | |
@@ -2720,8 +2742,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkAnswer(sql("SELECT i from (SELECT i FROM v)"), Row(1)) | |
val e = intercept[AnalysisException](sql("SELECT v.i from (SELECT i FROM v)")) | |
- assert(e.message == | |
- "cannot resolve 'v.i' given input columns: [__auto_generated_subquery_name.i]") | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("v.i", "__auto_generated_subquery_name.i"))) | |
checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1)) | |
} | |
@@ -2809,15 +2831,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
test("Non-deterministic aggregate functions should not be deduplicated") { | |
- val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" | |
- val df = sql(query) | |
- val physical = df.queryExecution.sparkPlan | |
- val aggregateExpressions = physical.collectFirst { | |
- case agg : HashAggregateExec => agg.aggregateExpressions | |
- case agg : SortAggregateExec => agg.aggregateExpressions | |
+ withUserDefinedFunction("sumND" -> true) { | |
+ spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] { | |
+ def zero: Long = 0L | |
+ def reduce(b: Long, a: Long): Long = b + a | |
+ def merge(b1: Long, b2: Long): Long = b1 + b2 | |
+ def finish(r: Long): Long = r | |
+ def bufferEncoder: Encoder[Long] = Encoders.scalaLong | |
+ def outputEncoder: Encoder[Long] = Encoders.scalaLong | |
+ }).asNondeterministic()) | |
+ | |
+ val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a" | |
+ val df = sql(query) | |
+ val physical = df.queryExecution.sparkPlan | |
+ val aggregateExpressions = physical.collectFirst { | |
+ case agg: BaseAggregateExec => agg.aggregateExpressions | |
+ } | |
+ assert(aggregateExpressions.isDefined) | |
+ assert(aggregateExpressions.get.size == 2) | |
} | |
- assert (aggregateExpressions.isDefined) | |
- assert (aggregateExpressions.get.size == 2) | |
} | |
test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { | |
@@ -3051,15 +3083,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
val df = spark.read.format(format).load(dir.getCanonicalPath) | |
checkPushedFilters( | |
format, | |
- df.where(('id < 2 and 's.contains("foo")) or ('id > 10 and 's.contains("bar"))), | |
+ df.where((Symbol("id") < 2 and Symbol("s").contains("foo")) or | |
+ (Symbol("id") > 10 and Symbol("s").contains("bar"))), | |
Array(sources.Or(sources.LessThan("id", 2), sources.GreaterThan("id", 10)))) | |
checkPushedFilters( | |
format, | |
- df.where('s.contains("foo") or ('id > 10 and 's.contains("bar"))), | |
+ df.where(Symbol("s").contains("foo") or | |
+ (Symbol("id") > 10 and Symbol("s").contains("bar"))), | |
Array.empty) | |
checkPushedFilters( | |
format, | |
- df.where('id < 2 and not('id > 10 and 's.contains("bar"))), | |
+ df.where(Symbol("id") < 2 and not(Symbol("id") > 10 and Symbol("s").contains("bar"))), | |
Array(sources.IsNotNull("id"), sources.LessThan("id", 2))) | |
} | |
} | |
@@ -3140,16 +3174,20 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkAnswer(sql("select * from t1 where d >= '2000-01-01'"), Row(result)) | |
checkAnswer(sql("select * from t1 where d >= '2000-01-02'"), Nil) | |
checkAnswer(sql("select * from t1 where '2000' >= d"), Row(result)) | |
- checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) | |
+ } | |
withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") { | |
checkAnswer(sql("select * from t1 where d < '2000'"), Nil) | |
checkAnswer(sql("select * from t1 where d < '2001'"), Row(result)) | |
- checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result)) | |
checkAnswer(sql("select * from t1 where d <= '1999'"), Nil) | |
checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result)) | |
- checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) | |
- checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true)) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(sql("select * from t1 where d < '2000-1-1'"), Row(result)) | |
+ checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) | |
+ checkAnswer(sql("select to_date('2000-01-01') > '1'"), Row(true)) | |
+ } | |
} | |
} | |
} | |
@@ -3182,17 +3220,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkAnswer(sql("select * from t1 where d >= '2000-01-01 01:10:00.000'"), Row(result)) | |
checkAnswer(sql("select * from t1 where d >= '2000-01-02 01:10:00.000'"), Nil) | |
checkAnswer(sql("select * from t1 where '2000' >= d"), Nil) | |
- checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(sql("select * from t1 where d > '2000-13'"), Nil) | |
+ } | |
withSQLConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING.key -> "true") { | |
checkAnswer(sql("select * from t1 where d < '2000'"), Nil) | |
checkAnswer(sql("select * from t1 where d < '2001'"), Row(result)) | |
- checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result)) | |
checkAnswer(sql("select * from t1 where d <= '2000-01-02'"), Row(result)) | |
checkAnswer(sql("select * from t1 where d <= '1999'"), Nil) | |
checkAnswer(sql("select * from t1 where d >= '2000'"), Row(result)) | |
- checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) | |
- checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true)) | |
+ if (!conf.ansiEnabled) { | |
+ checkAnswer(sql("select * from t1 where d <= '2000-1-1'"), Row(result)) | |
+ checkAnswer(sql("select * from t1 where d > '1999-13'"), Row(result)) | |
+ checkAnswer(sql("select to_timestamp('2000-01-01 01:10:00') > '1'"), Row(true)) | |
+ } | |
} | |
sql("DROP VIEW t1") | |
} | |
@@ -3257,28 +3299,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
test("SPARK-29213: FilterExec should not throw NPE") { | |
- withTempView("t1", "t2", "t3") { | |
- sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") | |
- sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") | |
- .as[java.lang.Long] | |
- .map(identity) | |
- .toDF("x") | |
- .createOrReplaceTempView("t2") | |
- sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") | |
- sql( | |
- """ | |
- |SELECT t1.x | |
- |FROM t1 | |
- |LEFT JOIN ( | |
- | SELECT x FROM ( | |
- | SELECT x FROM t2 | |
- | UNION ALL | |
- | SELECT SUBSTR(x,5) x FROM t3 | |
- | ) a | |
- | WHERE LENGTH(x)>0 | |
- |) t3 | |
- |ON t1.x=t3.x | |
+ // Under ANSI mode, casting string '' as numeric will cause runtime error | |
+ if (!conf.ansiEnabled) { | |
+ withTempView("t1", "t2", "t3") { | |
+ sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") | |
+ sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") | |
+ .as[java.lang.Long] | |
+ .map(identity) | |
+ .toDF("x") | |
+ .createOrReplaceTempView("t2") | |
+ sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") | |
+ sql( | |
+ """ | |
+ |SELECT t1.x | |
+ |FROM t1 | |
+ |LEFT JOIN ( | |
+ | SELECT x FROM ( | |
+ | SELECT x FROM t2 | |
+ | UNION ALL | |
+ | SELECT SUBSTR(x,5) x FROM t3 | |
+ | ) a | |
+ | WHERE LENGTH(x)>0 | |
+ |) t3 | |
+ |ON t1.x=t3.x | |
""".stripMargin).collect() | |
+ } | |
} | |
} | |
@@ -3298,7 +3343,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
sql("CREATE TEMPORARY VIEW tc AS SELECT * FROM VALUES(CAST(1 AS DOUBLE)) AS tc(id)") | |
sql("CREATE TEMPORARY VIEW td AS SELECT * FROM VALUES(CAST(1 AS FLOAT)) AS td(id)") | |
sql("CREATE TEMPORARY VIEW te AS SELECT * FROM VALUES(CAST(1 AS BIGINT)) AS te(id)") | |
- sql("CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)") | |
val df1 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tb)") | |
checkAnswer(df1, Row(new java.math.BigDecimal(1))) | |
val df2 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tc)") | |
@@ -3307,8 +3351,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkAnswer(df3, Row(new java.math.BigDecimal(1))) | |
val df4 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM te)") | |
checkAnswer(df4, Row(new java.math.BigDecimal(1))) | |
- val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)") | |
- checkAnswer(df5, Array.empty[Row]) | |
+ if (!conf.ansiEnabled) { | |
+ sql( | |
+ "CREATE TEMPORARY VIEW tf AS SELECT * FROM VALUES(CAST(1 AS DECIMAL(38, 38))) AS tf(id)") | |
+ val df5 = sql("SELECT id FROM ta WHERE id IN (SELECT id FROM tf)") | |
+ checkAnswer(df5, Array.empty[Row]) | |
+ } | |
} | |
} | |
@@ -4037,6 +4085,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
} | |
+ test("SPARK-40247: Fix BitSet equals") { | |
+ withTable("td") { | |
+ testData | |
+ .withColumn("bucket", $"key" % 3) | |
+ .write | |
+ .mode(SaveMode.Overwrite) | |
+ .bucketBy(2, "bucket") | |
+ .format("parquet") | |
+ .saveAsTable("td") | |
+ val df = sql( | |
+ """ | |
+ |SELECT t1.key, t2.key, t3.key | |
+ |FROM td AS t1 | |
+ |JOIN td AS t2 ON t2.key = t1.key | |
+ |JOIN td AS t3 ON t3.key = t2.key | |
+ |WHERE t1.bucket = 1 AND t2.bucket = 1 AND t3.bucket = 1 | |
+ |""".stripMargin) | |
+ df.collect() | |
+ val reusedExchanges = collect(df.queryExecution.executedPlan) { | |
+ case r: ReusedExchangeExec => r | |
+ } | |
+ assert(reusedExchanges.size == 1) | |
+ } | |
+ } | |
+ | |
test("SPARK-35331: Fix resolving original expression in RepartitionByExpression after aliased") { | |
Seq("CLUSTER", "DISTRIBUTE").foreach { keyword => | |
Seq("a", "substr(a, 0, 3)").foreach { expr => | |
@@ -4215,12 +4288,329 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
} | |
+ test("SPARK-36371: Support raw string literal") { | |
+ checkAnswer(sql("""SELECT r'a\tb\nc'"""), Row("""a\tb\nc""")) | |
+ checkAnswer(sql("""SELECT R'a\tb\nc'"""), Row("""a\tb\nc""")) | |
+ checkAnswer(sql("""SELECT r"a\tb\nc""""), Row("""a\tb\nc""")) | |
+ checkAnswer(sql("""SELECT R"a\tb\nc""""), Row("""a\tb\nc""")) | |
+ checkAnswer(sql("""SELECT from_json(r'{"a": "\\"}', 'a string')"""), Row(Row("\\"))) | |
+ checkAnswer(sql("""SELECT from_json(R'{"a": "\\"}', 'a string')"""), Row(Row("\\"))) | |
+ } | |
+ | |
test("SPARK-36979: Add RewriteLateralSubquery rule into nonExcludableRules") { | |
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> | |
"org.apache.spark.sql.catalyst.optimizer.RewriteLateralSubquery") { | |
sql("SELECT * FROM testData, LATERAL (SELECT * FROM testData)").collect() | |
} | |
} | |
+ | |
+ test("TABLE SAMPLE") { | |
+ withTable("test") { | |
+ sql("CREATE TABLE test(c int) USING PARQUET") | |
+ for (i <- 0 to 20) { | |
+ sql(s"INSERT INTO test VALUES ($i)") | |
+ } | |
+ val df1 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)") | |
+ val df2 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)") | |
+ checkAnswer(df1, df2) | |
+ | |
+ val df3 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)") | |
+ val df4 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)") | |
+ checkAnswer(df3, df4) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-27442: Spark support read/write parquet file with invalid char in field name") { | |
+ withTempDir { dir => | |
+ Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20)) | |
+ .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a") | |
+ .repartition(1) | |
+ .write.mode(SaveMode.Overwrite).parquet(dir.getAbsolutePath) | |
+ val df = spark.read.parquet(dir.getAbsolutePath) | |
+ checkAnswer(df, | |
+ Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) :: | |
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil) | |
+ assert(df.schema.names.sameElements( | |
+ Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a"))) | |
+ checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"), | |
+ Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil) | |
+ checkAnswer(df.where("`a.b` > 10"), | |
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37965: Spark support read/write orc file with invalid char in field name") { | |
+ withTempDir { dir => | |
+ Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22)) | |
+ .toDF("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ",") | |
+ .repartition(1) | |
+ .write.mode(SaveMode.Overwrite).orc(dir.getAbsolutePath) | |
+ val df = spark.read.orc(dir.getAbsolutePath) | |
+ checkAnswer(df, | |
+ Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) :: | |
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil) | |
+ assert(df.schema.names.sameElements( | |
+ Array("max(t)", "max(t", "=", "\n", ";", "a b", "{", ".", "a.b", "a", ","))) | |
+ checkAnswer(df.select("`max(t)`", "`a b`", "`{`", "`.`", "`a.b`"), | |
+ Row(1, 6, 7, 8, 9) :: Row(2, 12, 14, 16, 18) :: Nil) | |
+ checkAnswer(df.where("`a.b` > 10"), | |
+ Row(2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38173: Quoted column cannot be recognized correctly " + | |
+ "when quotedRegexColumnNames is true") { | |
+ withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |SELECT `(C3)?+.+`,T.`C1` * `C2` AS CC | |
+ |FROM (SELECT 3 AS C1,2 AS C2,1 AS C3) T | |
+ |""".stripMargin), | |
+ Row(3, 2, 6) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38548: try_sum should return null if overflow happens before merging") { | |
+ val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v") | |
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) | |
+ .map(Period.ofMonths) | |
+ .toDF("v") | |
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L) | |
+ .map(Duration.ofDays) | |
+ .toDF("v") | |
+ Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => | |
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), Row(null)) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" + | |
+ " when WSCG is off") { | |
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", | |
+ SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withTable("t") { | |
+ sql("create table t(i int, j int) using parquet") | |
+ sql("insert into t values(2147483647, 10)") | |
+ Seq( | |
+ "select i + j from t", | |
+ "select -i - j from t", | |
+ "select i * j from t", | |
+ "select i / (j - 10) from t").foreach { query => | |
+ val msg = intercept[SparkException] { | |
+ sql(query).collect() | |
+ }.getMessage | |
+ assert(msg.contains(query)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39175: Query context of Cast should be serialized to executors" + | |
+ " when WSCG is off") { | |
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", | |
+ SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withTable("t") { | |
+ sql("create table t(s string) using parquet") | |
+ sql("insert into t values('a')") | |
+ Seq( | |
+ "select cast(s as int) from t", | |
+ "select cast(s as long) from t", | |
+ "select cast(s as double) from t", | |
+ "select cast(s as decimal(10, 2)) from t", | |
+ "select cast(s as date) from t", | |
+ "select cast(s as timestamp) from t", | |
+ "select cast(s as boolean) from t").foreach { query => | |
+ val msg = intercept[SparkException] { | |
+ sql(query).collect() | |
+ }.getMessage | |
+ assert(msg.contains(query)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39177: Query context of getting map value should be serialized to executors" + | |
+ " when WSCG is off") { | |
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", | |
+ SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withTable("t") { | |
+ sql("create table t(m map<string, string>) using parquet") | |
+ sql("insert into t values map('a', 'b')") | |
+ Seq( | |
+ "select m['foo'] from t", | |
+ "select element_at(m, 'foo') from t").foreach { query => | |
+ val msg = intercept[SparkException] { | |
+ sql(query).collect() | |
+ }.getMessage | |
+ assert(msg.contains(query)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + | |
+ "be serialized to executors when WSCG is off") { | |
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", | |
+ SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withTable("t") { | |
+ sql("create table t(d decimal(38, 0)) using parquet") | |
+ sql("insert into t values (6e37BD),(6e37BD)") | |
+ Seq( | |
+ "select d / 0.1 from t", | |
+ "select sum(d) from t", | |
+ "select avg(d) from t").foreach { query => | |
+ val msg = intercept[SparkException] { | |
+ sql(query).collect() | |
+ }.getMessage | |
+ assert(msg.contains(query)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38589: try_avg should return null if overflow happens before merging") { | |
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) | |
+ .map(Period.ofMonths) | |
+ .toDF("v") | |
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L) | |
+ .map(Duration.ofDays) | |
+ .toDF("v") | |
+ Seq(yearMonthDf, dayTimeDf).foreach { df => | |
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_avg(v)"), Row(null)) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39012: SparkSQL cast partition value does not support all data types") { | |
+ withTempDir { dir => | |
+ val binary1 = Hex.hex(UTF8String.fromString("Spark").getBytes).getBytes | |
+ val binary2 = Hex.hex(UTF8String.fromString("SQL").getBytes).getBytes | |
+ val data = Seq[(Int, Boolean, Array[Byte])]( | |
+ (1, false, binary1), | |
+ (2, true, binary2) | |
+ ) | |
+ data.toDF("a", "b", "c") | |
+ .write | |
+ .mode("overwrite") | |
+ .partitionBy("b", "c") | |
+ .parquet(dir.getCanonicalPath) | |
+ val res = spark.read | |
+ .schema("a INT, b BOOLEAN, c BINARY") | |
+ .parquet(dir.getCanonicalPath) | |
+ checkAnswer(res, | |
+ Seq( | |
+ Row(1, false, mutable.WrappedArray.make(binary1)), | |
+ Row(2, true, mutable.WrappedArray.make(binary2)) | |
+ )) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39216: Don't collapse projects in CombineUnions if it hasCorrelatedSubquery") { | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |SELECT (SELECT IF(x, 1, 0)) AS a | |
+ |FROM (SELECT true) t(x) | |
+ |UNION | |
+ |SELECT 1 AS a | |
+ """.stripMargin), | |
+ Seq(Row(1))) | |
+ | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |SELECT x + 1 | |
+ |FROM (SELECT id | |
+ | + (SELECT Max(id) | |
+ | FROM range(2)) AS x | |
+ | FROM range(1)) t | |
+ |UNION | |
+ |SELECT 1 AS a | |
+ """.stripMargin), | |
+ Seq(Row(2), Row(1))) | |
+ } | |
+ | |
+ test("SPARK-39548: CreateView will make queries go into inline CTE code path thus" + | |
+ "trigger a mis-clarified `window definition not found` issue") { | |
+ sql( | |
+ """ | |
+ |create or replace temporary view test_temp_view as | |
+ |with step_1 as ( | |
+ |select * , min(a) over w2 as min_a_over_w2 from | |
+ |(select 1 as a, 2 as b, 3 as c) window w2 as (partition by b order by c)) , step_2 as | |
+ |( | |
+ |select *, max(e) over w1 as max_a_over_w1 | |
+ |from (select 1 as e, 2 as f, 3 as g) | |
+ |join step_1 on true | |
+ |window w1 as (partition by f order by g) | |
+ |) | |
+ |select * | |
+ |from step_2 | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql("select * from test_temp_view"), | |
+ Row(1, 2, 3, 1, 2, 3, 1, 1)) | |
+ } | |
+ | |
+ test("SPARK-40389: Don't eliminate a cast which can cause overflow") { | |
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { | |
+ withTable("dt") { | |
+ sql("create table dt using parquet as select 9000000000BD as d") | |
+ val msg = intercept[SparkException] { | |
+ sql("select cast(cast(d as int) as long) from dt").collect() | |
+ }.getCause.getMessage | |
+ assert(msg.contains("The value 9000000000BD of the type \"DECIMAL(10,0)\" " + | |
+ "cannot be cast to \"INT\" due to an overflow")) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-41144: Unresolved hint should not cause query failure") { | |
+ withTable("t1", "t2") { | |
+ sql("CREATE TABLE t1(c1 bigint) USING PARQUET") | |
+ sql("CREATE TABLE t2(c2 bigint) USING PARQUET") | |
+ sql("SELECT /*+ hash(t2) */ * FROM t1 join t2 on c1 = c2") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-41538: Metadata column should be appended at the end of project") { | |
+ val tableName = "table_1" | |
+ val viewName = "view_1" | |
+ withTable(tableName) { | |
+ withView(viewName) { | |
+ sql(s"CREATE TABLE $tableName (a ARRAY<STRING>, s STRUCT<id: STRING>) USING parquet") | |
+ val id = "id1" | |
+ sql(s"INSERT INTO $tableName values(ARRAY('a'), named_struct('id', '$id'))") | |
+ sql( | |
+ s""" | |
+ |CREATE VIEW $viewName (id) | |
+ |AS WITH source AS ( | |
+ | SELECT * FROM $tableName | |
+ |), | |
+ |renamed AS ( | |
+ | SELECT s.id FROM source | |
+ |) | |
+ |SELECT id FROM renamed | |
+ |""".stripMargin) | |
+ val query = | |
+ s""" | |
+ |with foo AS ( | |
+ | SELECT '$id' as id | |
+ |), | |
+ |bar AS ( | |
+ | SELECT '$id' as id | |
+ |) | |
+ |SELECT | |
+ | 1 | |
+ |FROM foo | |
+ |FULL OUTER JOIN bar USING(id) | |
+ |FULL OUTER JOIN $viewName USING(id) | |
+ |WHERE foo.id IS NOT NULL | |
+ |""".stripMargin | |
+ checkAnswer(sql(query), Row(1)) | |
+ } | |
+ } | |
+ } | |
} | |
case class Foo(bar: Option[String]) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala | |
index b9ca2a0f03..987e09adb1 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala | |
@@ -35,6 +35,7 @@ trait SQLQueryTestHelper { | |
protected def replaceNotIncludedMsg(line: String): String = { | |
line.replaceAll("#\\d+", "#x") | |
+ .replaceAll("plan_id=\\d+", "plan_id=x") | |
.replaceAll( | |
s"Location.*$clsName/", | |
s"Location $notIncludedMsg/{warehouse_dir}/") | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala | |
index d5a34ae64a..d6a7c69018 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala | |
@@ -36,6 +36,7 @@ import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.tags.ExtendedSQLTest | |
import org.apache.spark.util.Utils | |
+// scalastyle:off line.size.limit | |
/** | |
* End-to-end test cases for SQL queries. | |
* | |
@@ -44,22 +45,22 @@ import org.apache.spark.util.Utils | |
* | |
* To run the entire test suite: | |
* {{{ | |
- * build/sbt "sql/testOnly *SQLQueryTestSuite" | |
+ * build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" | |
* }}} | |
* | |
* To run a single test file upon change: | |
* {{{ | |
- * build/sbt "~sql/testOnly *SQLQueryTestSuite -- -z inline-table.sql" | |
+ * build/sbt "~sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z inline-table.sql" | |
* }}} | |
* | |
* To re-generate golden files for entire suite, run: | |
* {{{ | |
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" | |
* }}} | |
* | |
* To re-generate golden file for a single test, run: | |
* {{{ | |
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite -- -z describe.sql" | |
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z describe.sql" | |
* }}} | |
* | |
* The format for input files is simple: | |
@@ -120,6 +121,7 @@ import org.apache.spark.util.Utils | |
* Therefore, UDF test cases should have single input and output files but executed by three | |
* different types of UDFs. See 'udf/udf-inner-join.sql' as an example. | |
*/ | |
+// scalastyle:on line.size.limit | |
@ExtendedSQLTest | |
class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper | |
with SQLQueryTestHelper { | |
@@ -386,6 +388,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper | |
localSparkSession.conf.set(SQLConf.TIMESTAMP_TYPE.key, | |
TimestampTypes.TIMESTAMP_NTZ.toString) | |
case _ => | |
+ localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false) | |
} | |
if (configSet.nonEmpty) { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala | |
deleted file mode 100644 | |
index 6839294348..0000000000 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala | |
+++ /dev/null | |
@@ -1,241 +0,0 @@ | |
-/* | |
- * Licensed to the Apache Software Foundation (ASF) under one or more | |
- * contributor license agreements. See the NOTICE file distributed with | |
- * this work for additional information regarding copyright ownership. | |
- * The ASF licenses this file to You under the Apache License, Version 2.0 | |
- * (the "License"); you may not use this file except in compliance with | |
- * the License. You may obtain a copy of the License at | |
- * | |
- * http://www.apache.org/licenses/LICENSE-2.0 | |
- * | |
- * Unless required by applicable law or agreed to in writing, software | |
- * distributed under the License is distributed on an "AS IS" BASIS, | |
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
- * See the License for the specific language governing permissions and | |
- * limitations under the License. | |
- */ | |
- | |
-package org.apache.spark.sql | |
- | |
-import org.apache.spark.sql.catalyst.TableIdentifier | |
-import org.apache.spark.sql.catalyst.catalog.CatalogTable | |
-import org.apache.spark.sql.sources.SimpleInsertSource | |
-import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} | |
-import org.apache.spark.util.Utils | |
- | |
-class SimpleShowCreateTableSuite extends ShowCreateTableSuite with SharedSparkSession | |
- | |
-abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils { | |
- import testImplicits._ | |
- | |
- test("data source table with user specified schema") { | |
- withTable("ddl_test") { | |
- val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile | |
- | |
- sql( | |
- s"""CREATE TABLE ddl_test ( | |
- | a STRING, | |
- | b STRING, | |
- | `extra col` ARRAY<INT>, | |
- | `<another>` STRUCT<x: INT, y: ARRAY<BOOLEAN>> | |
- |) | |
- |USING json | |
- |OPTIONS ( | |
- | PATH '$jsonFilePath' | |
- |) | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("data source table CTAS") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |AS SELECT 1 AS a, "foo" AS b | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("partitioned data source table") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |PARTITIONED BY (b) | |
- |AS SELECT 1 AS a, "foo" AS b | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("bucketed data source table") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS | |
- |AS SELECT 1 AS a, "foo" AS b | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("partitioned bucketed data source table") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |PARTITIONED BY (c) | |
- |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS | |
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("data source table with a comment") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |COMMENT 'This is a comment' | |
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("data source table with table properties") { | |
- withTable("ddl_test") { | |
- sql( | |
- s"""CREATE TABLE ddl_test | |
- |USING json | |
- |TBLPROPERTIES ('a' = '1') | |
- |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c | |
- """.stripMargin | |
- ) | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("data source table using Dataset API") { | |
- withTable("ddl_test") { | |
- spark | |
- .range(3) | |
- .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) | |
- .write | |
- .mode("overwrite") | |
- .partitionBy("a", "b") | |
- .bucketBy(2, "c", "d") | |
- .saveAsTable("ddl_test") | |
- | |
- checkCreateTable("ddl_test") | |
- } | |
- } | |
- | |
- test("temp view") { | |
- val viewName = "spark_28383" | |
- withTempView(viewName) { | |
- sql(s"CREATE TEMPORARY VIEW $viewName AS SELECT 1 AS a") | |
- val ex = intercept[AnalysisException] { | |
- sql(s"SHOW CREATE TABLE $viewName") | |
- } | |
- assert(ex.getMessage.contains( | |
- s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view.")) | |
- } | |
- | |
- withGlobalTempView(viewName) { | |
- sql(s"CREATE GLOBAL TEMPORARY VIEW $viewName AS SELECT 1 AS a") | |
- val globalTempViewDb = spark.sessionState.catalog.globalTempViewManager.database | |
- val ex = intercept[AnalysisException] { | |
- sql(s"SHOW CREATE TABLE $globalTempViewDb.$viewName") | |
- } | |
- assert(ex.getMessage.contains( | |
- s"$globalTempViewDb.$viewName is a temp view. " + | |
- "'SHOW CREATE TABLE' expects a table or permanent view.")) | |
- } | |
- } | |
- | |
- test("SPARK-24911: keep quotes for nested fields") { | |
- withTable("t1") { | |
- val createTable = "CREATE TABLE `t1` (`a` STRUCT<`b`: STRING>)" | |
- sql(s"$createTable USING json") | |
- val shownDDL = getShowDDL("SHOW CREATE TABLE t1") | |
- assert(shownDDL == "CREATE TABLE `default`.`t1` ( `a` STRUCT<`b`: STRING>) USING json") | |
- | |
- checkCreateTable("t1") | |
- } | |
- } | |
- | |
- test("SPARK-36012: Add NULL flag when SHOW CREATE TABLE") { | |
- val t = "SPARK_36012" | |
- withTable(t) { | |
- sql( | |
- s""" | |
- |CREATE TABLE $t ( | |
- | a bigint NOT NULL, | |
- | b bigint | |
- |) | |
- |USING ${classOf[SimpleInsertSource].getName} | |
- """.stripMargin) | |
- val showDDL = getShowDDL(s"SHOW CREATE TABLE $t") | |
- assert(showDDL == s"CREATE TABLE `default`.`$t` ( `a` BIGINT NOT NULL," + | |
- s" `b` BIGINT) USING ${classOf[SimpleInsertSource].getName}") | |
- } | |
- } | |
- | |
- protected def getShowDDL(showCreateTableSql: String): String = { | |
- sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim).mkString(" ") | |
- } | |
- | |
- protected def checkCreateTable(table: String, serde: Boolean = false): Unit = { | |
- checkCreateTableOrView(TableIdentifier(table, Some("default")), "TABLE", serde) | |
- } | |
- | |
- protected def checkCreateView(table: String, serde: Boolean = false): Unit = { | |
- checkCreateTableOrView(TableIdentifier(table, Some("default")), "VIEW", serde) | |
- } | |
- | |
- protected def checkCreateTableOrView( | |
- table: TableIdentifier, | |
- checkType: String, | |
- serde: Boolean): Unit = { | |
- val db = table.database.getOrElse("default") | |
- val expected = spark.sharedState.externalCatalog.getTable(db, table.table) | |
- val shownDDL = if (serde) { | |
- sql(s"SHOW CREATE TABLE ${table.quotedString} AS SERDE").head().getString(0) | |
- } else { | |
- sql(s"SHOW CREATE TABLE ${table.quotedString}").head().getString(0) | |
- } | |
- | |
- sql(s"DROP $checkType ${table.quotedString}") | |
- | |
- try { | |
- sql(shownDDL) | |
- val actual = spark.sharedState.externalCatalog.getTable(db, table.table) | |
- checkCatalogTables(expected, actual) | |
- } finally { | |
- sql(s"DROP $checkType IF EXISTS ${table.table}") | |
- } | |
- } | |
- | |
- protected def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { | |
- assert(CatalogTable.normalize(actual) == CatalogTable.normalize(expected)) | |
- } | |
-} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala | |
index 9d4e57093c..0a7c684a68 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala | |
@@ -20,6 +20,7 @@ package org.apache.spark.sql | |
import scala.collection.JavaConverters._ | |
import org.apache.hadoop.fs.Path | |
+import org.apache.logging.log4j.Level | |
import org.scalatest.BeforeAndAfterEach | |
import org.scalatest.concurrent.Eventually | |
import org.scalatest.time.SpanSugar._ | |
@@ -431,7 +432,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit | |
.getOrCreate() | |
.sharedState | |
} | |
- assert(logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg))) | |
+ assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) | |
} | |
test("SPARK-33944: no warning setting spark.sql.warehouse.dir using session options") { | |
@@ -444,7 +445,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit | |
.getOrCreate() | |
.sharedState | |
} | |
- assert(!logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg))) | |
+ assert(!logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) | |
} | |
Seq(".", "..", "dir0", "dir0/dir1", "/dir0/dir1", "./dir0").foreach { pathStr => | |
@@ -484,6 +485,89 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach wit | |
.getOrCreate() | |
session.sql("SELECT 1").collect() | |
} | |
- assert(logAppender.loggingEvents.exists(_.getRenderedMessage.contains(msg))) | |
+ assert(logAppender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains(msg))) | |
+ } | |
+ | |
+ test("SPARK-37727: Show ignored configurations in debug level logs") { | |
+ // Create one existing SparkSession to check following logs. | |
+ SparkSession.builder().master("local").getOrCreate() | |
+ | |
+ val logAppender = new LogAppender | |
+ logAppender.setThreshold(Level.DEBUG) | |
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) { | |
+ SparkSession.builder() | |
+ .config("spark.sql.warehouse.dir", "2") | |
+ .config("spark.abc", "abcb") | |
+ .config("spark.abcd", "abcb4") | |
+ .getOrCreate() | |
+ } | |
+ | |
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ Seq( | |
+ "Ignored static SQL configurations", | |
+ "spark.sql.warehouse.dir=2", | |
+ "Configurations that might not take effect", | |
+ "spark.abcd=abcb4", | |
+ "spark.abc=abcb").foreach { msg => | |
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37727: Hide the same configuration already explicitly set in logs") { | |
+ // Create one existing SparkSession to check following logs. | |
+ SparkSession.builder().master("local").config("spark.abc", "abc").getOrCreate() | |
+ | |
+ val logAppender = new LogAppender | |
+ logAppender.setThreshold(Level.DEBUG) | |
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) { | |
+ // Ignore logs because it's already set. | |
+ SparkSession.builder().config("spark.abc", "abc").getOrCreate() | |
+ // Show logs for only configuration newly set. | |
+ SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() | |
+ // Ignore logs because it's set ^. | |
+ SparkSession.builder().config("spark.abc.new", "abc").getOrCreate() | |
+ } | |
+ | |
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ Seq( | |
+ "Using an existing Spark session; only runtime SQL configurations will take effect", | |
+ "Configurations that might not take effect", | |
+ "spark.abc.new=abc").foreach { msg => | |
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") | |
+ } | |
+ | |
+ assert( | |
+ !logs.exists(_.contains("spark.abc=abc")), | |
+ s"'spark.abc=abc' existed in:\n${logs.mkString("\n")}") | |
+ } | |
+ | |
+ test("SPARK-37727: Hide runtime SQL configurations in logs") { | |
+ // Create one existing SparkSession to check following logs. | |
+ SparkSession.builder().master("local").getOrCreate() | |
+ | |
+ val logAppender = new LogAppender | |
+ logAppender.setThreshold(Level.DEBUG) | |
+ withLogAppender(logAppender, level = Some(Level.DEBUG)) { | |
+ // Ignore logs for runtime SQL configurations | |
+ SparkSession.builder().config("spark.sql.ansi.enabled", "true").getOrCreate() | |
+ // Show logs for Spark core configuration | |
+ SparkSession.builder().config("spark.buffer.size", "1234").getOrCreate() | |
+ // Show logs for custom runtime options | |
+ SparkSession.builder().config("spark.sql.source.abc", "abc").getOrCreate() | |
+ // Show logs for static SQL configurations | |
+ SparkSession.builder().config("spark.sql.warehouse.dir", "xyz").getOrCreate() | |
+ } | |
+ | |
+ val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) | |
+ Seq( | |
+ "spark.buffer.size=1234", | |
+ "spark.sql.source.abc=abc", | |
+ "spark.sql.warehouse.dir=xyz").foreach { msg => | |
+ assert(logs.exists(_.contains(msg)), s"$msg did not exist in:\n${logs.mkString("\n")}") | |
+ } | |
+ | |
+ assert( | |
+ !logs.exists(_.contains("spark.sql.ansi.enabled\"")), | |
+ s"'spark.sql.ansi.enabled' existed in:\n${logs.mkString("\n")}") | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala | |
index f8a155ec1b..5c8eec6b10 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala | |
@@ -27,7 +27,8 @@ import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} | |
import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} | |
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Statistics, UnresolvedHint} | |
+import org.apache.spark.sql.catalyst.plans.SQLHelper | |
+import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} | |
import org.apache.spark.sql.catalyst.plans.physical.Partitioning | |
import org.apache.spark.sql.catalyst.rules.Rule | |
import org.apache.spark.sql.catalyst.trees.TreeNodeTag | |
@@ -45,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String | |
/** | |
* Test cases for the [[SparkSessionExtensions]]. | |
*/ | |
-class SparkSessionExtensionSuite extends SparkFunSuite { | |
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper { | |
private def create( | |
builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder) | |
@@ -171,7 +172,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
} | |
withSession(extensions) { session => | |
session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true) | |
- assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule())) | |
+ assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules | |
+ .contains(MyQueryStagePrepRule())) | |
assert(session.sessionState.columnarRules.contains( | |
MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule()))) | |
import session.sqlContext.implicits._ | |
@@ -406,6 +408,26 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
session.sql("SELECT * FROM v") | |
} | |
} | |
+ | |
+ test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE Optimizer") { | |
+ def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = { | |
+ assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) | |
+ df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan | |
+ } | |
+ val extensions = create { extensions => | |
+ extensions.injectRuntimeOptimizerRule(_ => AddLimit) | |
+ } | |
+ withSession(extensions) { session => | |
+ assert(session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules.contains(AddLimit)) | |
+ | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
+ val df = session.range(2).repartition() | |
+ assert(!executedPlan(df).isInstanceOf[CollectLimitExec]) | |
+ df.collect() | |
+ assert(executedPlan(df).isInstanceOf[CollectLimitExec]) | |
+ } | |
+ } | |
+ } | |
} | |
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { | |
@@ -441,6 +463,9 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars | |
override def parseDataType(sqlText: String): DataType = | |
delegate.parseDataType(sqlText) | |
+ | |
+ override def parseQuery(sqlText: String): LogicalPlan = | |
+ delegate.parseQuery(sqlText) | |
} | |
object MyExtensions { | |
@@ -546,7 +571,7 @@ case class NoCloseColumnVector(wrapped: ColumnVector) extends ColumnVector(wrapp | |
override def getBinary(rowId: Int): Array[Byte] = wrapped.getBinary(rowId) | |
- override protected def getChild(ordinal: Int): ColumnVector = wrapped.getChild(ordinal) | |
+ override def getChild(ordinal: Int): ColumnVector = wrapped.getChild(ordinal) | |
} | |
trait ColumnarExpression extends Expression with Serializable { | |
@@ -722,37 +747,32 @@ class BrokenColumnarAdd( | |
lhs = left.columnarEval(batch) | |
rhs = right.columnarEval(batch) | |
- if (lhs == null || rhs == null) { | |
- ret = null | |
- } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) { | |
- val l = lhs.asInstanceOf[ColumnVector] | |
- val r = rhs.asInstanceOf[ColumnVector] | |
- val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
- ret = result | |
- | |
- for (i <- 0 until batch.numRows()) { | |
- result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add | |
- } | |
- } else if (rhs.isInstanceOf[ColumnVector]) { | |
- val l = lhs.asInstanceOf[Long] | |
- val r = rhs.asInstanceOf[ColumnVector] | |
- val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
- ret = result | |
- | |
- for (i <- 0 until batch.numRows()) { | |
- result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add | |
- } | |
- } else if (lhs.isInstanceOf[ColumnVector]) { | |
- val l = lhs.asInstanceOf[ColumnVector] | |
- val r = rhs.asInstanceOf[Long] | |
- val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
- ret = result | |
- | |
- for (i <- 0 until batch.numRows()) { | |
- result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add | |
- } | |
- } else { | |
- ret = nullSafeEval(lhs, rhs) | |
+ (lhs, rhs) match { | |
+ case (null, null) => | |
+ ret = null | |
+ case (l: ColumnVector, r: ColumnVector) => | |
+ val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
+ ret = result | |
+ | |
+ for (i <- 0 until batch.numRows()) { | |
+ result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add | |
+ } | |
+ case (l: Long, r: ColumnVector) => | |
+ val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
+ ret = result | |
+ | |
+ for (i <- 0 until batch.numRows()) { | |
+ result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add | |
+ } | |
+ case (l: ColumnVector, r: Long) => | |
+ val result = new OnHeapColumnVector(batch.numRows(), dataType) | |
+ ret = result | |
+ | |
+ for (i <- 0 until batch.numRows()) { | |
+ result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add | |
+ } | |
+ case (l, r) => | |
+ ret = nullSafeEval(l, r) | |
} | |
} finally { | |
if (lhs != null && lhs.isInstanceOf[ColumnVector]) { | |
@@ -1026,3 +1046,10 @@ class YourExtensions extends SparkSessionExtensionsProvider { | |
v1.injectFunction(getAppName) | |
} | |
} | |
+ | |
+object AddLimit extends Rule[LogicalPlan] { | |
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match { | |
+ case Limit(_, _) => plan | |
+ case _ => Limit(Literal(1), plan) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala | |
index 9f8000a08f..c37309d97a 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala | |
@@ -27,8 +27,9 @@ import scala.collection.mutable | |
import org.apache.spark.sql.catalyst.TableIdentifier | |
import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat | |
import org.apache.spark.sql.catalyst.plans.logical._ | |
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils | |
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneUTC | |
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} | |
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, PST, UTC} | |
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC} | |
import org.apache.spark.sql.functions.timestamp_seconds | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
@@ -406,9 +407,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared | |
withTable("TBL1", "TBL") { | |
import org.apache.spark.sql.functions._ | |
val df = spark.range(1000L).select('id, | |
- 'id * 2 as "FLD1", | |
- 'id * 12 as "FLD2", | |
- lit("aaa") + 'id as "fld3") | |
+ Symbol("id") * 2 as "FLD1", | |
+ Symbol("id") * 12 as "FLD2", | |
+ lit(null).cast(DoubleType) + Symbol("id") as "fld3") | |
df.write | |
.mode(SaveMode.Overwrite) | |
.bucketBy(10, "id", "FLD1", "FLD2") | |
@@ -424,7 +425,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared | |
|WHERE t1.fld3 IN (-123.23,321.23) | |
""".stripMargin) | |
df2.createTempView("TBL2") | |
- sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan | |
+ sql("SELECT * FROM tbl2 WHERE fld3 IN (0,1) ").queryExecution.executedPlan | |
} | |
} | |
} | |
@@ -470,7 +471,89 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared | |
} | |
} | |
- def getStatAttrNames(tableName: String): Set[String] = { | |
+ private def checkDescTimestampColStats( | |
+ tableName: String, | |
+ timestampColumn: String, | |
+ expectedMinTimestamp: String, | |
+ expectedMaxTimestamp: String): Unit = { | |
+ | |
+ def extractColumnStatsFromDesc(statsName: String, rows: Array[Row]): String = { | |
+ rows.collect { | |
+ case r: Row if r.getString(0) == statsName => | |
+ r.getString(1) | |
+ }.head | |
+ } | |
+ | |
+ val descTsCol = sql(s"DESC FORMATTED $tableName $timestampColumn").collect() | |
+ assert(extractColumnStatsFromDesc("min", descTsCol) == expectedMinTimestamp) | |
+ assert(extractColumnStatsFromDesc("max", descTsCol) == expectedMaxTimestamp) | |
+ } | |
+ | |
+ test("SPARK-38140: describe column stats (min, max) for timestamp column: desc results should " + | |
+ "be consistent with the written value if writing and desc happen in the same time zone") { | |
+ | |
+ val zoneIdAndOffsets = | |
+ Seq((UTC, "+0000"), (PST, "-0800"), (getZoneId("Asia/Hong_Kong"), "+0800")) | |
+ | |
+ zoneIdAndOffsets.foreach { case (zoneId, offset) => | |
+ withDefaultTimeZone(zoneId) { | |
+ val table = "insert_desc_same_time_zone" | |
+ val tsCol = "timestamp_typed_col" | |
+ withTable(table) { | |
+ val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)" | |
+ val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)" | |
+ sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet") | |
+ sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp") | |
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") | |
+ | |
+ checkDescTimestampColStats( | |
+ tableName = table, | |
+ timestampColumn = tsCol, | |
+ expectedMinTimestamp = "2022-01-01 00:00:01.123456 " + offset, | |
+ expectedMaxTimestamp = "2022-01-03 00:00:02.987654 " + offset) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38140: describe column stats (min, max) for timestamp column: desc should show " + | |
+ "different results if writing in UTC and desc in other time zones") { | |
+ | |
+ val table = "insert_desc_diff_time_zones" | |
+ val tsCol = "timestamp_typed_col" | |
+ | |
+ withDefaultTimeZone(UTC) { | |
+ withTable(table) { | |
+ val minTimestamp = "make_timestamp(2022, 1, 1, 0, 0, 1.123456)" | |
+ val maxTimestamp = "make_timestamp(2022, 1, 3, 0, 0, 2.987654)" | |
+ sql(s"CREATE TABLE $table ($tsCol Timestamp) USING parquet") | |
+ sql(s"INSERT INTO $table VALUES $minTimestamp, $maxTimestamp") | |
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") | |
+ | |
+ checkDescTimestampColStats( | |
+ tableName = table, | |
+ timestampColumn = tsCol, | |
+ expectedMinTimestamp = "2022-01-01 00:00:01.123456 +0000", | |
+ expectedMaxTimestamp = "2022-01-03 00:00:02.987654 +0000") | |
+ | |
+ TimeZone.setDefault(DateTimeUtils.getTimeZone("PST")) | |
+ checkDescTimestampColStats( | |
+ tableName = table, | |
+ timestampColumn = tsCol, | |
+ expectedMinTimestamp = "2021-12-31 16:00:01.123456 -0800", | |
+ expectedMaxTimestamp = "2022-01-02 16:00:02.987654 -0800") | |
+ | |
+ TimeZone.setDefault(DateTimeUtils.getTimeZone("Asia/Hong_Kong")) | |
+ checkDescTimestampColStats( | |
+ tableName = table, | |
+ timestampColumn = tsCol, | |
+ expectedMinTimestamp = "2022-01-01 08:00:01.123456 +0800", | |
+ expectedMaxTimestamp = "2022-01-03 08:00:02.987654 +0800") | |
+ } | |
+ } | |
+ } | |
+ | |
+ private def getStatAttrNames(tableName: String): Set[String] = { | |
val queryStats = spark.table(tableName).queryExecution.optimizedPlan.stats.attributeStats | |
queryStats.map(_._1.name).toSet | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala | |
index dd8a1a8478..2f118f236e 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala | |
@@ -112,9 +112,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { | |
val df = Seq[(String, String, String, Int)](("hello", "world", null, 15)) | |
.toDF("a", "b", "c", "d") | |
- checkAnswer( | |
- df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), | |
- Row(null, "hello", null)) | |
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { | |
+ checkAnswer( | |
+ df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), | |
+ Row(null, "hello", null)) | |
+ } | |
// check implicit type cast | |
checkAnswer( | |
@@ -383,9 +385,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { | |
Row("host", "/file;param", "query;p2", null, "http", "/file;param?query;p2", | |
"user:pass@host", "user:pass", null)) | |
- testUrl( | |
- "inva lid://user:pass@host/file;param?query;p2", | |
- Row(null, null, null, null, null, null, null, null, null)) | |
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { | |
+ testUrl( | |
+ "inva lid://user:pass@host/file;param?query;p2", | |
+ Row(null, null, null, null, null, null, null, null, null)) | |
+ } | |
} | |
@@ -486,6 +490,58 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { | |
) | |
} | |
+ test("SPARK-36751: add octet length api for scala") { | |
+ // scalastyle:off | |
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here. | |
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08")) | |
+ .toDF("a", "b", "c", "d", "e", "f") | |
+ // string and binary input | |
+ checkAnswer( | |
+ df.select(octet_length($"a"), octet_length($"b")), | |
+ Row(3, 4)) | |
+ // string and binary input | |
+ checkAnswer( | |
+ df.selectExpr("octet_length(a)", "octet_length(b)"), | |
+ Row(3, 4)) | |
+ // integer, float and double input | |
+ checkAnswer( | |
+ df.selectExpr("octet_length(c)", "octet_length(d)", "octet_length(e)"), | |
+ Row(3, 3, 5) | |
+ ) | |
+ // multi-byte character input | |
+ checkAnswer( | |
+ df.selectExpr("octet_length(f)"), | |
+ Row(4) | |
+ ) | |
+ // scalastyle:on | |
+ } | |
+ | |
+ test("SPARK-36751: add bit length api for scala") { | |
+ // scalastyle:off | |
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here. | |
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015, "\ud83d\udc08")) | |
+ .toDF("a", "b", "c", "d", "e", "f") | |
+ // string and binary input | |
+ checkAnswer( | |
+ df.select(bit_length($"a"), bit_length($"b")), | |
+ Row(24, 32)) | |
+ // string and binary input | |
+ checkAnswer( | |
+ df.selectExpr("bit_length(a)", "bit_length(b)"), | |
+ Row(24, 32)) | |
+ // integer, float and double input | |
+ checkAnswer( | |
+ df.selectExpr("bit_length(c)", "bit_length(d)", "bit_length(e)"), | |
+ Row(24, 24, 40) | |
+ ) | |
+ // multi-byte character input | |
+ checkAnswer( | |
+ df.selectExpr("bit_length(f)"), | |
+ Row(32) | |
+ ) | |
+ // scalastyle:on | |
+ } | |
+ | |
test("initcap function") { | |
val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z") | |
checkAnswer( | |
@@ -598,4 +654,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { | |
) | |
} | |
+ | |
+ test("SPARK-36148: check input data types of regexp_replace") { | |
+ val m = intercept[AnalysisException] { | |
+ sql("select regexp_replace(collect_list(1), '1', '2')") | |
+ }.getMessage | |
+ assert(m.contains("data type mismatch: argument 1 requires string type")) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala | |
index 277cb1bceb..b3aefac05b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala | |
@@ -20,10 +20,11 @@ package org.apache.spark.sql | |
import scala.collection.mutable.ArrayBuffer | |
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression | |
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} | |
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project, Sort} | |
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} | |
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} | |
import org.apache.spark.sql.execution.datasources.FileScanRDD | |
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec | |
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
@@ -145,12 +146,12 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
test("runtime error when the number of rows is greater than 1") { | |
- val error2 = intercept[RuntimeException] { | |
+ val e = intercept[IllegalStateException] { | |
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() | |
} | |
- assert(error2.getMessage.contains( | |
- "more than one row returned by a subquery used as an expression") | |
- ) | |
+ // TODO(SPARK-39167): Throw an exception w/ an error class for multiple rows from a subquery | |
+ assert(e.getMessage.contains( | |
+ "more than one row returned by a subquery used as an expression")) | |
} | |
test("uncorrelated scalar subquery on a DataFrame generated query") { | |
@@ -895,7 +896,8 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
withTempView("t") { | |
Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("t") | |
val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) | |
- assert(e.message.contains("cannot resolve 'a' given input columns: [t.i, t.j]")) | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("a", "t.i, t.j"))) | |
} | |
} | |
@@ -1407,7 +1409,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
test("Scalar subquery name should start with scalar-subquery#") { | |
val df = sql("SELECT a FROM l WHERE a = (SELECT max(c) FROM r WHERE c = 1)".stripMargin) | |
- var subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty | |
+ val subqueryExecs: ArrayBuffer[SubqueryExec] = ArrayBuffer.empty | |
df.queryExecution.executedPlan.transformAllExpressions { | |
case s @ ScalarSubquery(p: SubqueryExec, _) => | |
subqueryExecs += p | |
@@ -1878,6 +1880,30 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
} | |
+ test("SPARK-36280: Remove redundant aliases after RewritePredicateSubquery") { | |
+ withTable("t1", "t2") { | |
+ sql("CREATE TABLE t1 USING parquet AS SELECT id AS a, id AS b, id AS c FROM range(10)") | |
+ sql("CREATE TABLE t2 USING parquet AS SELECT id AS x, id AS y FROM range(8)") | |
+ val df = sql( | |
+ """ | |
+ |SELECT * | |
+ |FROM t1 | |
+ |WHERE a IN (SELECT x | |
+ | FROM (SELECT x AS x, | |
+ | RANK() OVER (PARTITION BY x ORDER BY SUM(y) DESC) AS ranking | |
+ | FROM t2 | |
+ | GROUP BY x) tmp1 | |
+ | WHERE ranking <= 5) | |
+ |""".stripMargin) | |
+ | |
+ df.collect() | |
+ val exchanges = collect(df.queryExecution.executedPlan) { | |
+ case s: ShuffleExchangeExec => s | |
+ } | |
+ assert(exchanges.size === 1) | |
+ } | |
+ } | |
+ | |
test("SPARK-36747: should not combine Project with Aggregate") { | |
withTempView("t") { | |
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") | |
@@ -1896,6 +1922,77 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
} | |
} | |
+ test("SPARK-36656: Do not collapse projects with correlate scalar subqueries") { | |
+ withTempView("t1", "t2") { | |
+ Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | |
+ Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") | |
+ val correctAnswer = Row(0, 2, 20) :: Row(1, null, null) :: Nil | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |SELECT c1, s, s * 10 FROM ( | |
+ | SELECT c1, (SELECT FIRST(c2) FROM t2 WHERE t1.c1 = t2.c1) s FROM t1) | |
+ |""".stripMargin), | |
+ correctAnswer) | |
+ checkAnswer( | |
+ sql( | |
+ """ | |
+ |SELECT c1, s, s * 10 FROM ( | |
+ | SELECT c1, SUM((SELECT FIRST(c2) FROM t2 WHERE t1.c1 = t2.c1)) s | |
+ | FROM t1 GROUP BY c1 | |
+ |) | |
+ |""".stripMargin), | |
+ correctAnswer) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37199: deterministic in QueryPlan considers subquery") { | |
+ val deterministicQueryPlan = sql("select (select 1 as b) as b") | |
+ .queryExecution.executedPlan | |
+ assert(deterministicQueryPlan.deterministic) | |
+ | |
+ val nonDeterministicQueryPlan = sql("select (select rand(1) as b) as b") | |
+ .queryExecution.executedPlan | |
+ assert(!nonDeterministicQueryPlan.deterministic) | |
+ } | |
+ | |
+ test("SPARK-38132: Not IN subquery correctness checks") { | |
+ val t = "test_table" | |
+ withTable(t) { | |
+ Seq[(Integer, Integer)]( | |
+ (1, 1), | |
+ (2, 2), | |
+ (3, 3), | |
+ (4, null), | |
+ (null, 0)) | |
+ .toDF("c1", "c2").write.saveAsTable(t) | |
+ val df = spark.table(t) | |
+ | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) = true"), Seq.empty) | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) = true"), | |
+ Row(4, null) :: Nil) | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) <=> true"), Seq.empty) | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> true"), | |
+ Row(4, null) :: Nil) | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t)) != false"), Seq.empty) | |
+ checkAnswer(df.where(s"(c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) != false"), | |
+ Row(4, null) :: Nil) | |
+ checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t)) <=> false)"), Seq.empty) | |
+ checkAnswer(df.where(s"NOT((c1 NOT IN (SELECT c2 FROM $t WHERE c2 IS NOT NULL)) <=> false)"), | |
+ Row(4, null) :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { | |
+ withTempView("t1", "t2") { | |
+ Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") | |
+ Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") | |
+ assert(intercept[AnalysisException] { | |
+ sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") | |
+ }.getMessage.contains("Correlated column is not allowed in predicate")) | |
+ } | |
+ } | |
+ | |
test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { | |
withTempView("t1", "t2") { | |
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | |
@@ -1921,4 +2018,302 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
}.getMessage.contains("Correlated column is not allowed in predicate")) | |
} | |
} | |
+ | |
+ test("Merge non-correlated scalar subqueries") { | |
+ Seq(false, true).foreach { enableAQE => | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { | |
+ val df = sql( | |
+ """ | |
+ |SELECT | |
+ | (SELECT avg(key) FROM testData), | |
+ | (SELECT sum(key) FROM testData), | |
+ | (SELECT count(distinct key) FROM testData) | |
+ """.stripMargin) | |
+ | |
+ checkAnswer(df, Row(50.5, 5050, 100) :: Nil) | |
+ | |
+ val plan = df.queryExecution.executedPlan | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 2, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Merge non-correlated scalar subqueries in a subquery") { | |
+ Seq(false, true).foreach { enableAQE => | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { | |
+ val df = sql( | |
+ """ | |
+ |SELECT ( | |
+ | SELECT | |
+ | SUM( | |
+ | (SELECT avg(key) FROM testData) + | |
+ | (SELECT sum(key) FROM testData) + | |
+ | (SELECT count(distinct key) FROM testData)) | |
+ | FROM testData | |
+ |) | |
+ """.stripMargin) | |
+ | |
+ checkAnswer(df, Row(520050.0) :: Nil) | |
+ | |
+ val plan = df.queryExecution.executedPlan | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ if (enableAQE) { | |
+ assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 4, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } else { | |
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 5, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Merge non-correlated scalar subqueries from different levels") { | |
+ Seq(false, true).foreach { enableAQE => | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { | |
+ val df = sql( | |
+ """ | |
+ |SELECT | |
+ | (SELECT avg(key) FROM testData), | |
+ | ( | |
+ | SELECT | |
+ | SUM( | |
+ | (SELECT sum(key) FROM testData) | |
+ | ) | |
+ | FROM testData | |
+ | ) | |
+ """.stripMargin) | |
+ | |
+ checkAnswer(df, Row(50.5, 505000) :: Nil) | |
+ | |
+ val plan = df.queryExecution.executedPlan | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 2, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Merge non-correlated scalar subqueries from different parent plans") { | |
+ Seq(false, true).foreach { enableAQE => | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { | |
+ val df = sql( | |
+ """ | |
+ |SELECT | |
+ | ( | |
+ | SELECT | |
+ | SUM( | |
+ | (SELECT avg(key) FROM testData) | |
+ | ) | |
+ | FROM testData | |
+ | ), | |
+ | ( | |
+ | SELECT | |
+ | SUM( | |
+ | (SELECT sum(key) FROM testData) | |
+ | ) | |
+ | FROM testData | |
+ | ) | |
+ """.stripMargin) | |
+ | |
+ checkAnswer(df, Row(5050.0, 505000) :: Nil) | |
+ | |
+ val plan = df.queryExecution.executedPlan | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ if (enableAQE) { | |
+ assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 3, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } else { | |
+ assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 4, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Merge non-correlated scalar subqueries with conflicting names") { | |
+ Seq(false, true).foreach { enableAQE => | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { | |
+ val df = sql( | |
+ """ | |
+ |SELECT | |
+ | (SELECT avg(key) AS key FROM testData), | |
+ | (SELECT sum(key) AS key FROM testData), | |
+ | (SELECT count(distinct key) AS key FROM testData) | |
+ """.stripMargin) | |
+ | |
+ checkAnswer(df, Row(50.5, 5050, 100) :: Nil) | |
+ | |
+ val plan = df.queryExecution.executedPlan | |
+ val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } | |
+ val reusedSubqueryIds = collectWithSubqueries(plan) { | |
+ case rs: ReusedSubqueryExec => rs.child.id | |
+ } | |
+ | |
+ assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") | |
+ assert(reusedSubqueryIds.size == 2, | |
+ "Missing or unexpected reused ReusedSubqueryExec in the plan") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") { | |
+ checkAnswer( | |
+ sql(""" | |
+ |SELECT * | |
+ |FROM ( | |
+ | SELECT '2022-06-01' AS c1 | |
+ |) a | |
+ |WHERE c1 IN ( | |
+ | SELECT date_add('2022-06-01', 0) | |
+ |) | |
+ |""".stripMargin), | |
+ Row("2022-06-01")) | |
+ checkAnswer( | |
+ sql(""" | |
+ |SELECT * | |
+ |FROM ( | |
+ | SELECT '2022-06-01' AS c1 | |
+ |) a | |
+ |WHERE c1 IN ( | |
+ | SELECT date_add(a.c1.k1, 0) | |
+ | FROM ( | |
+ | SELECT named_struct('k1', '2022-06-01') AS c1 | |
+ | ) a | |
+ |) | |
+ |""".stripMargin), | |
+ Row("2022-06-01")) | |
+ } | |
+ | |
+ test("SPARK-39672: Fix removing project before filter with correlated subquery") { | |
+ withTempView("v1", "v2") { | |
+ Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c").createTempView("v1") | |
+ Seq((1, 3, 5), (4, 5, 6)).toDF("a", "b", "c").createTempView("v2") | |
+ | |
+ def findProject(df: DataFrame): Seq[Project] = { | |
+ df.queryExecution.optimizedPlan.collect { | |
+ case p: Project => p | |
+ } | |
+ } | |
+ | |
+ // project before filter cannot be removed since subquery has conflicting attributes | |
+ // with outer reference | |
+ val df1 = sql( | |
+ """ | |
+ |select * from | |
+ |( | |
+ |select | |
+ |v1.a, | |
+ |v1.b, | |
+ |v2.c | |
+ |from v1 | |
+ |inner join v2 | |
+ |on v1.a=v2.a) t3 | |
+ |where not exists ( | |
+ | select 1 | |
+ | from v2 | |
+ | where t3.a=v2.a and t3.b=v2.b and t3.c=v2.c | |
+ |) | |
+ |""".stripMargin) | |
+ checkAnswer(df1, Row(1, 2, 5)) | |
+ assert(findProject(df1).size == 4) | |
+ | |
+ // project before filter can be removed when there are no conflicting attributes | |
+ val df2 = sql( | |
+ """ | |
+ |select * from | |
+ |( | |
+ |select | |
+ |v1.b, | |
+ |v2.c | |
+ |from v1 | |
+ |inner join v2 | |
+ |on v1.b=v2.c) t3 | |
+ |where not exists ( | |
+ | select 1 | |
+ | from v2 | |
+ | where t3.b=v2.b and t3.c=v2.c | |
+ |) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(df2, Row(5, 5)) | |
+ assert(findProject(df2).size == 3) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { | |
+ withTempView("t1") { | |
+ Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") | |
+ | |
+ checkAnswer(sql( | |
+ """ | |
+ |SELECT | |
+ | (SELECT count(distinct c1) FROM t1), | |
+ | (SELECT count(distinct c2) FROM t1) | |
+ |""".stripMargin), | |
+ Row(2, 2)) | |
+ | |
+ // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the | |
+ // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that | |
+ // are inserted by the rewrite. | |
+ checkAnswer(sql( | |
+ """ | |
+ |SELECT | |
+ | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), | |
+ | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) | |
+ |""".stripMargin), | |
+ Row(8, 6)) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-42937: Outer join with subquery in condition") { | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", | |
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { | |
+ withTempView("t2") { | |
+ // this is the same as the view t created in beforeAll, but that gets dropped by | |
+ // one of the tests above | |
+ r.filter($"c".isNotNull && $"d".isNotNull).createOrReplaceTempView("t2") | |
+ val expected = Row(1, 2.0d, null, null) :: Row(1, 2.0d, null, null) :: | |
+ Row(3, 3.0d, 3, 2.0d) :: Row(null, 5.0d, null, null) :: Nil | |
+ checkAnswer(sql( | |
+ """ | |
+ |select * | |
+ |from l | |
+ |left outer join r | |
+ |on a = c | |
+ |and a in (select c from t2 where d in (1.0, 2.0)) | |
+ |where b > 1.0""".stripMargin), | |
+ expected) | |
+ } | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala | |
new file mode 100644 | |
index 0000000000..1764584922 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCBase.scala | |
@@ -0,0 +1,53 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import org.apache.spark.SparkConf | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+trait TPCBase extends SharedSparkSession { | |
+ | |
+ protected def injectStats: Boolean = false | |
+ | |
+ override protected def sparkConf: SparkConf = { | |
+ if (injectStats) { | |
+ super.sparkConf | |
+ .set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) | |
+ .set(SQLConf.CBO_ENABLED, true) | |
+ .set(SQLConf.PLAN_STATS_ENABLED, true) | |
+ .set(SQLConf.JOIN_REORDER_ENABLED, true) | |
+ } else { | |
+ super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) | |
+ } | |
+ } | |
+ | |
+ override def beforeAll(): Unit = { | |
+ super.beforeAll() | |
+ createTables() | |
+ } | |
+ | |
+ override def afterAll(): Unit = { | |
+ dropTables() | |
+ super.afterAll() | |
+ } | |
+ | |
+ protected def createTables(): Unit | |
+ | |
+ protected def dropTables(): Unit | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala | |
index 20cfcecc22..39587ce063 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSBase.scala | |
@@ -18,10 +18,8 @@ | |
package org.apache.spark.sql | |
import org.apache.spark.sql.catalyst.TableIdentifier | |
-import org.apache.spark.sql.internal.SQLConf | |
-import org.apache.spark.sql.test.SharedSparkSession | |
-trait TPCDSBase extends SharedSparkSession with TPCDSSchema { | |
+trait TPCDSBase extends TPCBase with TPCDSSchema { | |
// The TPCDS queries below are based on v1.4 | |
private val tpcdsAllQueries: Seq[String] = Seq( | |
@@ -79,19 +77,7 @@ trait TPCDSBase extends SharedSparkSession with TPCDSSchema { | |
""".stripMargin) | |
} | |
- private val originalCBCEnabled = conf.cboEnabled | |
- private val originalPlanStatsEnabled = conf.planStatsEnabled | |
- private val originalJoinReorderEnabled = conf.joinReorderEnabled | |
- | |
- override def beforeAll(): Unit = { | |
- super.beforeAll() | |
- if (injectStats) { | |
- // Sets configurations for enabling the optimization rules that | |
- // exploit data statistics. | |
- conf.setConf(SQLConf.CBO_ENABLED, true) | |
- conf.setConf(SQLConf.PLAN_STATS_ENABLED, true) | |
- conf.setConf(SQLConf.JOIN_REORDER_ENABLED, true) | |
- } | |
+ override def createTables(): Unit = { | |
tableNames.foreach { tableName => | |
createTable(spark, tableName) | |
if (injectStats) { | |
@@ -102,15 +88,9 @@ trait TPCDSBase extends SharedSparkSession with TPCDSSchema { | |
} | |
} | |
- override def afterAll(): Unit = { | |
- conf.setConf(SQLConf.CBO_ENABLED, originalCBCEnabled) | |
- conf.setConf(SQLConf.PLAN_STATS_ENABLED, originalPlanStatsEnabled) | |
- conf.setConf(SQLConf.JOIN_REORDER_ENABLED, originalJoinReorderEnabled) | |
+ override def dropTables(): Unit = { | |
tableNames.foreach { tableName => | |
spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) | |
} | |
- super.afterAll() | |
} | |
- | |
- protected def injectStats: Boolean = false | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala | |
index 22e1b838f3..8c4d25a7eb 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala | |
@@ -29,7 +29,8 @@ import org.apache.spark.tags.ExtendedSQLTest | |
@ExtendedSQLTest | |
class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSBase { | |
- tpcdsQueries.foreach { name => | |
+ // q72 is skipped due to GitHub Actions' memory limit. | |
+ tpcdsQueries.filterNot(sys.env.contains("GITHUB_ACTIONS") && _ == "q72").foreach { name => | |
val queryString = resourceToString(s"tpcds/$name.sql", | |
classLoader = Thread.currentThread().getContextClassLoader) | |
test(name) { | |
@@ -39,7 +40,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSBase { | |
} | |
} | |
- tpcdsQueriesV2_7_0.foreach { name => | |
+ // q72 is skipped due to GitHub Actions' memory limit. | |
+ tpcdsQueriesV2_7_0.filterNot(sys.env.contains("GITHUB_ACTIONS") && _ == "q72").foreach { name => | |
val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", | |
classLoader = Thread.currentThread().getContextClassLoader) | |
test(s"$name-v2.7") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala | |
index 952e896802..8019fc98a5 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala | |
@@ -20,10 +20,13 @@ package org.apache.spark.sql | |
import java.io.File | |
import java.nio.file.{Files, Paths} | |
+import scala.collection.JavaConverters._ | |
+ | |
import org.apache.spark.{SparkConf, SparkContext} | |
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.TestSparkSession | |
+import org.apache.spark.tags.ExtendedSQLTest | |
/** | |
* End-to-end tests to check TPCDS query results. | |
@@ -51,13 +54,14 @@ import org.apache.spark.sql.test.TestSparkSession | |
* build/sbt "sql/testOnly *TPCDSQueryTestSuite -- -z q79" | |
* }}} | |
*/ | |
+@ExtendedSQLTest | |
class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelper { | |
private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA") | |
private val regenerateGoldenFiles = sys.env.get("SPARK_GENERATE_GOLDEN_FILES").exists(_ == "1") | |
// To make output results deterministic | |
- protected override def sparkConf: SparkConf = super.sparkConf | |
+ override protected def sparkConf: SparkConf = super.sparkConf | |
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") | |
protected override def createSparkSession: TestSparkSession = { | |
@@ -97,50 +101,111 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp | |
""".stripMargin) | |
} | |
- private def runQuery(query: String, goldenFile: File): Unit = { | |
- val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) | |
- val queryString = query.trim | |
- val outputString = output.mkString("\n").replaceAll("\\s+$", "") | |
- if (regenerateGoldenFiles) { | |
- val goldenOutput = { | |
- s"-- Automatically generated by ${getClass.getSimpleName}\n\n" + | |
- s"-- !query schema\n" + | |
- schema + "\n" + | |
- s"-- !query output\n" + | |
- outputString + | |
- "\n" | |
- } | |
- val parent = goldenFile.getParentFile | |
- if (!parent.exists()) { | |
- assert(parent.mkdirs(), "Could not create directory: " + parent) | |
+ private def runQuery( | |
+ query: String, | |
+ goldenFile: File, | |
+ conf: Map[String, String]): Unit = { | |
+ val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins | |
+ withSQLConf(conf.toSeq: _*) { | |
+ try { | |
+ val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) | |
+ val queryString = query.trim | |
+ val outputString = output.mkString("\n").replaceAll("\\s+$", "") | |
+ if (regenerateGoldenFiles) { | |
+ val goldenOutput = { | |
+ s"-- Automatically generated by ${getClass.getSimpleName}\n\n" + | |
+ s"-- !query schema\n" + | |
+ schema + "\n" + | |
+ s"-- !query output\n" + | |
+ outputString + | |
+ "\n" | |
+ } | |
+ val parent = goldenFile.getParentFile | |
+ if (!parent.exists()) { | |
+ assert(parent.mkdirs(), "Could not create directory: " + parent) | |
+ } | |
+ stringToFile(goldenFile, goldenOutput) | |
+ } | |
+ | |
+ // Read back the golden file. | |
+ val (expectedSchema, expectedOutput) = { | |
+ val goldenOutput = fileToString(goldenFile) | |
+ val segments = goldenOutput.split("-- !query.*\n") | |
+ | |
+ // query has 3 segments, plus the header | |
+ assert(segments.size == 3, | |
+ s"Expected 3 blocks in result file but got ${segments.size}. " + | |
+ "Try regenerate the result files.") | |
+ | |
+ (segments(1).trim, segments(2).replaceAll("\\s+$", "")) | |
+ } | |
+ | |
+ assertResult(expectedSchema, s"Schema did not match\n$queryString") { | |
+ schema | |
+ } | |
+ if (shouldSortResults) { | |
+ val expectSorted = expectedOutput.split("\n").sorted.map(_.trim) | |
+ .mkString("\n").replaceAll("\\s+$", "") | |
+ val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "") | |
+ assertResult(expectSorted, s"Result did not match\n$queryString") { | |
+ outputSorted | |
+ } | |
+ } else { | |
+ assertResult(expectedOutput, s"Result did not match\n$queryString") { | |
+ outputString | |
+ } | |
+ } | |
+ } catch { | |
+ case e: Throwable => | |
+ val configs = conf.map { | |
+ case (k, v) => s"$k=$v" | |
+ } | |
+ throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}") | |
} | |
- stringToFile(goldenFile, goldenOutput) | |
} | |
+ } | |
- // Read back the golden file. | |
- val (expectedSchema, expectedOutput) = { | |
- val goldenOutput = fileToString(goldenFile) | |
- val segments = goldenOutput.split("-- !query.*\n") | |
+ val sortMergeJoinConf = Map( | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "true") | |
- // query has 3 segments, plus the header | |
- assert(segments.size == 3, | |
- s"Expected 3 blocks in result file but got ${segments.size}. " + | |
- "Try regenerate the result files.") | |
+ val broadcastHashJoinConf = Map(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760") | |
- (segments(1).trim, segments(2).replaceAll("\\s+$", "")) | |
- } | |
+ val shuffledHashJoinConf = Map( | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ "spark.sql.join.forceApplyShuffledHashJoin" -> "true") | |
- assertResult(expectedSchema, s"Schema did not match\n$queryString") { schema } | |
- assertResult(expectedOutput, s"Result did not match\n$queryString") { outputString } | |
+ val allJoinConfCombinations = Seq( | |
+ sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf) | |
+ | |
+ val joinConfs: Seq[Map[String, String]] = if (regenerateGoldenFiles) { | |
+ require( | |
+ !sys.env.contains("SPARK_TPCDS_JOIN_CONF"), | |
+ "'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'") | |
+ Seq(sortMergeJoinConf) | |
+ } else { | |
+ sys.env.get("SPARK_TPCDS_JOIN_CONF").map { s => | |
+ val p = new java.util.Properties() | |
+ p.load(new java.io.StringReader(s)) | |
+ Seq(p.asScala.toMap) | |
+ }.getOrElse(allJoinConfCombinations) | |
} | |
+ assert(joinConfs.nonEmpty) | |
+ joinConfs.foreach(conf => require( | |
+ allJoinConfCombinations.contains(conf), | |
+ s"Join configurations [$conf] should be one of $allJoinConfCombinations")) | |
+ | |
if (tpcdsDataPath.nonEmpty) { | |
tpcdsQueries.foreach { name => | |
val queryString = resourceToString(s"tpcds/$name.sql", | |
classLoader = Thread.currentThread().getContextClassLoader) | |
test(name) { | |
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out") | |
- runQuery(queryString, goldenFile) | |
+ joinConfs.foreach { conf => | |
+ System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 | |
+ runQuery(queryString, goldenFile, conf) | |
+ } | |
} | |
} | |
@@ -149,7 +214,10 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp | |
classLoader = Thread.currentThread().getContextClassLoader) | |
test(s"$name-v2.7") { | |
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out") | |
- runQuery(queryString, goldenFile) | |
+ joinConfs.foreach { conf => | |
+ System.gc() // SPARK-37368 | |
+ runQuery(queryString, goldenFile, conf) | |
+ } | |
} | |
} | |
} else { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala | |
new file mode 100644 | |
index 0000000000..e7edd1090b | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHBase.scala | |
@@ -0,0 +1,96 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql | |
+ | |
+import org.apache.spark.sql.catalyst.TableIdentifier | |
+ | |
+trait TPCHBase extends TPCBase { | |
+ | |
+ override def createTables(): Unit = { | |
+ tpchCreateTable.values.foreach { sql => | |
+ spark.sql(sql) | |
+ } | |
+ } | |
+ | |
+ override def dropTables(): Unit = { | |
+ tpchCreateTable.keys.foreach { tableName => | |
+ spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) | |
+ } | |
+ } | |
+ | |
+ val tpchCreateTable = Map( | |
+ "orders" -> | |
+ """ | |
+ |CREATE TABLE `orders` ( | |
+ |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, | |
+ |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, | |
+ |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "nation" -> | |
+ """ | |
+ |CREATE TABLE `nation` ( | |
+ |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "region" -> | |
+ """ | |
+ |CREATE TABLE `region` ( | |
+ |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "part" -> | |
+ """ | |
+ |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, | |
+ |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, | |
+ |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "partsupp" -> | |
+ """ | |
+ |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, | |
+ |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "customer" -> | |
+ """ | |
+ |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, | |
+ |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), | |
+ |`c_mktsegment` STRING, `c_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "supplier" -> | |
+ """ | |
+ |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, | |
+ |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin, | |
+ "lineitem" -> | |
+ """ | |
+ |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, | |
+ |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), | |
+ |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, | |
+ |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, | |
+ |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) | |
+ |USING parquet | |
+ """.stripMargin) | |
+ | |
+ val tpchQueries = Seq( | |
+ "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", | |
+ "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", "q22") | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala | |
index ba99e18714..89dfcefb1b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala | |
@@ -23,79 +23,7 @@ import org.apache.spark.sql.catalyst.util.resourceToString | |
* This test suite ensures all the TPC-H queries can be successfully analyzed, optimized | |
* and compiled without hitting the max iteration threshold. | |
*/ | |
-class TPCHQuerySuite extends BenchmarkQueryTest { | |
- | |
- override def beforeAll(): Unit = { | |
- super.beforeAll() | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `orders` ( | |
- |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, | |
- |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, | |
- |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `nation` ( | |
- |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `region` ( | |
- |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, | |
- |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, | |
- |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, | |
- |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, | |
- |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), | |
- |`c_mktsegment` STRING, `c_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, | |
- |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- | |
- sql( | |
- """ | |
- |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, | |
- |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), | |
- |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, | |
- |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, | |
- |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) | |
- |USING parquet | |
- """.stripMargin) | |
- } | |
- | |
- val tpchQueries = Seq( | |
- "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", | |
- "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", "q22") | |
- | |
+class TPCHQuerySuite extends BenchmarkQueryTest with TPCHBase { | |
tpchQueries.foreach { name => | |
val queryString = resourceToString(s"tpch/$name.sql", | |
classLoader = Thread.currentThread().getContextClassLoader) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | |
index aad1060f11..912811bfda 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | |
@@ -424,7 +424,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
("N", Integer.valueOf(3), null)).toDF("a", "b", "c") | |
val udf1 = udf((a: String, b: Int, c: Any) => a + b + c) | |
- val df = input.select(udf1('a, 'b, 'c)) | |
+ val df = input.select(udf1(Symbol("a"), 'b, 'c)) | |
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) | |
// test Java UDF. Java UDF can't have primitive inputs, as it's generic typed. | |
@@ -554,7 +554,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
spark.udf.register("buildLocalDateInstantType", | |
udf((d: LocalDate, i: Instant) => LocalDateInstantType(d, i))) | |
checkAnswer(df.selectExpr(s"buildLocalDateInstantType(d, i) as di") | |
- .select('di.cast(StringType)), | |
+ .select(Symbol("di").cast(StringType)), | |
Row(s"{$expectedDate, $expectedInstant}") :: Nil) | |
// test null cases | |
@@ -584,7 +584,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
spark.udf.register("buildTimestampInstantType", | |
udf((t: Timestamp, i: Instant) => TimestampInstantType(t, i))) | |
checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti") | |
- .select('ti.cast(StringType)), | |
+ .select(Symbol("ti").cast(StringType)), | |
Row(s"{$expectedTimestamp, $expectedInstant}")) | |
// test null cases | |
@@ -730,7 +730,8 @@ class UDFSuite extends QueryTest with SharedSparkSession { | |
.select(lit(50).as("a")) | |
.select(struct("a").as("col")) | |
val error = intercept[AnalysisException](df.select(myUdf(Column("col")))) | |
- assert(error.getMessage.contains("cannot resolve 'b' given input columns: [a]")) | |
+ assert(error.getErrorClass == "MISSING_COLUMN") | |
+ assert(error.messageParameters.sameElements(Array("b", "a"))) | |
} | |
test("wrong order of input fields for case class") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala | |
index e6f0426428..1d7af84ef6 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala | |
@@ -190,5 +190,55 @@ class UnwrapCastInComparisonEndToEndSuite extends QueryTest with SharedSparkSess | |
} | |
} | |
+ test("SPARK-36607: Support BooleanType in UnwrapCastInBinaryComparison") { | |
+ // If ANSI mode is on, Spark disallows comparing Int with Boolean. | |
+ if (!conf.ansiEnabled) { | |
+ withTable(t) { | |
+ Seq(Some(true), Some(false), None).toDF().write.saveAsTable(t) | |
+ val df = spark.table(t) | |
+ | |
+ checkAnswer(df.where("value = -1"), Seq.empty) | |
+ checkAnswer(df.where("value = 0"), Row(false)) | |
+ checkAnswer(df.where("value = 1"), Row(true)) | |
+ checkAnswer(df.where("value = 2"), Seq.empty) | |
+ checkAnswer(df.where("value <=> -1"), Seq.empty) | |
+ checkAnswer(df.where("value <=> 0"), Row(false)) | |
+ checkAnswer(df.where("value <=> 1"), Row(true)) | |
+ checkAnswer(df.where("value <=> 2"), Seq.empty) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39476: Should not unwrap cast from Long to Double/Float") { | |
+ withTable(t) { | |
+ Seq((6470759586864300301L)) | |
+ .toDF("c1").write.saveAsTable(t) | |
+ val df = spark.table(t) | |
+ | |
+ checkAnswer( | |
+ df.where("cast(c1 as double) == cast(6470759586864300301L as double)") | |
+ .select("c1"), | |
+ Row(6470759586864300301L)) | |
+ | |
+ checkAnswer( | |
+ df.where("cast(c1 as float) == cast(6470759586864300301L as float)") | |
+ .select("c1"), | |
+ Row(6470759586864300301L)) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39476: Should not unwrap cast from Integer to Float") { | |
+ withTable(t) { | |
+ Seq((33554435)) | |
+ .toDF("c1").write.saveAsTable(t) | |
+ val df = spark.table(t) | |
+ | |
+ checkAnswer( | |
+ df.where("cast(c1 as float) == cast(33554435 as float)") | |
+ .select("c1"), | |
+ Row(33554435)) | |
+ } | |
+ } | |
+ | |
private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | |
index cc52b6d8a1..729312c3e5 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | |
@@ -82,14 +82,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque | |
} | |
test("register user type: MyDenseVector for MyLabeledPoint") { | |
- val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } | |
+ val labels: RDD[Double] = pointsRDD.select(Symbol("label")).rdd.map { case Row(v: Double) => v } | |
val labelsArrays: Array[Double] = labels.collect() | |
assert(labelsArrays.size === 2) | |
assert(labelsArrays.contains(1.0)) | |
assert(labelsArrays.contains(0.0)) | |
val features: RDD[TestUDT.MyDenseVector] = | |
- pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } | |
+ pointsRDD.select(Symbol("features")).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } | |
val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() | |
assert(featuresArrays.size === 2) | |
assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) | |
@@ -137,8 +137,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque | |
val df = Seq((1, vec)).toDF("int", "vec") | |
assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) | |
assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) | |
- checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) | |
- checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) | |
+ checkAnswer(df.limit(1).groupBy(Symbol("int")).agg(first(Symbol("vec"))), Row(1, vec)) | |
+ checkAnswer(df.orderBy(Symbol("int")).limit(1).groupBy(Symbol("int")) | |
+ .agg(first(Symbol("vec"))), Row(1, vec)) | |
} | |
test("UDTs with JSON") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala | |
index 1b0898fbc1..19f3f86c94 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala | |
@@ -1070,7 +1070,7 @@ trait AlterTableTests extends SharedSparkSession { | |
} | |
} | |
- test("AlterTable: drop column must exist") { | |
+ test("AlterTable: drop column must exist if required") { | |
val t = s"${catalogAndNamespace}table_name" | |
withTable(t) { | |
sql(s"CREATE TABLE $t (id int) USING $v2Format") | |
@@ -1080,10 +1080,15 @@ trait AlterTableTests extends SharedSparkSession { | |
} | |
assert(exc.getMessage.contains("Missing field data")) | |
+ | |
+ // with if exists it should pass | |
+ sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS data") | |
+ val table = getTableMetadata(fullTableName(t)) | |
+ assert(table.schema == new StructType().add("id", IntegerType)) | |
} | |
} | |
- test("AlterTable: nested drop column must exist") { | |
+ test("AlterTable: nested drop column must exist if required") { | |
val t = s"${catalogAndNamespace}table_name" | |
withTable(t) { | |
sql(s"CREATE TABLE $t (id int) USING $v2Format") | |
@@ -1093,6 +1098,27 @@ trait AlterTableTests extends SharedSparkSession { | |
} | |
assert(exc.getMessage.contains("Missing field point.x")) | |
+ | |
+ // with if exists it should pass | |
+ sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS point.x") | |
+ val table = getTableMetadata(fullTableName(t)) | |
+ assert(table.schema == new StructType().add("id", IntegerType)) | |
+ | |
+ } | |
+ } | |
+ | |
+ test("AlterTable: drop mixed existing/non-existing columns using IF EXISTS") { | |
+ val t = s"${catalogAndNamespace}table_name" | |
+ withTable(t) { | |
+ sql(s"CREATE TABLE $t (id int, name string, points array<struct<x: double, y: double>>) " + | |
+ s"USING $v2Format") | |
+ | |
+ // with if exists it should pass | |
+ sql(s"ALTER TABLE $t DROP COLUMNS IF EXISTS " + | |
+ s"names, name, points.element.z, id, points.element.x") | |
+ val table = getTableMetadata(fullTableName(t)) | |
+ assert(table.schema == new StructType() | |
+ .add("points", ArrayType(StructType(Seq(StructField("y", DoubleType)))))) | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala | |
index 91ac7db335..98d95e48f5 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala | |
@@ -17,8 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
- | |
import org.scalatest.BeforeAndAfter | |
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} | |
@@ -83,10 +81,10 @@ class DataSourceV2DataFrameSessionCatalogSuite | |
test("saveAsTable passes path and provider information properly") { | |
val t1 = "prop_table" | |
withTable(t1) { | |
- spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1) | |
+ spark.range(20).write.format(v2Format).option("path", "/abc").saveAsTable(t1) | |
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] | |
val tableInfo = cat.loadTable(Identifier.of(Array("default"), t1)) | |
- assert(tableInfo.properties().get("location") === "abc") | |
+ assert(tableInfo.properties().get("location") === "file:/abc") | |
assert(tableInfo.properties().get("provider") === v2Format) | |
} | |
} | |
@@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable | |
name: String, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): InMemoryTable = { | |
+ properties: java.util.Map[String, String]): InMemoryTable = { | |
new InMemoryTable(name, schema, partitions, properties) | |
} | |
@@ -210,7 +208,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio | |
verifyTable(t1, df) | |
// Check that appends are by name | |
- df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1) | |
+ df.select(Symbol("data"), Symbol("id")).write.format(v2Format).mode("append").saveAsTable(t1) | |
verifyTable(t1, df.union(df)) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala | |
index 951d787571..03dcfcf7dd 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala | |
@@ -21,7 +21,7 @@ import java.util.Collections | |
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} | |
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException | |
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan} | |
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} | |
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} | |
import org.apache.spark.sql.execution.QueryExecution | |
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation | |
@@ -93,7 +93,7 @@ class DataSourceV2DataFrameSuite | |
assert(spark.table(t1).count() === 0) | |
// appends are by name not by position | |
- df.select('data, 'id).write.mode("append").saveAsTable(t1) | |
+ df.select(Symbol("data"), Symbol("id")).write.mode("append").saveAsTable(t1) | |
checkAnswer(spark.table(t1), df) | |
} | |
} | |
@@ -207,4 +207,50 @@ class DataSourceV2DataFrameSuite | |
assert(options.get(optionName) === "false") | |
} | |
} | |
+ | |
+ test("CTAS and RTAS should take write options") { | |
+ | |
+ var plan: LogicalPlan = null | |
+ val listener = new QueryExecutionListener { | |
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { | |
+ plan = qe.analyzed | |
+ } | |
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} | |
+ } | |
+ | |
+ try { | |
+ spark.listenerManager.register(listener) | |
+ | |
+ val t1 = "testcat.ns1.ns2.tbl" | |
+ | |
+ val df1 = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") | |
+ df1.write.option("option1", "20").saveAsTable(t1) | |
+ | |
+ sparkContext.listenerBus.waitUntilEmpty() | |
+ plan match { | |
+ case o: CreateTableAsSelect => | |
+ assert(o.writeOptions == Map("option1" -> "20")) | |
+ case other => | |
+ fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + | |
+ s"got ${other.getClass.getName}: $plan") | |
+ } | |
+ checkAnswer(spark.table(t1), df1) | |
+ | |
+ val df2 = Seq((1L, "d"), (2L, "e"), (3L, "f")).toDF("id", "data") | |
+ df2.write.option("option2", "30").mode("overwrite").saveAsTable(t1) | |
+ | |
+ sparkContext.listenerBus.waitUntilEmpty() | |
+ plan match { | |
+ case o: ReplaceTableAsSelect => | |
+ assert(o.writeOptions == Map("option2" -> "30")) | |
+ case other => | |
+ fail(s"Expected to parse ${classOf[ReplaceTableAsSelect].getName} from query," + | |
+ s"got ${other.getClass.getName}: $plan") | |
+ } | |
+ | |
+ checkAnswer(spark.table(t1), df2) | |
+ } finally { | |
+ spark.listenerManager.unregister(listener) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala | |
index ace66199f3..92a5c55210 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala | |
@@ -17,7 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
import java.util.Collections | |
import test.org.apache.spark.sql.connector.catalog.functions._ | |
@@ -37,7 +36,7 @@ import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.types.UTF8String | |
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { | |
- private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] | |
+ private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] | |
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { | |
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) | |
@@ -53,10 +52,73 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { | |
withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { | |
assert(intercept[AnalysisException]( | |
sql("SELECT testcat.strlen('abc')").collect() | |
- ).getMessage.contains("is not a FunctionCatalog")) | |
+ ).getMessage.contains("Catalog testcat does not support functions")) | |
} | |
} | |
+ test("DESCRIBE FUNCTION: only support session catalog") { | |
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl)) | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql("DESCRIBE FUNCTION testcat.abc") | |
+ } | |
+ assert(e.message.contains("Catalog testcat does not support functions")) | |
+ | |
+ val e1 = intercept[AnalysisException] { | |
+ sql("DESCRIBE FUNCTION default.ns1.ns2.fun") | |
+ } | |
+ assert(e1.message.contains("requires a single-part namespace")) | |
+ } | |
+ | |
+ test("SHOW FUNCTIONS: only support session catalog") { | |
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl)) | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql(s"SHOW FUNCTIONS LIKE testcat.abc") | |
+ } | |
+ assert(e.message.contains("Catalog testcat does not support functions")) | |
+ } | |
+ | |
+ test("DROP FUNCTION: only support session catalog") { | |
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl)) | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql("DROP FUNCTION testcat.abc") | |
+ } | |
+ assert(e.message.contains("Catalog testcat does not support DROP FUNCTION")) | |
+ | |
+ val e1 = intercept[AnalysisException] { | |
+ sql("DROP FUNCTION default.ns1.ns2.fun") | |
+ } | |
+ assert(e1.message.contains("requires a single-part namespace")) | |
+ } | |
+ | |
+ test("CREATE FUNCTION: only support session catalog") { | |
+ val e = intercept[AnalysisException] { | |
+ sql("CREATE FUNCTION testcat.ns1.ns2.fun as 'f'") | |
+ } | |
+ assert(e.message.contains("Catalog testcat does not support CREATE FUNCTION")) | |
+ | |
+ val e1 = intercept[AnalysisException] { | |
+ sql("CREATE FUNCTION default.ns1.ns2.fun as 'f'") | |
+ } | |
+ assert(e1.message.contains("requires a single-part namespace")) | |
+ } | |
+ | |
+ test("REFRESH FUNCTION: only support session catalog") { | |
+ addFunction(Identifier.of(Array.empty, "abc"), new JavaStrLen(new JavaStrLenNoImpl)) | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql("REFRESH FUNCTION testcat.abc") | |
+ } | |
+ assert(e.message.contains("Catalog testcat does not support REFRESH FUNCTION")) | |
+ | |
+ val e1 = intercept[AnalysisException] { | |
+ sql("REFRESH FUNCTION default.ns1.ns2.fun") | |
+ } | |
+ assert(e1.message.contains("requires a single-part namespace")) | |
+ } | |
+ | |
test("built-in with non-function catalog should still work") { | |
withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "testcat", | |
"spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala | |
index 44fbc639a5..95624f3f61 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala | |
@@ -17,7 +17,7 @@ | |
package org.apache.spark.sql.connector | |
-import org.apache.spark.sql.{DataFrame, Row, SaveMode} | |
+import org.apache.spark.sql.{DataFrame, SaveMode} | |
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog} | |
class DataSourceV2SQLSessionCatalogSuite | |
@@ -64,22 +64,6 @@ class DataSourceV2SQLSessionCatalogSuite | |
} | |
} | |
- test("SPARK-31624: SHOW TBLPROPERTIES working with V2 tables and the session catalog") { | |
- val t1 = "tbl" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format TBLPROPERTIES " + | |
- "(key='v', key2='v2')") | |
- | |
- checkAnswer(sql(s"SHOW TBLPROPERTIES $t1"), Seq(Row("key", "v"), Row("key2", "v2"))) | |
- | |
- checkAnswer(sql(s"SHOW TBLPROPERTIES $t1('key')"), Row("key", "v")) | |
- | |
- checkAnswer( | |
- sql(s"SHOW TBLPROPERTIES $t1('keyX')"), | |
- Row("keyX", s"Table default.$t1 does not have property: keyX")) | |
- } | |
- } | |
- | |
test("SPARK-33651: allow CREATE EXTERNAL TABLE without LOCATION") { | |
withTable("t") { | |
val prop = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY + "=true" | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala | |
index 68a0f8b282..7470911c9e 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala | |
@@ -18,15 +18,16 @@ | |
package org.apache.spark.sql.connector | |
import java.sql.Timestamp | |
-import java.time.LocalDate | |
+import java.time.{Duration, LocalDate, Period} | |
import scala.collection.JavaConverters._ | |
+import scala.concurrent.duration.MICROSECONDS | |
-import org.apache.spark.SparkException | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} | |
+import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} | |
import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.catalyst.util.DateTimeUtils | |
import org.apache.spark.sql.connector.catalog._ | |
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME | |
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership | |
@@ -125,7 +126,7 @@ class DataSourceV2SQLSuite | |
" PARTITIONED BY (id)" + | |
" TBLPROPERTIES ('bar'='baz')" + | |
" COMMENT 'this is a test table'" + | |
- " LOCATION '/tmp/testcat/table_name'") | |
+ " LOCATION 'file:/tmp/testcat/table_name'") | |
val descriptionDf = spark.sql("DESCRIBE TABLE EXTENDED testcat.table_name") | |
assert(descriptionDf.schema.map(field => (field.name, field.dataType)) | |
=== Seq( | |
@@ -148,7 +149,7 @@ class DataSourceV2SQLSuite | |
Array("# Detailed Table Information", "", ""), | |
Array("Name", "testcat.table_name", ""), | |
Array("Comment", "this is a test table", ""), | |
- Array("Location", "/tmp/testcat/table_name", ""), | |
+ Array("Location", "file:/tmp/testcat/table_name", ""), | |
Array("Provider", "foo", ""), | |
Array(TableCatalog.PROP_OWNER.capitalize, defaultUser, ""), | |
Array("Table Properties", "[bar=baz]", ""))) | |
@@ -173,9 +174,10 @@ class DataSourceV2SQLSuite | |
Row("data_type", "string"), | |
Row("comment", "hello"))) | |
- assertAnalysisError( | |
+ assertAnalysisErrorClass( | |
s"DESCRIBE $t invalid_col", | |
- "cannot resolve 'invalid_col' given input columns: [testcat.tbl.data, testcat.tbl.id]") | |
+ "MISSING_COLUMN", | |
+ Array("invalid_col", "testcat.tbl.id, testcat.tbl.data")) | |
} | |
} | |
@@ -339,13 +341,15 @@ class DataSourceV2SQLSuite | |
} | |
test("CTAS/RTAS: invalid schema if has interval type") { | |
- Seq("CREATE", "REPLACE").foreach { action => | |
- val e1 = intercept[AnalysisException]( | |
- sql(s"$action TABLE table_name USING $v2Format as select interval 1 day")) | |
- assert(e1.getMessage.contains(s"Cannot use interval type in the table schema.")) | |
- val e2 = intercept[AnalysisException]( | |
- sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)")) | |
- assert(e2.getMessage.contains(s"Cannot use interval type in the table schema.")) | |
+ withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { | |
+ Seq("CREATE", "REPLACE").foreach { action => | |
+ val e1 = intercept[AnalysisException]( | |
+ sql(s"$action TABLE table_name USING $v2Format as select interval 1 day")) | |
+ assert(e1.getMessage.contains(s"Cannot use interval type in the table schema.")) | |
+ val e2 = intercept[AnalysisException]( | |
+ sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)")) | |
+ assert(e2.getMessage.contains(s"Cannot use interval type in the table schema.")) | |
+ } | |
} | |
} | |
@@ -404,6 +408,83 @@ class DataSourceV2SQLSuite | |
} | |
} | |
+ test("SPARK-36850: CreateTableAsSelect partitions can be specified using " + | |
+ "PARTITIONED BY and/or CLUSTERED BY") { | |
+ val identifier = "testcat.table_name" | |
+ val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), | |
+ (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") | |
+ df.createOrReplaceTempView("source_table") | |
+ withTable(identifier) { | |
+ spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " + | |
+ s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table") | |
+ val describe = spark.sql(s"DESCRIBE $identifier") | |
+ val part1 = describe | |
+ .filter("col_name = 'Part 0'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part1 === "id") | |
+ val part2 = describe | |
+ .filter("col_name = 'Part 1'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part2 === "bucket(4, data1, data2, data3, data4)") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " + | |
+ "PARTITIONED BY and/or CLUSTERED BY") { | |
+ val identifier = "testcat.table_name" | |
+ val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), | |
+ (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") | |
+ df.createOrReplaceTempView("source_table") | |
+ withTable(identifier) { | |
+ spark.sql(s"CREATE TABLE $identifier USING foo " + | |
+ "AS SELECT id FROM source") | |
+ spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " + | |
+ s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " + | |
+ s"AS SELECT * FROM source_table") | |
+ val describe = spark.sql(s"DESCRIBE $identifier") | |
+ val part1 = describe | |
+ .filter("col_name = 'Part 0'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part1 === "id") | |
+ val part2 = describe | |
+ .filter("col_name = 'Part 1'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37545: CreateTableAsSelect should store location as qualified") { | |
+ val basicIdentifier = "testcat.table_name" | |
+ val atomicIdentifier = "testcat_atomic.table_name" | |
+ Seq(basicIdentifier, atomicIdentifier).foreach { identifier => | |
+ withTable(identifier) { | |
+ spark.sql(s"CREATE TABLE $identifier USING foo LOCATION '/tmp/foo' " + | |
+ "AS SELECT id FROM source") | |
+ val location = spark.sql(s"DESCRIBE EXTENDED $identifier") | |
+ .filter("col_name = 'Location'") | |
+ .select("data_type").head.getString(0) | |
+ assert(location === "file:/tmp/foo") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37546: ReplaceTableAsSelect should store location as qualified") { | |
+ val basicIdentifier = "testcat.table_name" | |
+ val atomicIdentifier = "testcat_atomic.table_name" | |
+ Seq(basicIdentifier, atomicIdentifier).foreach { identifier => | |
+ withTable(identifier) { | |
+ spark.sql(s"CREATE TABLE $identifier USING foo LOCATION '/tmp/foo' " + | |
+ "AS SELECT id, data FROM source") | |
+ spark.sql(s"REPLACE TABLE $identifier USING foo LOCATION '/tmp/foo' " + | |
+ "AS SELECT id FROM source") | |
+ val location = spark.sql(s"DESCRIBE EXTENDED $identifier") | |
+ .filter("col_name = 'Location'") | |
+ .select("data_type").head.getString(0) | |
+ assert(location === "file:/tmp/foo") | |
+ } | |
+ } | |
+ } | |
+ | |
test("ReplaceTableAsSelect: basic v2 implementation.") { | |
val basicCatalog = catalog("testcat").asTableCatalog | |
val atomicCatalog = catalog("testcat_atomic").asTableCatalog | |
@@ -962,7 +1043,8 @@ class DataSourceV2SQLSuite | |
val ex = intercept[AnalysisException] { | |
sql(s"SELECT ns1.ns2.ns3.tbl.id from $t") | |
} | |
- assert(ex.getMessage.contains("cannot resolve 'ns1.ns2.ns3.tbl.id")) | |
+ assert(ex.getErrorClass == "MISSING_COLUMN") | |
+ assert(ex.messageParameters.head == "ns1.ns2.ns3.tbl.id") | |
} | |
} | |
@@ -1015,76 +1097,7 @@ class DataSourceV2SQLSuite | |
sql("SHOW VIEWS FROM testcat") | |
} | |
- assert(exception.getMessage.contains("Catalog testcat doesn't support SHOW VIEWS," + | |
- " only SessionCatalog supports this command.")) | |
- } | |
- | |
- test("CreateNameSpace: basic tests") { | |
- // Session catalog is used. | |
- withNamespace("ns") { | |
- sql("CREATE NAMESPACE ns") | |
- testShowNamespaces("SHOW NAMESPACES", Seq("default", "ns")) | |
- } | |
- | |
- // V2 non-session catalog is used. | |
- withNamespace("testcat.ns1.ns2") { | |
- sql("CREATE NAMESPACE testcat.ns1.ns2") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1")) | |
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2")) | |
- } | |
- | |
- withNamespace("testcat.test") { | |
- withTempDir { tmpDir => | |
- val path = tmpDir.getCanonicalPath | |
- sql(s"CREATE NAMESPACE testcat.test LOCATION '$path'") | |
- val metadata = | |
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("test")).asScala | |
- val catalogPath = metadata(SupportsNamespaces.PROP_LOCATION) | |
- assert(catalogPath.equals(catalogPath)) | |
- } | |
- } | |
- } | |
- | |
- test("CreateNameSpace: test handling of 'IF NOT EXIST'") { | |
- withNamespace("testcat.ns1") { | |
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1") | |
- | |
- // The 'ns1' namespace already exists, so this should fail. | |
- val exception = intercept[NamespaceAlreadyExistsException] { | |
- sql("CREATE NAMESPACE testcat.ns1") | |
- } | |
- assert(exception.getMessage.contains("Namespace 'ns1' already exists")) | |
- | |
- // The following will be no-op since the namespace already exists. | |
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1") | |
- } | |
- } | |
- | |
- test("CreateNameSpace: reserved properties") { | |
- import SupportsNamespaces._ | |
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { | |
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
- val exception = intercept[ParseException] { | |
- sql(s"CREATE NAMESPACE testcat.reservedTest WITH DBPROPERTIES('$key'='dummyVal')") | |
- } | |
- assert(exception.getMessage.contains(s"$key is a reserved namespace property")) | |
- } | |
- } | |
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) { | |
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
- withNamespace("testcat.reservedTest") { | |
- sql(s"CREATE NAMESPACE testcat.reservedTest WITH DBPROPERTIES('$key'='foo')") | |
- assert(sql("DESC NAMESPACE EXTENDED testcat.reservedTest") | |
- .toDF("k", "v") | |
- .where("k='Properties'") | |
- .isEmpty, s"$key is a reserved namespace property and ignored") | |
- val meta = | |
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("reservedTest")) | |
- assert(meta.get(key) == null || !meta.get(key).contains("foo"), | |
- "reserved properties should not have side effects") | |
- } | |
- } | |
- } | |
+ assert(exception.getMessage.contains("Catalog testcat does not support views")) | |
} | |
test("create/replace/alter table - reserved properties") { | |
@@ -1155,8 +1168,9 @@ class DataSourceV2SQLSuite | |
s" ('path'='bar', 'Path'='noop')") | |
val tableCatalog = catalog("testcat").asTableCatalog | |
val identifier = Identifier.of(Array(), "reservedTest") | |
- assert(tableCatalog.loadTable(identifier).properties() | |
- .get(TableCatalog.PROP_LOCATION) == "foo", | |
+ val location = tableCatalog.loadTable(identifier).properties() | |
+ .get(TableCatalog.PROP_LOCATION) | |
+ assert(location.startsWith("file:") && location.endsWith("foo"), | |
"path as a table property should not have side effects") | |
assert(tableCatalog.loadTable(identifier).properties().get("path") == "bar", | |
"path as a table property should not have side effects") | |
@@ -1168,151 +1182,6 @@ class DataSourceV2SQLSuite | |
} | |
} | |
- test("DropNamespace: basic tests") { | |
- // Session catalog is used. | |
- sql("CREATE NAMESPACE ns") | |
- testShowNamespaces("SHOW NAMESPACES", Seq("default", "ns")) | |
- sql("DROP NAMESPACE ns") | |
- testShowNamespaces("SHOW NAMESPACES", Seq("default")) | |
- | |
- // V2 non-session catalog is used. | |
- sql("CREATE NAMESPACE testcat.ns1") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1")) | |
- sql("DROP NAMESPACE testcat.ns1") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq()) | |
- } | |
- | |
- test("DropNamespace: drop non-empty namespace with a non-cascading mode") { | |
- sql("CREATE TABLE testcat.ns1.table (id bigint) USING foo") | |
- sql("CREATE TABLE testcat.ns1.ns2.table (id bigint) USING foo") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1")) | |
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2")) | |
- | |
- def assertDropFails(): Unit = { | |
- val e = intercept[SparkException] { | |
- sql("DROP NAMESPACE testcat.ns1") | |
- } | |
- assert(e.getMessage.contains("Cannot drop a non-empty namespace: ns1")) | |
- } | |
- | |
- // testcat.ns1.table is present, thus testcat.ns1 cannot be dropped. | |
- assertDropFails() | |
- sql("DROP TABLE testcat.ns1.table") | |
- | |
- // testcat.ns1.ns2.table is present, thus testcat.ns1 cannot be dropped. | |
- assertDropFails() | |
- sql("DROP TABLE testcat.ns1.ns2.table") | |
- | |
- // testcat.ns1.ns2 namespace is present, thus testcat.ns1 cannot be dropped. | |
- assertDropFails() | |
- sql("DROP NAMESPACE testcat.ns1.ns2") | |
- | |
- // Now that testcat.ns1 is empty, it can be dropped. | |
- sql("DROP NAMESPACE testcat.ns1") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq()) | |
- } | |
- | |
- test("DropNamespace: drop non-empty namespace with a cascade mode") { | |
- sql("CREATE TABLE testcat.ns1.table (id bigint) USING foo") | |
- sql("CREATE TABLE testcat.ns1.ns2.table (id bigint) USING foo") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1")) | |
- testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2")) | |
- | |
- sql("DROP NAMESPACE testcat.ns1 CASCADE") | |
- testShowNamespaces("SHOW NAMESPACES IN testcat", Seq()) | |
- } | |
- | |
- test("DropNamespace: test handling of 'IF EXISTS'") { | |
- sql("DROP NAMESPACE IF EXISTS testcat.unknown") | |
- | |
- val exception = intercept[NoSuchNamespaceException] { | |
- sql("DROP NAMESPACE testcat.ns1") | |
- } | |
- assert(exception.getMessage.contains("Namespace 'ns1' not found")) | |
- } | |
- | |
- test("DescribeNamespace using v2 catalog") { | |
- withNamespace("testcat.ns1.ns2") { | |
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " + | |
- "'test namespace' LOCATION '/tmp/ns_test'") | |
- val descriptionDf = sql("DESCRIBE NAMESPACE testcat.ns1.ns2") | |
- assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === | |
- Seq( | |
- ("info_name", StringType), | |
- ("info_value", StringType) | |
- )) | |
- val description = descriptionDf.collect() | |
- assert(description === Seq( | |
- Row("Namespace Name", "ns2"), | |
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), | |
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"), | |
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser)) | |
- ) | |
- } | |
- } | |
- | |
- test("ALTER NAMESPACE .. SET PROPERTIES using v2 catalog") { | |
- withNamespace("testcat.ns1.ns2") { | |
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " + | |
- "'test namespace' LOCATION '/tmp/ns_test' WITH PROPERTIES ('a'='a','b'='b','c'='c')") | |
- sql("ALTER NAMESPACE testcat.ns1.ns2 SET PROPERTIES ('a'='b','b'='a')") | |
- val descriptionDf = sql("DESCRIBE NAMESPACE EXTENDED testcat.ns1.ns2") | |
- assert(descriptionDf.collect() === Seq( | |
- Row("Namespace Name", "ns2"), | |
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), | |
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"), | |
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser), | |
- Row("Properties", "((a,b),(b,a),(c,c))")) | |
- ) | |
- } | |
- } | |
- | |
- test("ALTER NAMESPACE .. SET PROPERTIES reserved properties") { | |
- import SupportsNamespaces._ | |
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { | |
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
- withNamespace("testcat.reservedTest") { | |
- sql("CREATE NAMESPACE testcat.reservedTest") | |
- val exception = intercept[ParseException] { | |
- sql(s"ALTER NAMESPACE testcat.reservedTest SET PROPERTIES ('$key'='dummyVal')") | |
- } | |
- assert(exception.getMessage.contains(s"$key is a reserved namespace property")) | |
- } | |
- } | |
- } | |
- withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) { | |
- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
- withNamespace("testcat.reservedTest") { | |
- sql(s"CREATE NAMESPACE testcat.reservedTest") | |
- sql(s"ALTER NAMESPACE testcat.reservedTest SET PROPERTIES ('$key'='foo')") | |
- assert(sql("DESC NAMESPACE EXTENDED testcat.reservedTest") | |
- .toDF("k", "v") | |
- .where("k='Properties'") | |
- .isEmpty, s"$key is a reserved namespace property and ignored") | |
- val meta = | |
- catalog("testcat").asNamespaceCatalog.loadNamespaceMetadata(Array("reservedTest")) | |
- assert(meta.get(key) == null || !meta.get(key).contains("foo"), | |
- "reserved properties should not have side effects") | |
- } | |
- } | |
- } | |
- } | |
- | |
- test("ALTER NAMESPACE .. SET LOCATION using v2 catalog") { | |
- withNamespace("testcat.ns1.ns2") { | |
- sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1.ns2 COMMENT " + | |
- "'test namespace' LOCATION '/tmp/ns_test_1'") | |
- sql("ALTER NAMESPACE testcat.ns1.ns2 SET LOCATION '/tmp/ns_test_2'") | |
- val descriptionDf = sql("DESCRIBE NAMESPACE EXTENDED testcat.ns1.ns2") | |
- assert(descriptionDf.collect() === Seq( | |
- Row("Namespace Name", "ns2"), | |
- Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), | |
- Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test_2"), | |
- Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser)) | |
- ) | |
- } | |
- } | |
- | |
private def testShowNamespaces( | |
sqlText: String, | |
expected: Seq[String]): Unit = { | |
@@ -1335,6 +1204,7 @@ class DataSourceV2SQLSuite | |
sql("CREATE TABLE testcat2.ns2.ns2_2.table (id bigint) USING foo") | |
sql("CREATE TABLE testcat2.ns3.ns3_3.table (id bigint) USING foo") | |
sql("CREATE TABLE testcat2.testcat.table (id bigint) USING foo") | |
+ sql("CREATE TABLE testcat2.testcat.ns1.ns1_1.table (id bigint) USING foo") | |
// Catalog is resolved to 'testcat'. | |
sql("USE testcat.ns1.ns1_1") | |
@@ -1356,6 +1226,11 @@ class DataSourceV2SQLSuite | |
assert(catalogManager.currentCatalog.name() == "testcat2") | |
assert(catalogManager.currentNamespace === Array("testcat")) | |
+ // Only the namespace is changed (explicit). | |
+ sql("USE NAMESPACE testcat.ns1.ns1_1") | |
+ assert(catalogManager.currentCatalog.name() == "testcat2") | |
+ assert(catalogManager.currentNamespace === Array("testcat", "ns1", "ns1_1")) | |
+ | |
// Catalog is resolved to `testcat`. | |
sql("USE testcat") | |
assert(catalogManager.currentCatalog.name() == "testcat") | |
@@ -1609,6 +1484,27 @@ class DataSourceV2SQLSuite | |
} | |
} | |
+ test("create table using - with sorted bucket") { | |
+ val identifier = "testcat.table_name" | |
+ withTable(identifier) { | |
+ sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" + | |
+ s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS") | |
+ val describe = spark.sql(s"DESCRIBE $identifier") | |
+ val part1 = describe | |
+ .filter("col_name = 'Part 0'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part1 === "a") | |
+ val part2 = describe | |
+ .filter("col_name = 'Part 1'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part2 === "b") | |
+ val part3 = describe | |
+ .filter("col_name = 'Part 2'") | |
+ .select("data_type").head.getString(0) | |
+ assert(part3 === "sorted_bucket(c, d, 4, e, f)") | |
+ } | |
+ } | |
+ | |
test("REFRESH TABLE: v2 table") { | |
val t = "testcat.ns1.ns2.tbl" | |
withTable(t) { | |
@@ -1805,12 +1701,20 @@ class DataSourceV2SQLSuite | |
"Table or view not found") | |
// UPDATE non-existing column | |
- assertAnalysisError( | |
+ assertAnalysisErrorClass( | |
s"UPDATE $t SET dummy='abc'", | |
- "cannot resolve") | |
- assertAnalysisError( | |
+ "MISSING_COLUMN", | |
+ Array( | |
+ "dummy", | |
+ "testcat.ns1.ns2.tbl.p, testcat.ns1.ns2.tbl.id, " + | |
+ "testcat.ns1.ns2.tbl.age, testcat.ns1.ns2.tbl.name")) | |
+ assertAnalysisErrorClass( | |
s"UPDATE $t SET name='abc' WHERE dummy=1", | |
- "cannot resolve") | |
+ "MISSING_COLUMN", | |
+ Array( | |
+ "dummy", | |
+ "testcat.ns1.ns2.tbl.p, testcat.ns1.ns2.tbl.id, " + | |
+ "testcat.ns1.ns2.tbl.age, testcat.ns1.ns2.tbl.name")) | |
// UPDATE is not implemented yet. | |
val e = intercept[UnsupportedOperationException] { | |
@@ -1961,109 +1865,6 @@ class DataSourceV2SQLSuite | |
} | |
} | |
- test("SPARK-33898: SHOW CREATE TABLE AS SERDE") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") | |
- val e = intercept[AnalysisException] { | |
- sql(s"SHOW CREATE TABLE $t AS SERDE") | |
- } | |
- assert(e.message.contains(s"SHOW CREATE TABLE AS SERDE is not supported for v2 tables.")) | |
- } | |
- } | |
- | |
- test("SPARK-33898: SHOW CREATE TABLE") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- sql( | |
- s""" | |
- |CREATE TABLE $t ( | |
- | a bigint NOT NULL, | |
- | b bigint, | |
- | c bigint, | |
- | `extra col` ARRAY<INT>, | |
- | `<another>` STRUCT<x: INT, y: ARRAY<BOOLEAN>> | |
- |) | |
- |USING foo | |
- |OPTIONS ( | |
- | from = 0, | |
- | to = 1, | |
- | via = 2) | |
- |COMMENT 'This is a comment' | |
- |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4) | |
- |PARTITIONED BY (a) | |
- |LOCATION '/tmp' | |
- """.stripMargin) | |
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") | |
- assert(showDDL === Array( | |
- "CREATE TABLE testcat.ns1.ns2.tbl (", | |
- "`a` BIGINT NOT NULL,", | |
- "`b` BIGINT,", | |
- "`c` BIGINT,", | |
- "`extra col` ARRAY<INT>,", | |
- "`<another>` STRUCT<`x`: INT, `y`: ARRAY<BOOLEAN>>)", | |
- "USING foo", | |
- "OPTIONS(", | |
- "'from' = '0',", | |
- "'to' = '1',", | |
- "'via' = '2')", | |
- "PARTITIONED BY (a)", | |
- "COMMENT 'This is a comment'", | |
- "LOCATION '/tmp'", | |
- "TBLPROPERTIES(", | |
- "'prop1' = '1',", | |
- "'prop2' = '2',", | |
- "'prop3' = '3',", | |
- "'prop4' = '4')" | |
- )) | |
- } | |
- } | |
- | |
- test("SPARK-33898: SHOW CREATE TABLE WITH AS SELECT") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- sql( | |
- s""" | |
- |CREATE TABLE $t | |
- |USING foo | |
- |AS SELECT 1 AS a, "foo" AS b | |
- """.stripMargin) | |
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") | |
- assert(showDDL === Array( | |
- "CREATE TABLE testcat.ns1.ns2.tbl (", | |
- "`a` INT,", | |
- "`b` STRING)", | |
- "USING foo" | |
- )) | |
- } | |
- } | |
- | |
- test("SPARK-33898: SHOW CREATE TABLE PARTITIONED BY Transforms") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- sql( | |
- s""" | |
- |CREATE TABLE $t (a INT, b STRING, ts TIMESTAMP) USING foo | |
- |PARTITIONED BY ( | |
- | a, | |
- | bucket(16, b), | |
- | years(ts), | |
- | months(ts), | |
- | days(ts), | |
- | hours(ts)) | |
- """.stripMargin) | |
- val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") | |
- assert(showDDL === Array( | |
- "CREATE TABLE testcat.ns1.ns2.tbl (", | |
- "`a` INT,", | |
- "`b` STRING,", | |
- "`ts` TIMESTAMP)", | |
- "USING foo", | |
- "PARTITIONED BY (a, bucket(16, b), years(ts), months(ts), days(ts), hours(ts))" | |
- )) | |
- } | |
- } | |
- | |
test("CACHE/UNCACHE TABLE") { | |
val t = "testcat.ns1.ns2.tbl" | |
withTable(t) { | |
@@ -2117,120 +1918,7 @@ class DataSourceV2SQLSuite | |
val e = intercept[AnalysisException] { | |
sql(s"CREATE VIEW $v AS SELECT 1") | |
} | |
- assert(e.message.contains("CREATE VIEW is only supported with v1 tables")) | |
- } | |
- | |
- test("SHOW TBLPROPERTIES: v2 table") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- val user = "andrew" | |
- val status = "new" | |
- val provider = "foo" | |
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING $provider " + | |
- s"TBLPROPERTIES ('user'='$user', 'status'='$status')") | |
- | |
- val properties = sql(s"SHOW TBLPROPERTIES $t") | |
- | |
- val schema = new StructType() | |
- .add("key", StringType, nullable = false) | |
- .add("value", StringType, nullable = false) | |
- | |
- val expected = Seq( | |
- Row("status", status), | |
- Row("user", user)) | |
- | |
- assert(properties.schema === schema) | |
- assert(expected === properties.collect()) | |
- } | |
- } | |
- | |
- test("SHOW TBLPROPERTIES(key): v2 table") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- val user = "andrew" | |
- val status = "new" | |
- val provider = "foo" | |
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING $provider " + | |
- s"TBLPROPERTIES ('user'='$user', 'status'='$status')") | |
- | |
- val properties = sql(s"SHOW TBLPROPERTIES $t ('status')") | |
- | |
- val expected = Seq(Row("status", status)) | |
- | |
- assert(expected === properties.collect()) | |
- } | |
- } | |
- | |
- test("SHOW TBLPROPERTIES(key): v2 table, key not found") { | |
- val t = "testcat.ns1.ns2.tbl" | |
- withTable(t) { | |
- val nonExistingKey = "nonExistingKey" | |
- spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo " + | |
- s"TBLPROPERTIES ('user'='andrew', 'status'='new')") | |
- | |
- val properties = sql(s"SHOW TBLPROPERTIES $t ('$nonExistingKey')") | |
- | |
- val expected = Seq(Row(nonExistingKey, s"Table $t does not have property: $nonExistingKey")) | |
- | |
- assert(expected === properties.collect()) | |
- } | |
- } | |
- | |
- test("DESCRIBE FUNCTION: only support session catalog") { | |
- val e = intercept[AnalysisException] { | |
- sql("DESCRIBE FUNCTION testcat.ns1.ns2.fun") | |
- } | |
- assert(e.message.contains("function is only supported in v1 catalog")) | |
- | |
- val e1 = intercept[AnalysisException] { | |
- sql("DESCRIBE FUNCTION default.ns1.ns2.fun") | |
- } | |
- assert(e1.message.contains("requires a single-part namespace")) | |
- } | |
- | |
- test("SHOW FUNCTIONS not valid v1 namespace") { | |
- val function = "testcat.ns1.ns2.fun" | |
- | |
- val e = intercept[AnalysisException] { | |
- sql(s"SHOW FUNCTIONS LIKE $function") | |
- } | |
- assert(e.message.contains("function is only supported in v1 catalog")) | |
- } | |
- | |
- test("DROP FUNCTION: only support session catalog") { | |
- val e = intercept[AnalysisException] { | |
- sql("DROP FUNCTION testcat.ns1.ns2.fun") | |
- } | |
- assert(e.message.contains("function is only supported in v1 catalog")) | |
- | |
- val e1 = intercept[AnalysisException] { | |
- sql("DROP FUNCTION default.ns1.ns2.fun") | |
- } | |
- assert(e1.message.contains("requires a single-part namespace")) | |
- } | |
- | |
- test("CREATE FUNCTION: only support session catalog") { | |
- val e = intercept[AnalysisException] { | |
- sql("CREATE FUNCTION testcat.ns1.ns2.fun as 'f'") | |
- } | |
- assert(e.message.contains("function is only supported in v1 catalog")) | |
- | |
- val e1 = intercept[AnalysisException] { | |
- sql("CREATE FUNCTION default.ns1.ns2.fun as 'f'") | |
- } | |
- assert(e1.message.contains("requires a single-part namespace")) | |
- } | |
- | |
- test("REFRESH FUNCTION: only support session catalog") { | |
- val e = intercept[AnalysisException] { | |
- sql("REFRESH FUNCTION testcat.ns1.ns2.fun") | |
- } | |
- assert(e.message.contains("function is only supported in v1 catalog")) | |
- | |
- val e1 = intercept[AnalysisException] { | |
- sql("REFRESH FUNCTION default.ns1.ns2.fun") | |
- } | |
- assert(e1.message.contains("requires a single-part namespace")) | |
+ assert(e.message.contains("Catalog testcat does not support views")) | |
} | |
test("global temp view should not be masked by v2 catalog") { | |
@@ -2457,14 +2145,6 @@ class DataSourceV2SQLSuite | |
.head().getString(1) === expectedComment) | |
} | |
- test("SPARK-30799: temp view name can't contain catalog name") { | |
- val sessionCatalogName = CatalogManager.SESSION_CATALOG_NAME | |
- val e2 = intercept[AnalysisException] { | |
- sql(s"CREATE TEMP VIEW $sessionCatalogName.v AS SELECT 1") | |
- } | |
- assert(e2.message.contains("It is not allowed to add database prefix")) | |
- } | |
- | |
test("SPARK-31015: star expression should work for qualified column names for v2 tables") { | |
val t = "testcat.ns1.ns2.tbl" | |
withTable(t) { | |
@@ -2524,100 +2204,6 @@ class DataSourceV2SQLSuite | |
} | |
} | |
- test("SPARK-31255: Project a metadata column") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- | |
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1") | |
- val dfQuery = spark.table(t1).select("id", "data", "index", "_partition") | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
- } | |
- } | |
- } | |
- | |
- test("SPARK-31255: Projects data column when metadata column has the same name") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, index), index)") | |
- sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") | |
- | |
- val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1") | |
- val dfQuery = spark.table(t1).select("index", "data", "_partition") | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) | |
- } | |
- } | |
- } | |
- | |
- test("SPARK-31255: * expansion does not include metadata columns") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')") | |
- | |
- val sqlQuery = spark.sql(s"SELECT * FROM $t1") | |
- val dfQuery = spark.table(t1) | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a"))) | |
- } | |
- } | |
- } | |
- | |
- test("SPARK-31255: metadata column should only be produced when necessary") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- | |
- val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0") | |
- val dfQuery = spark.table(t1).filter("index = 0") | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- assert(query.schema.fieldNames.toSeq == Seq("id", "data")) | |
- } | |
- } | |
- } | |
- | |
- test("SPARK-34547: metadata columns are resolved last") { | |
- val t1 = s"${catalogAndNamespace}tableOne" | |
- val t2 = "t2" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- withTempView(t2) { | |
- sql(s"CREATE TEMPORARY VIEW $t2 AS SELECT * FROM " + | |
- s"VALUES (1, -1), (2, -2), (3, -3) AS $t2(id, index)") | |
- | |
- val sqlQuery = spark.sql(s"SELECT $t1.id, $t2.id, data, index, $t1.index, $t2.index FROM " + | |
- s"$t1 JOIN $t2 WHERE $t1.id = $t2.id") | |
- val t1Table = spark.table(t1) | |
- val t2Table = spark.table(t2) | |
- val dfQuery = t1Table.join(t2Table, t1Table.col("id") === t2Table.col("id")) | |
- .select(s"$t1.id", s"$t2.id", "data", "index", s"$t1.index", s"$t2.index") | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, | |
- Seq( | |
- Row(1, 1, "a", -1, 0, -1), | |
- Row(2, 2, "b", -2, 0, -2), | |
- Row(3, 3, "c", -3, 0, -3) | |
- ) | |
- ) | |
- } | |
- } | |
- } | |
- } | |
- | |
test("SPARK-33505: insert into partitioned table") { | |
val t = "testpart.ns1.ns2.tbl" | |
withTable(t) { | |
@@ -2702,27 +2288,6 @@ class DataSourceV2SQLSuite | |
} | |
} | |
- test("SPARK-34555: Resolve DataFrame metadata column") { | |
- val tbl = s"${catalogAndNamespace}table" | |
- withTable(tbl) { | |
- sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- val table = spark.table(tbl) | |
- val dfQuery = table.select( | |
- table.col("id"), | |
- table.col("data"), | |
- table.col("index"), | |
- table.col("_partition") | |
- ) | |
- | |
- checkAnswer( | |
- dfQuery, | |
- Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) | |
- ) | |
- } | |
- } | |
- | |
test("SPARK-34561: drop/add columns to a dataset of `DESCRIBE TABLE`") { | |
val tbl = s"${catalogAndNamespace}tbl" | |
withTable(tbl) { | |
@@ -2766,125 +2331,265 @@ class DataSourceV2SQLSuite | |
} | |
} | |
- test("SPARK-34577: drop/add columns to a dataset of `DESCRIBE NAMESPACE`") { | |
- withNamespace("ns") { | |
- sql("CREATE NAMESPACE ns") | |
- val description = sql(s"DESCRIBE NAMESPACE ns") | |
- val noCommentDataset = description.drop("info_name") | |
- val expectedSchema = new StructType() | |
- .add( | |
- name = "info_value", | |
- dataType = StringType, | |
- nullable = true, | |
- metadata = new MetadataBuilder() | |
- .putString("comment", "value of the namespace info").build()) | |
- assert(noCommentDataset.schema === expectedSchema) | |
- val isNullDataset = noCommentDataset | |
- .withColumn("is_null", noCommentDataset("info_value").isNull) | |
- assert(isNullDataset.schema === expectedSchema.add("is_null", BooleanType, false)) | |
- } | |
+ test("SPARK-36481: Test for SET CATALOG statement") { | |
+ val catalogManager = spark.sessionState.catalogManager | |
+ assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME) | |
+ | |
+ sql("SET CATALOG testcat") | |
+ assert(catalogManager.currentCatalog.name() == "testcat") | |
+ | |
+ sql("SET CATALOG testcat2") | |
+ assert(catalogManager.currentCatalog.name() == "testcat2") | |
+ | |
+ val errMsg = intercept[CatalogNotFoundException] { | |
+ sql("SET CATALOG not_exist_catalog") | |
+ }.getMessage | |
+ assert(errMsg.contains("Catalog 'not_exist_catalog' plugin class not found")) | |
} | |
- test("SPARK-34923: do not propagate metadata columns through Project") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
+ test("SPARK-35973: ShowCatalogs") { | |
+ val schema = new StructType() | |
+ .add("catalog", StringType, nullable = false) | |
- assertThrows[AnalysisException] { | |
- sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)") | |
- } | |
- assertThrows[AnalysisException] { | |
- spark.table(t1).select("id", "data").select("index", "_partition") | |
- } | |
- } | |
+ val df = sql("SHOW CATALOGS") | |
+ assert(df.schema === schema) | |
+ assert(df.collect === Array(Row("spark_catalog"))) | |
+ | |
+ sql("use testcat") | |
+ sql("use testpart") | |
+ sql("use testcat2") | |
+ assert(sql("SHOW CATALOGS").collect === Array( | |
+ Row("spark_catalog"), Row("testcat"), Row("testcat2"), Row("testpart"))) | |
+ | |
+ assert(sql("SHOW CATALOGS LIKE 'test*'").collect === Array( | |
+ Row("testcat"), Row("testcat2"), Row("testpart"))) | |
+ | |
+ assert(sql("SHOW CATALOGS LIKE 'testcat*'").collect === Array( | |
+ Row("testcat"), Row("testcat2"))) | |
} | |
- test("SPARK-34923: do not propagate metadata columns through View") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- val view = "view" | |
+ test("CREATE INDEX should fail") { | |
+ val t = "testcat.tbl" | |
+ withTable(t) { | |
+ sql(s"CREATE TABLE $t (id bigint, data string COMMENT 'hello') USING foo") | |
+ val e1 = intercept[AnalysisException] { | |
+ sql(s"CREATE index i1 ON $t(non_exist)") | |
+ } | |
+ assert(e1.getMessage.contains(s"Missing field non_exist in table $t")) | |
- withTable(t1) { | |
- withTempView(view) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- sql(s"CACHE TABLE $view AS SELECT * FROM $t1") | |
- assertThrows[AnalysisException] { | |
- sql(s"SELECT index, _partition FROM $view") | |
+ val e2 = intercept[AnalysisException] { | |
+ sql(s"CREATE index i1 ON $t(id)") | |
+ } | |
+ assert(e2.getMessage.contains(s"CreateIndex is not supported in this table $t.")) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37294: insert ANSI intervals into a table partitioned by the interval columns") { | |
+ val tbl = "testpart.interval_table" | |
+ Seq(PartitionOverwriteMode.DYNAMIC, PartitionOverwriteMode.STATIC).foreach { mode => | |
+ withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> mode.toString) { | |
+ withTable(tbl) { | |
+ sql( | |
+ s""" | |
+ |CREATE TABLE $tbl (i INT, part1 INTERVAL YEAR, part2 INTERVAL DAY) USING $v2Format | |
+ |PARTITIONED BY (part1, part2) | |
+ """.stripMargin) | |
+ | |
+ sql( | |
+ s"""ALTER TABLE $tbl ADD PARTITION ( | |
+ |part1 = INTERVAL '2' YEAR, | |
+ |part2 = INTERVAL '3' DAY)""".stripMargin) | |
+ sql(s"INSERT OVERWRITE TABLE $tbl SELECT 1, INTERVAL '2' YEAR, INTERVAL '3' DAY") | |
+ sql(s"INSERT INTO TABLE $tbl SELECT 4, INTERVAL '5' YEAR, INTERVAL '6' DAY") | |
+ sql( | |
+ s""" | |
+ |INSERT INTO $tbl | |
+ | PARTITION (part1 = INTERVAL '8' YEAR, part2 = INTERVAL '9' DAY) | |
+ |SELECT 7""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ spark.table(tbl), | |
+ Seq(Row(1, Period.ofYears(2), Duration.ofDays(3)), | |
+ Row(4, Period.ofYears(5), Duration.ofDays(6)), | |
+ Row(7, Period.ofYears(8), Duration.ofDays(9)))) | |
} | |
} | |
} | |
} | |
- test("SPARK-34923: propagate metadata columns through Filter") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
+ test("Check HasPartitionKey from InMemoryPartitionTable") { | |
+ val t = "testpart.tbl" | |
+ withTable(t) { | |
+ sql(s"CREATE TABLE $t (id string) USING foo PARTITIONED BY (key int)") | |
+ val table = catalog("testpart").asTableCatalog | |
+ .loadTable(Identifier.of(Array(), "tbl")) | |
+ .asInstanceOf[InMemoryPartitionTable] | |
+ | |
+ sql(s"INSERT INTO $t VALUES ('a', 1), ('b', 2), ('c', 3)") | |
+ var partKeys = table.data.map(_.partitionKey().getInt(0)) | |
+ assert(partKeys.length == 3) | |
+ assert(partKeys.toSet == Set(1, 2, 3)) | |
+ | |
+ sql(s"ALTER TABLE $t DROP PARTITION (key=3)") | |
+ partKeys = table.data.map(_.partitionKey().getInt(0)) | |
+ assert(partKeys.length == 2) | |
+ assert(partKeys.toSet == Set(1, 2)) | |
+ | |
+ sql(s"ALTER TABLE $t ADD PARTITION (key=4)") | |
+ partKeys = table.data.map(_.partitionKey().getInt(0)) | |
+ assert(partKeys.length == 3) | |
+ assert(partKeys.toSet == Set(1, 2, 4)) | |
+ | |
+ sql(s"INSERT INTO $t VALUES ('c', 3), ('e', 5)") | |
+ partKeys = table.data.map(_.partitionKey().getInt(0)) | |
+ assert(partKeys.length == 5) | |
+ assert(partKeys.toSet == Set(1, 2, 3, 4, 5)) | |
+ } | |
+ } | |
+ | |
+ test("time travel") { | |
+ sql("use testcat") | |
+ // The testing in-memory table simply append the version/timestamp to the table name when | |
+ // looking up tables. | |
+ val t1 = "testcat.tSnapshot123456789" | |
+ val t2 = "testcat.t2345678910" | |
+ withTable(t1, t2) { | |
+ sql(s"CREATE TABLE $t1 (id int) USING foo") | |
+ sql(s"CREATE TABLE $t2 (id int) USING foo") | |
+ | |
+ sql(s"INSERT INTO $t1 VALUES (1)") | |
+ sql(s"INSERT INTO $t1 VALUES (2)") | |
+ sql(s"INSERT INTO $t2 VALUES (3)") | |
+ sql(s"INSERT INTO $t2 VALUES (4)") | |
+ | |
+ assert(sql("SELECT * FROM t VERSION AS OF 'Snapshot123456789'").collect | |
+ === Array(Row(1), Row(2))) | |
+ assert(sql("SELECT * FROM t VERSION AS OF 2345678910").collect | |
+ === Array(Row(3), Row(4))) | |
+ } | |
+ | |
+ val ts1 = DateTimeUtils.stringToTimestampAnsi( | |
+ UTF8String.fromString("2019-01-29 00:37:58"), | |
+ DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) | |
+ val ts2 = DateTimeUtils.stringToTimestampAnsi( | |
+ UTF8String.fromString("2021-01-29 00:00:00"), | |
+ DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) | |
+ val ts1InSeconds = MICROSECONDS.toSeconds(ts1).toString | |
+ val ts2InSeconds = MICROSECONDS.toSeconds(ts2).toString | |
+ val t3 = s"testcat.t$ts1" | |
+ val t4 = s"testcat.t$ts2" | |
+ | |
+ withTable(t3, t4) { | |
+ sql(s"CREATE TABLE $t3 (id int) USING foo") | |
+ sql(s"CREATE TABLE $t4 (id int) USING foo") | |
+ | |
+ sql(s"INSERT INTO $t3 VALUES (5)") | |
+ sql(s"INSERT INTO $t3 VALUES (6)") | |
+ sql(s"INSERT INTO $t4 VALUES (7)") | |
+ sql(s"INSERT INTO $t4 VALUES (8)") | |
+ | |
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF '2019-01-29 00:37:58'").collect | |
+ === Array(Row(5), Row(6))) | |
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF '2021-01-29 00:00:00'").collect | |
+ === Array(Row(7), Row(8))) | |
+ assert(sql(s"SELECT * FROM t TIMESTAMP AS OF $ts1InSeconds").collect | |
+ === Array(Row(5), Row(6))) | |
+ assert(sql(s"SELECT * FROM t TIMESTAMP AS OF $ts2InSeconds").collect | |
+ === Array(Row(7), Row(8))) | |
+ assert(sql(s"SELECT * FROM t FOR SYSTEM_TIME AS OF $ts1InSeconds").collect | |
+ === Array(Row(5), Row(6))) | |
+ assert(sql(s"SELECT * FROM t FOR SYSTEM_TIME AS OF $ts2InSeconds").collect | |
+ === Array(Row(7), Row(8))) | |
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF make_date(2021, 1, 29)").collect | |
+ === Array(Row(7), Row(8))) | |
+ assert(sql("SELECT * FROM t TIMESTAMP AS OF to_timestamp('2021-01-29 00:00:00')").collect | |
+ === Array(Row(7), Row(8))) | |
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1") | |
- val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition") | |
+ val e1 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM t TIMESTAMP AS OF INTERVAL 1 DAY").collect() | |
+ ) | |
+ assert(e1.message.contains("is not a valid timestamp expression for time travel")) | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
- } | |
- } | |
- } | |
+ val e2 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM t TIMESTAMP AS OF 'abc'").collect() | |
+ ) | |
+ assert(e2.message.contains("is not a valid timestamp expression for time travel")) | |
- test("SPARK-34923: propagate metadata columns through Sort") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
+ val e3 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM t TIMESTAMP AS OF current_user()").collect() | |
+ ) | |
+ assert(e3.message.contains("is not a valid timestamp expression for time travel")) | |
- val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id") | |
- val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition") | |
+ val e4 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM t TIMESTAMP AS OF CAST(rand() AS STRING)").collect() | |
+ ) | |
+ assert(e4.message.contains("is not a valid timestamp expression for time travel")) | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
- } | |
- } | |
- } | |
+ val e5 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM t TIMESTAMP AS OF abs(true)").collect() | |
+ ) | |
+ assert(e5.message.contains("cannot resolve 'abs(true)' due to data type mismatch")) | |
- test("SPARK-34923: propagate metadata columns through RepartitionBy") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- | |
- val sqlQuery = spark.sql( | |
- s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1") | |
- val tbl = spark.table(t1) | |
- val dfQuery = tbl.repartitionByRange(3, tbl.col("id")) | |
- .select("id", "data", "index", "_partition") | |
- | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
- } | |
+ val e6 = intercept[AnalysisException]( | |
+ sql("SELECT * FROM parquet.`/the/path` VERSION AS OF 1") | |
+ ) | |
+ assert(e6.message.contains("Cannot time travel path-based tables")) | |
+ | |
+ val e7 = intercept[AnalysisException]( | |
+ sql("WITH x AS (SELECT 1) SELECT * FROM x VERSION AS OF 1") | |
+ ) | |
+ assert(e7.message.contains("Cannot time travel subqueries from WITH clause")) | |
} | |
} | |
- test("SPARK-34923: propagate metadata columns through SubqueryAlias") { | |
- val t1 = s"${catalogAndNamespace}table" | |
- val sbq = "sbq" | |
- withTable(t1) { | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + | |
- "PARTITIONED BY (bucket(4, id), id)") | |
- sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
- | |
- val sqlQuery = spark.sql( | |
- s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq") | |
- val dfQuery = spark.table(t1).as(sbq).select( | |
- s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") | |
+ test("SPARK-37827: put build-in properties into V1Table.properties to adapt v2 command") { | |
+ val t = "tbl" | |
+ withTable(t) { | |
+ sql( | |
+ s""" | |
+ |CREATE TABLE $t ( | |
+ | a bigint, | |
+ | b bigint | |
+ |) | |
+ |using parquet | |
+ |OPTIONS ( | |
+ | from = 0, | |
+ | to = 1) | |
+ |COMMENT 'This is a comment' | |
+ |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2') | |
+ |PARTITIONED BY (a) | |
+ |LOCATION '/tmp' | |
+ """.stripMargin) | |
- Seq(sqlQuery, dfQuery).foreach { query => | |
- checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
- } | |
+ val table = spark.sessionState.catalogManager.v2SessionCatalog.asTableCatalog | |
+ .loadTable(Identifier.of(Array("default"), t)) | |
+ val properties = table.properties | |
+ assert(properties.get(TableCatalog.PROP_PROVIDER) == "parquet") | |
+ assert(properties.get(TableCatalog.PROP_COMMENT) == "This is a comment") | |
+ assert(properties.get(TableCatalog.PROP_LOCATION) == "file:///tmp") | |
+ assert(properties.containsKey(TableCatalog.PROP_OWNER)) | |
+ assert(properties.get(TableCatalog.PROP_EXTERNAL) == "true") | |
+ assert(properties.get(s"${TableCatalog.OPTION_PREFIX}from") == "0") | |
+ assert(properties.get(s"${TableCatalog.OPTION_PREFIX}to") == "1") | |
+ assert(properties.get("prop1") == "1") | |
+ assert(properties.get("prop2") == "2") | |
+ } | |
+ } | |
+ | |
+ test("SPARK-41154: Incorrect relation caching for queries with time travel spec") { | |
+ sql("use testcat") | |
+ val t1 = "testcat.t1" | |
+ val t2 = "testcat.t2" | |
+ withTable(t1, t2) { | |
+ sql(s"CREATE TABLE $t1 USING foo AS SELECT 1 as c") | |
+ sql(s"CREATE TABLE $t2 USING foo AS SELECT 2 as c") | |
+ assert( | |
+ sql(""" | |
+ |SELECT * FROM t VERSION AS OF '1' | |
+ |UNION ALL | |
+ |SELECT * FROM t VERSION AS OF '2' | |
+ |""".stripMargin | |
+ ).collect() === Array(Row(1), Row(2))) | |
} | |
} | |
@@ -2895,15 +2600,24 @@ class DataSourceV2SQLSuite | |
assert(e.message.contains(s"$sqlCommand is not supported for v2 tables")) | |
} | |
- private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { | |
- val errMsg = intercept[AnalysisException] { | |
+ private def assertAnalysisError( | |
+ sqlStatement: String, | |
+ expectedError: String): Unit = { | |
+ val ex = intercept[AnalysisException] { | |
sql(sqlStatement) | |
- }.getMessage | |
- assert(errMsg.contains(expectedError)) | |
+ } | |
+ assert(ex.getMessage.contains(expectedError)) | |
} | |
- private def getShowCreateDDL(showCreateTableSql: String): Array[String] = { | |
- sql(showCreateTableSql).head().getString(0).split("\n").map(_.trim) | |
+ private def assertAnalysisErrorClass( | |
+ sqlStatement: String, | |
+ expectedErrorClass: String, | |
+ expectedErrorMessageParameters: Array[String]): Unit = { | |
+ val ex = intercept[AnalysisException] { | |
+ sql(sqlStatement) | |
+ } | |
+ assert(ex.getErrorClass == expectedErrorClass) | |
+ assert(ex.messageParameters.sameElements(expectedErrorMessageParameters)) | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala | |
index b42d48d873..491d27e546 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala | |
@@ -18,11 +18,8 @@ | |
package org.apache.spark.sql.connector | |
import java.io.File | |
-import java.util | |
import java.util.OptionalLong | |
-import scala.collection.JavaConverters._ | |
- | |
import test.org.apache.spark.sql.connector._ | |
import org.apache.spark.SparkException | |
@@ -30,14 +27,16 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} | |
import org.apache.spark.sql.connector.catalog.TableCapability._ | |
-import org.apache.spark.sql.connector.expressions.Transform | |
+import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, Transform} | |
+import org.apache.spark.sql.connector.expressions.filter.Predicate | |
import org.apache.spark.sql.connector.read._ | |
-import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} | |
+import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning} | |
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} | |
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} | |
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector | |
import org.apache.spark.sql.functions._ | |
+import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.sources.{Filter, GreaterThan} | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types.{IntegerType, StructType} | |
@@ -54,6 +53,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
}.head | |
} | |
+ private def getBatchWithV2Filter(query: DataFrame): AdvancedBatchWithV2Filter = { | |
+ query.queryExecution.executedPlan.collect { | |
+ case d: BatchScanExec => | |
+ d.batch.asInstanceOf[AdvancedBatchWithV2Filter] | |
+ }.head | |
+ } | |
+ | |
private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = { | |
query.queryExecution.executedPlan.collect { | |
case d: BatchScanExec => | |
@@ -61,13 +67,21 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
}.head | |
} | |
+ private def getJavaBatchWithV2Filter( | |
+ query: DataFrame): JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter = { | |
+ query.queryExecution.executedPlan.collect { | |
+ case d: BatchScanExec => | |
+ d.batch.asInstanceOf[JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter] | |
+ }.head | |
+ } | |
+ | |
test("simplest implementation") { | |
Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => | |
withClue(cls.getName) { | |
val df = spark.read.format(cls.getName).load() | |
checkAnswer(df, (0 until 10).map(i => Row(i, -i))) | |
- checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) | |
- checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) | |
+ checkAnswer(df.select(Symbol("j")), (0 until 10).map(i => Row(-i))) | |
+ checkAnswer(df.filter(Symbol("i") > 5), (6 until 10).map(i => Row(i, -i))) | |
} | |
} | |
} | |
@@ -78,7 +92,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
val df = spark.read.format(cls.getName).load() | |
checkAnswer(df, (0 until 10).map(i => Row(i, -i))) | |
- val q1 = df.select('j) | |
+ val q1 = df.select(Symbol("j")) | |
checkAnswer(q1, (0 until 10).map(i => Row(-i))) | |
if (cls == classOf[AdvancedDataSourceV2]) { | |
val batch = getBatch(q1) | |
@@ -90,7 +104,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
assert(batch.requiredSchema.fieldNames === Seq("j")) | |
} | |
- val q2 = df.filter('i > 3) | |
+ val q2 = df.filter(Symbol("i") > 3) | |
checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) | |
if (cls == classOf[AdvancedDataSourceV2]) { | |
val batch = getBatch(q2) | |
@@ -102,7 +116,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
assert(batch.requiredSchema.fieldNames === Seq("i", "j")) | |
} | |
- val q3 = df.select('i).filter('i > 6) | |
+ val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6) | |
checkAnswer(q3, (7 until 10).map(i => Row(i))) | |
if (cls == classOf[AdvancedDataSourceV2]) { | |
val batch = getBatch(q3) | |
@@ -114,16 +128,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
assert(batch.requiredSchema.fieldNames === Seq("i")) | |
} | |
- val q4 = df.select('j).filter('j < -10) | |
+ val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10) | |
checkAnswer(q4, Nil) | |
if (cls == classOf[AdvancedDataSourceV2]) { | |
val batch = getBatch(q4) | |
- // 'j < 10 is not supported by the testing data source. | |
+ // Symbol("j") < 10 is not supported by the testing data source. | |
assert(batch.filters.isEmpty) | |
assert(batch.requiredSchema.fieldNames === Seq("j")) | |
} else { | |
val batch = getJavaBatch(q4) | |
- // 'j < 10 is not supported by the testing data source. | |
+ // Symbol("j") < 10 is not supported by the testing data source. | |
assert(batch.filters.isEmpty) | |
assert(batch.requiredSchema.fieldNames === Seq("j")) | |
} | |
@@ -131,13 +145,73 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
} | |
} | |
+ test("advanced implementation with V2 Filter") { | |
+ Seq(classOf[AdvancedDataSourceV2WithV2Filter], classOf[JavaAdvancedDataSourceV2WithV2Filter]) | |
+ .foreach { cls => | |
+ withClue(cls.getName) { | |
+ val df = spark.read.format(cls.getName).load() | |
+ checkAnswer(df, (0 until 10).map(i => Row(i, -i))) | |
+ | |
+ val q1 = df.select(Symbol("j")) | |
+ checkAnswer(q1, (0 until 10).map(i => Row(-i))) | |
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { | |
+ val batch = getBatchWithV2Filter(q1) | |
+ assert(batch.predicates.isEmpty) | |
+ assert(batch.requiredSchema.fieldNames === Seq("j")) | |
+ } else { | |
+ val batch = getJavaBatchWithV2Filter(q1) | |
+ assert(batch.predicates.isEmpty) | |
+ assert(batch.requiredSchema.fieldNames === Seq("j")) | |
+ } | |
+ | |
+ val q2 = df.filter(Symbol("i") > 3) | |
+ checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) | |
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { | |
+ val batch = getBatchWithV2Filter(q2) | |
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) | |
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j")) | |
+ } else { | |
+ val batch = getJavaBatchWithV2Filter(q2) | |
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) | |
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j")) | |
+ } | |
+ | |
+ val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6) | |
+ checkAnswer(q3, (7 until 10).map(i => Row(i))) | |
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { | |
+ val batch = getBatchWithV2Filter(q3) | |
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) | |
+ assert(batch.requiredSchema.fieldNames === Seq("i")) | |
+ } else { | |
+ val batch = getJavaBatchWithV2Filter(q3) | |
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) | |
+ assert(batch.requiredSchema.fieldNames === Seq("i")) | |
+ } | |
+ | |
+ val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10) | |
+ checkAnswer(q4, Nil) | |
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { | |
+ val batch = getBatchWithV2Filter(q4) | |
+ // Symbol("j") < 10 is not supported by the testing data source. | |
+ assert(batch.predicates.isEmpty) | |
+ assert(batch.requiredSchema.fieldNames === Seq("j")) | |
+ } else { | |
+ val batch = getJavaBatchWithV2Filter(q4) | |
+ // Symbol("j") < 10 is not supported by the testing data source. | |
+ assert(batch.predicates.isEmpty) | |
+ assert(batch.requiredSchema.fieldNames === Seq("j")) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
test("columnar batch scan implementation") { | |
Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => | |
withClue(cls.getName) { | |
val df = spark.read.format(cls.getName).load() | |
checkAnswer(df, (0 until 90).map(i => Row(i, -i))) | |
- checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) | |
- checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) | |
+ checkAnswer(df.select(Symbol("j")), (0 until 90).map(i => Row(-i))) | |
+ checkAnswer(df.filter(Symbol("i") > 50), (51 until 90).map(i => Row(i, -i))) | |
} | |
} | |
} | |
@@ -161,45 +235,47 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
"supports external metadata") { | |
withTempDir { dir => | |
val cls = classOf[SupportsExternalMetadataWritableDataSource].getName | |
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls) | |
- .option("path", dir.getCanonicalPath).mode("append").save() | |
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls).option("path", dir.getCanonicalPath).mode("append").save() | |
val schema = new StructType().add("i", "long").add("j", "long") | |
checkAnswer( | |
spark.read.format(cls).option("path", dir.getCanonicalPath).schema(schema).load(), | |
- spark.range(10).select('id, -'id)) | |
+ spark.range(10).select(Symbol("id"), -Symbol("id"))) | |
} | |
} | |
test("partitioning reporting") { | |
import org.apache.spark.sql.functions.{count, sum} | |
- Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => | |
- withClue(cls.getName) { | |
- val df = spark.read.format(cls.getName).load() | |
- checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) | |
- | |
- val groupByColA = df.groupBy('i).agg(sum('j)) | |
- checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) | |
- assert(collectFirst(groupByColA.queryExecution.executedPlan) { | |
- case e: ShuffleExchangeExec => e | |
- }.isEmpty) | |
- | |
- val groupByColAB = df.groupBy('i, 'j).agg(count("*")) | |
- checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) | |
- assert(collectFirst(groupByColAB.queryExecution.executedPlan) { | |
- case e: ShuffleExchangeExec => e | |
- }.isEmpty) | |
- | |
- val groupByColB = df.groupBy('j).agg(sum('i)) | |
- checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) | |
- assert(collectFirst(groupByColB.queryExecution.executedPlan) { | |
- case e: ShuffleExchangeExec => e | |
- }.isDefined) | |
- | |
- val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) | |
- checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) | |
- assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { | |
- case e: ShuffleExchangeExec => e | |
- }.isDefined) | |
+ withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") { | |
+ Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => | |
+ withClue(cls.getName) { | |
+ val df = spark.read.format(cls.getName).load() | |
+ checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) | |
+ | |
+ val groupByColA = df.groupBy(Symbol("i")).agg(sum(Symbol("j"))) | |
+ checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) | |
+ assert(collectFirst(groupByColA.queryExecution.executedPlan) { | |
+ case e: ShuffleExchangeExec => e | |
+ }.isEmpty) | |
+ | |
+ val groupByColAB = df.groupBy(Symbol("i"), Symbol("j")).agg(count("*")) | |
+ checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) | |
+ assert(collectFirst(groupByColAB.queryExecution.executedPlan) { | |
+ case e: ShuffleExchangeExec => e | |
+ }.isEmpty) | |
+ | |
+ val groupByColB = df.groupBy(Symbol("j")).agg(sum(Symbol("i"))) | |
+ checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) | |
+ assert(collectFirst(groupByColB.queryExecution.executedPlan) { | |
+ case e: ShuffleExchangeExec => e | |
+ }.isDefined) | |
+ | |
+ val groupByAPlusB = df.groupBy(Symbol("i") + Symbol("j")).agg(count("*")) | |
+ checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) | |
+ assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { | |
+ case e: ShuffleExchangeExec => e | |
+ }.isDefined) | |
+ } | |
} | |
} | |
} | |
@@ -233,37 +309,43 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
val path = file.getCanonicalPath | |
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) | |
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.option("path", path).mode("append").save() | |
checkAnswer( | |
spark.read.format(cls.getName).option("path", path).load(), | |
- spark.range(10).select('id, -'id)) | |
+ spark.range(10).select(Symbol("id"), -Symbol("id"))) | |
// default save mode is ErrorIfExists | |
intercept[AnalysisException] { | |
- spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.option("path", path).save() | |
} | |
- spark.range(10).select('id as 'i, -'id as 'j).write.mode("append").format(cls.getName) | |
+ spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.mode("append").format(cls.getName) | |
.option("path", path).save() | |
checkAnswer( | |
spark.read.format(cls.getName).option("path", path).load(), | |
- spark.range(10).union(spark.range(10)).select('id, -'id)) | |
+ spark.range(10).union(spark.range(10)).select(Symbol("id"), -Symbol("id"))) | |
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.option("path", path).mode("overwrite").save() | |
checkAnswer( | |
spark.read.format(cls.getName).option("path", path).load(), | |
- spark.range(5).select('id, -'id)) | |
+ spark.range(5).select(Symbol("id"), -Symbol("id"))) | |
val e = intercept[AnalysisException] { | |
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.option("path", path).mode("ignore").save() | |
} | |
assert(e.message.contains("please use Append or Overwrite modes instead")) | |
val e2 = intercept[AnalysisException] { | |
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.option("path", path).mode("error").save() | |
} | |
assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) | |
@@ -280,7 +362,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
} | |
} | |
// this input data will fail to read middle way. | |
- val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) | |
+ val input = spark.range(15).select(failingUdf(Symbol("id")).as(Symbol("i"))) | |
+ .select(Symbol("i"), -Symbol("i") as Symbol("j")) | |
val e3 = intercept[SparkException] { | |
input.write.format(cls.getName).option("path", path).mode("overwrite").save() | |
} | |
@@ -300,11 +383,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) | |
val numPartition = 6 | |
- spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) | |
+ spark.range(0, 10, 1, numPartition) | |
+ .select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) | |
+ .write.format(cls.getName) | |
.mode("append").option("path", path).save() | |
checkAnswer( | |
spark.read.format(cls.getName).option("path", path).load(), | |
- spark.range(10).select('id, -'id)) | |
+ spark.range(10).select(Symbol("id"), -Symbol("id"))) | |
assert(SimpleCounter.getCounter == numPartition, | |
"method onDataWriterCommit should be called as many as the number of partitions") | |
@@ -321,7 +406,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
test("SPARK-23301: column pruning with arbitrary expressions") { | |
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() | |
- val q1 = df.select('i + 1) | |
+ val q1 = df.select(Symbol("i") + 1) | |
checkAnswer(q1, (1 until 11).map(i => Row(i))) | |
val batch1 = getBatch(q1) | |
assert(batch1.requiredSchema.fieldNames === Seq("i")) | |
@@ -332,14 +417,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
assert(batch2.requiredSchema.isEmpty) | |
// 'j === 1 can't be pushed down, but we should still be able do column pruning | |
- val q3 = df.filter('j === -1).select('j * 2) | |
+ val q3 = df.filter(Symbol("j") === -1).select(Symbol("j") * 2) | |
checkAnswer(q3, Row(-2)) | |
val batch3 = getBatch(q3) | |
assert(batch3.filters.isEmpty) | |
assert(batch3.requiredSchema.fieldNames === Seq("j")) | |
// column pruning should work with other operators. | |
- val q4 = df.sort('i).limit(1).select('i + 1) | |
+ val q4 = df.sort(Symbol("i")).limit(1).select(Symbol("i") + 1) | |
checkAnswer(q4, Row(1)) | |
val batch4 = getBatch(q4) | |
assert(batch4.requiredSchema.fieldNames === Seq("i")) | |
@@ -361,7 +446,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() | |
checkCanonicalizedOutput(df, 2, 2) | |
- checkCanonicalizedOutput(df.select('i), 2, 1) | |
+ checkCanonicalizedOutput(df.select(Symbol("i")), 2, 1) | |
} | |
test("SPARK-25425: extra options should override sessions options during reading") { | |
@@ -400,7 +485,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
withTempView("t1") { | |
val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() | |
Seq(2, 3).toDF("a").createTempView("t1") | |
- val df = t2.where("i < (select max(a) from t1)").select('i) | |
+ val df = t2.where("i < (select max(a) from t1)").select(Symbol("i")) | |
val subqueries = stripAQEPlan(df.queryExecution.executedPlan).collect { | |
case p => p.subqueries | |
}.flatten | |
@@ -419,8 +504,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => | |
withClue(cls.getName) { | |
val df = spark.read.format(cls.getName).load() | |
- val q1 = df.select('i).filter('i > 6) | |
- val q2 = df.select('i).filter('i > 5) | |
+ val q1 = df.select(Symbol("i")).filter(Symbol("i") > 6) | |
+ val q2 = df.select(Symbol("i")).filter(Symbol("i") > 5) | |
val scan1 = getScanExec(q1) | |
val scan2 = getScanExec(q2) | |
assert(!scan1.equals(scan2)) | |
@@ -433,7 +518,19 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
withClue(cls.getName) { | |
val df = spark.read.format(cls.getName).load() | |
// before SPARK-33267 below query just threw NPE | |
- df.select('i).where("i in (1, null)").collect() | |
+ df.select(Symbol("i")).where("i in (1, null)").collect() | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-35803: Support datasorce V2 in CREATE VIEW USING") { | |
+ Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => | |
+ withClue(cls.getName) { | |
+ sql(s"CREATE or REPLACE TEMPORARY VIEW s1 USING ${cls.getName}") | |
+ checkAnswer(sql("select * from s1"), (0 until 10).map(i => Row(i, -i))) | |
+ checkAnswer(sql("select j from s1"), (0 until 10).map(i => Row(-i))) | |
+ checkAnswer(sql("select * from s1 where i > 5"), | |
+ (6 until 10).map(i => Row(i, -i))) | |
} | |
} | |
} | |
@@ -466,7 +563,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead { | |
override def name(): String = this.getClass.toString | |
- override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava | |
+ override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ) | |
} | |
abstract class SimpleScanBuilder extends ScanBuilder | |
@@ -489,7 +586,7 @@ trait TestingV2Source extends TableProvider { | |
override def getTable( | |
schema: StructType, | |
partitioning: Array[Transform], | |
- properties: util.Map[String, String]): Table = { | |
+ properties: java.util.Map[String, String]): Table = { | |
getTable(new CaseInsensitiveStringMap(properties)) | |
} | |
@@ -597,6 +694,75 @@ class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) | |
} | |
} | |
+class AdvancedDataSourceV2WithV2Filter extends TestingV2Source { | |
+ | |
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { | |
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { | |
+ new AdvancedScanBuilderWithV2Filter() | |
+ } | |
+ } | |
+} | |
+ | |
+class AdvancedScanBuilderWithV2Filter extends ScanBuilder | |
+ with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns { | |
+ | |
+ var requiredSchema = TestingV2Source.schema | |
+ var predicates = Array.empty[Predicate] | |
+ | |
+ override def pruneColumns(requiredSchema: StructType): Unit = { | |
+ this.requiredSchema = requiredSchema | |
+ } | |
+ | |
+ override def readSchema(): StructType = requiredSchema | |
+ | |
+ override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { | |
+ val (supported, unsupported) = predicates.partition { | |
+ case p: Predicate if p.name() == ">" => true | |
+ case _ => false | |
+ } | |
+ this.predicates = supported | |
+ unsupported | |
+ } | |
+ | |
+ override def pushedPredicates(): Array[Predicate] = predicates | |
+ | |
+ override def build(): Scan = this | |
+ | |
+ override def toBatch: Batch = new AdvancedBatchWithV2Filter(predicates, requiredSchema) | |
+} | |
+ | |
+class AdvancedBatchWithV2Filter( | |
+ val predicates: Array[Predicate], | |
+ val requiredSchema: StructType) extends Batch { | |
+ | |
+ override def planInputPartitions(): Array[InputPartition] = { | |
+ val lowerBound = predicates.collectFirst { | |
+ case p: Predicate if p.name().equals(">") => | |
+ val value = p.children()(1) | |
+ assert(value.isInstanceOf[Literal[_]]) | |
+ value.asInstanceOf[Literal[_]] | |
+ } | |
+ | |
+ val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] | |
+ | |
+ if (lowerBound.isEmpty) { | |
+ res.append(RangeInputPartition(0, 5)) | |
+ res.append(RangeInputPartition(5, 10)) | |
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 4) { | |
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 5)) | |
+ res.append(RangeInputPartition(5, 10)) | |
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 9) { | |
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 10)) | |
+ } | |
+ | |
+ res.toArray | |
+ } | |
+ | |
+ override def createReaderFactory(): PartitionReaderFactory = { | |
+ new AdvancedReaderFactory(requiredSchema) | |
+ } | |
+} | |
+ | |
class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { | |
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { | |
@@ -640,7 +806,7 @@ class SchemaRequiredDataSource extends TableProvider { | |
override def getTable( | |
schema: StructType, | |
partitioning: Array[Transform], | |
- properties: util.Map[String, String]): Table = { | |
+ properties: java.util.Map[String, String]): Table = { | |
val userGivenSchema = schema | |
new SimpleBatchTable { | |
override def schema(): StructType = userGivenSchema | |
@@ -733,7 +899,8 @@ class PartitionAwareDataSource extends TestingV2Source { | |
SpecificReaderFactory | |
} | |
- override def outputPartitioning(): Partitioning = new MyPartitioning | |
+ override def outputPartitioning(): Partitioning = | |
+ new KeyGroupedPartitioning(Array(FieldReference("i")), 2) | |
} | |
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { | |
@@ -741,18 +908,13 @@ class PartitionAwareDataSource extends TestingV2Source { | |
new MyScanBuilder() | |
} | |
} | |
- | |
- class MyPartitioning extends Partitioning { | |
- override def numPartitions(): Int = 2 | |
- | |
- override def satisfy(distribution: Distribution): Boolean = distribution match { | |
- case c: ClusteredDistribution => c.clusteredColumns.contains("i") | |
- case _ => false | |
- } | |
- } | |
} | |
-case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition | |
+case class SpecificInputPartition( | |
+ i: Array[Int], | |
+ j: Array[Int]) extends InputPartition with HasPartitionKey { | |
+ override def partitionKey(): InternalRow = InternalRow.fromSeq(Seq(i(0))) | |
+} | |
object SpecificReaderFactory extends PartitionReaderFactory { | |
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala | |
new file mode 100644 | |
index 0000000000..a2cfdde267 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala | |
@@ -0,0 +1,629 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.connector | |
+ | |
+import java.util.Collections | |
+ | |
+import org.scalatest.BeforeAndAfter | |
+ | |
+import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} | |
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog} | |
+import org.apache.spark.sql.connector.expressions.LogicalExpressions._ | |
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} | |
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
+import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec} | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.types.StructType | |
+import org.apache.spark.sql.util.QueryExecutionListener | |
+ | |
+abstract class DeleteFromTableSuiteBase | |
+ extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { | |
+ | |
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
+ import testImplicits._ | |
+ | |
+ before { | |
+ spark.conf.set("spark.sql.catalog.cat", classOf[InMemoryRowLevelOperationTableCatalog].getName) | |
+ } | |
+ | |
+ after { | |
+ spark.sessionState.catalogManager.reset() | |
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") | |
+ } | |
+ | |
+ private val namespace = Array("ns1") | |
+ private val ident = Identifier.of(namespace, "test_table") | |
+ private val tableNameAsString = "cat." + ident.toString | |
+ | |
+ private def catalog: InMemoryRowLevelOperationTableCatalog = { | |
+ val catalog = spark.sessionState.catalogManager.catalog("cat") | |
+ catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] | |
+ } | |
+ | |
+ test("EXPLAIN only delete") { | |
+ createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""") | |
+ | |
+ sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE id <= 10") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Nil) | |
+ } | |
+ | |
+ test("delete from empty tables") { | |
+ createTable("id INT, dep STRING") | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) | |
+ } | |
+ | |
+ test("delete with basic filters") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "software" } | |
+ |{ "id": 3, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "software") :: Row(3, "hr") :: Nil) | |
+ } | |
+ | |
+ test("delete with aliases") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "software" } | |
+ |{ "id": 3, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString AS t WHERE t.id <= 1 OR t.dep = 'hr'") | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "software") :: Nil) | |
+ } | |
+ | |
+ test("delete with IN predicates") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "software" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, null)") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "software") :: Row(null, "hr") :: Nil) | |
+ } | |
+ | |
+ test("delete with NOT IN predicates") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "software" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (null, 1)") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(2, "software") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (1, 10)") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(null, "hr") :: Nil) | |
+ } | |
+ | |
+ test("delete with conditions on nested columns") { | |
+ createAndInitTable("id INT, complex STRUCT<c1:INT,c2:STRING>, dep STRING", | |
+ """{ "id": 1, "complex": { "c1": 3, "c2": "v1" }, "dep": "hr" } | |
+ |{ "id": 2, "complex": { "c1": 2, "c2": "v2" }, "dep": "software" } | |
+ |""".stripMargin) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString WHERE complex.c1 = id + 2") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, Row(2, "v2"), "software") :: Nil) | |
+ | |
+ sql(s"DELETE FROM $tableNameAsString t WHERE t.complex.c1 = id") | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) | |
+ } | |
+ | |
+ test("delete with IN subqueries") { | |
+ withTempView("deleted_id", "deleted_dep") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(0), Some(1), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ val deletedDepDF = Seq("software", "hr").toDF() | |
+ deletedDepDF.createOrReplaceTempView("deleted_dep") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id IN (SELECT * FROM deleted_id) | |
+ | AND | |
+ | dep IN (SELECT * FROM deleted_dep) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ | |
+ append("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": -1, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id IS NULL | |
+ | OR | |
+ | id IN (SELECT value + 2 FROM deleted_id) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(-1, "hr") :: Row(1, "hr") :: Nil) | |
+ | |
+ append("id INT, dep STRING", | |
+ """{ "id": null, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hr") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id IN (SELECT value + 2 FROM deleted_id) | |
+ | AND | |
+ | dep = 'hr' | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(-1, "hr") :: Row(1, "hr") :: Row(null, "hr") :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete with multi-column IN subqueries") { | |
+ withTempView("deleted_employee") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedEmployeeDF = Seq((None, "hr"), (Some(1), "hr")).toDF() | |
+ deletedEmployeeDF.createOrReplaceTempView("deleted_employee") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | (id, dep) IN (SELECT * FROM deleted_employee) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete with NOT IN subqueries") { | |
+ withTempView("deleted_id", "deleted_dep") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ val deletedDepDF = Seq("software", "hr").toDF() | |
+ deletedDepDF.createOrReplaceTempView("deleted_dep") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id NOT IN (SELECT * FROM deleted_id) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(null, "hr") :: Nil) | |
+ | |
+ append("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id NOT IN (SELECT * FROM deleted_id) | |
+ | OR | |
+ | dep IN ('software', 'hr') | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString | |
+ |WHERE | |
+ | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) | |
+ | AND | |
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE dep = deleted_dep.value) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | t.id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) | |
+ | OR | |
+ | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete with EXISTS subquery") { | |
+ withTempView("deleted_id", "deleted_dep") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ val deletedDepDF = Seq("software", "hr").toDF() | |
+ deletedDepDF.createOrReplaceTempView("deleted_dep") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) | |
+ | AND | |
+ | EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete with NOT EXISTS subquery") { | |
+ withTempView("deleted_id", "deleted_dep") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ val deletedDepDF = Seq("software", "hr").toDF() | |
+ deletedDepDF.createOrReplaceTempView("deleted_dep") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) | |
+ | AND | |
+ | NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(1, "hr") :: Row(null, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(1, "hr") :: Nil) | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) | |
+ | OR | |
+ | t.id = 1 | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete with a scalar subquery") { | |
+ withTempView("deleted_id") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": null, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ sql( | |
+ s"""DELETE FROM $tableNameAsString t | |
+ |WHERE | |
+ | id <= (SELECT min(value) FROM deleted_id) | |
+ |""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Row(null, "hr") :: Nil) | |
+ } | |
+ } | |
+ | |
+ test("delete refreshes relation cache") { | |
+ withTempView("temp") { | |
+ withCache("temp") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 1, "dep": "hardware" } | |
+ |{ "id": 2, "dep": "hardware" } | |
+ |{ "id": 3, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ // define a view on top of the table | |
+ val query = sql(s"SELECT * FROM $tableNameAsString WHERE id = 1") | |
+ query.createOrReplaceTempView("temp") | |
+ | |
+ // cache the view | |
+ sql("CACHE TABLE temp") | |
+ | |
+ // verify the view returns expected results | |
+ checkAnswer( | |
+ sql("SELECT * FROM temp"), | |
+ Row(1, "hr") :: Row(1, "hardware") :: Nil) | |
+ | |
+ // delete some records from the table | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") | |
+ | |
+ // verify the delete was successful | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, "hardware") :: Row(3, "hr") :: Nil) | |
+ | |
+ // verify the view reflects the changes in the table | |
+ checkAnswer(sql("SELECT * FROM temp"), Nil) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("delete with nondeterministic conditions") { | |
+ createAndInitTable("id INT, dep STRING", | |
+ """{ "id": 1, "dep": "hr" } | |
+ |{ "id": 2, "dep": "software" } | |
+ |{ "id": 3, "dep": "hr" } | |
+ |""".stripMargin) | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5") | |
+ } | |
+ assert(e.message.contains("nondeterministic expressions are only allowed")) | |
+ } | |
+ | |
+ test("delete without condition executed as delete with filters") { | |
+ createAndInitTable("id INT, dep INT", | |
+ """{ "id": 1, "dep": 100 } | |
+ |{ "id": 2, "dep": 200 } | |
+ |{ "id": 3, "dep": 100 } | |
+ |""".stripMargin) | |
+ | |
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString") | |
+ | |
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) | |
+ } | |
+ | |
+ test("delete with supported predicates gets converted into delete with filters") { | |
+ createAndInitTable("id INT, dep INT", | |
+ """{ "id": 1, "dep": 100 } | |
+ |{ "id": 2, "dep": 200 } | |
+ |{ "id": 3, "dep": 100 } | |
+ |""".stripMargin) | |
+ | |
+ executeDeleteWithFilters(s"DELETE FROM $tableNameAsString WHERE dep = 100") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, 200) :: Nil) | |
+ } | |
+ | |
+ test("delete with unsupported predicates cannot be converted into delete with filters") { | |
+ createAndInitTable("id INT, dep INT", | |
+ """{ "id": 1, "dep": 100 } | |
+ |{ "id": 2, "dep": 200 } | |
+ |{ "id": 3, "dep": 100 } | |
+ |""".stripMargin) | |
+ | |
+ executeDeleteWithRewrite(s"DELETE FROM $tableNameAsString WHERE dep = 100 OR dep < 200") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, 200) :: Nil) | |
+ } | |
+ | |
+ test("delete with subquery cannot be converted into delete with filters") { | |
+ withTempView("deleted_id") { | |
+ createAndInitTable("id INT, dep INT", | |
+ """{ "id": 1, "dep": 100 } | |
+ |{ "id": 2, "dep": 200 } | |
+ |{ "id": 3, "dep": 100 } | |
+ |""".stripMargin) | |
+ | |
+ val deletedIdDF = Seq(Some(1), Some(100), None).toDF() | |
+ deletedIdDF.createOrReplaceTempView("deleted_id") | |
+ | |
+ val q = s"DELETE FROM $tableNameAsString WHERE dep = 100 AND id IN (SELECT * FROM deleted_id)" | |
+ executeDeleteWithRewrite(q) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT * FROM $tableNameAsString"), | |
+ Row(2, 200) :: Row(3, 100) :: Nil) | |
+ } | |
+ } | |
+ | |
+ private def createTable(schemaString: String): Unit = { | |
+ val schema = StructType.fromDDL(schemaString) | |
+ val tableProps = Collections.emptyMap[String, String] | |
+ catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps) | |
+ } | |
+ | |
+ private def createAndInitTable(schemaString: String, jsonData: String): Unit = { | |
+ createTable(schemaString) | |
+ append(schemaString, jsonData) | |
+ } | |
+ | |
+ private def append(schemaString: String, jsonData: String): Unit = { | |
+ val df = toDF(jsonData, schemaString) | |
+ df.coalesce(1).writeTo(tableNameAsString).append() | |
+ } | |
+ | |
+ private def toDF(jsonData: String, schemaString: String = null): DataFrame = { | |
+ val jsonRows = jsonData.split("\\n").filter(str => str.trim.nonEmpty) | |
+ val jsonDS = spark.createDataset(jsonRows)(Encoders.STRING) | |
+ if (schemaString == null) { | |
+ spark.read.json(jsonDS) | |
+ } else { | |
+ spark.read.schema(schemaString).json(jsonDS) | |
+ } | |
+ } | |
+ | |
+ private def executeDeleteWithFilters(query: String): Unit = { | |
+ val executedPlan = executeAndKeepPlan { | |
+ sql(query) | |
+ } | |
+ | |
+ executedPlan match { | |
+ case _: DeleteFromTableExec => | |
+ // OK | |
+ case other => | |
+ fail("unexpected executed plan: " + other) | |
+ } | |
+ } | |
+ | |
+ private def executeDeleteWithRewrite(query: String): Unit = { | |
+ val executedPlan = executeAndKeepPlan { | |
+ sql(query) | |
+ } | |
+ | |
+ executedPlan match { | |
+ case _: ReplaceDataExec => | |
+ // OK | |
+ case other => | |
+ fail("unexpected executed plan: " + other) | |
+ } | |
+ } | |
+ | |
+ // executes an operation and keeps the executed plan | |
+ private def executeAndKeepPlan(func: => Unit): SparkPlan = { | |
+ var executedPlan: SparkPlan = null | |
+ | |
+ val listener = new QueryExecutionListener { | |
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { | |
+ executedPlan = qe.executedPlan | |
+ } | |
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { | |
+ } | |
+ } | |
+ spark.listenerManager.register(listener) | |
+ | |
+ func | |
+ | |
+ sparkContext.listenerBus.waitUntilEmpty() | |
+ | |
+ stripAQEPlan(executedPlan) | |
+ } | |
+} | |
+ | |
+class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..f4317e6327 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala | |
@@ -0,0 +1,103 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+package org.apache.spark.sql.connector | |
+ | |
+import org.scalatest.BeforeAndAfter | |
+ | |
+import org.apache.spark.sql.QueryTest | |
+import org.apache.spark.sql.catalyst | |
+import org.apache.spark.sql.catalyst.analysis.Resolver | |
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute | |
+import org.apache.spark.sql.catalyst.expressions.SortOrder | |
+import org.apache.spark.sql.catalyst.plans.QueryPlan | |
+import org.apache.spark.sql.catalyst.plans.physical | |
+import org.apache.spark.sql.catalyst.plans.physical._ | |
+import org.apache.spark.sql.connector.catalog.InMemoryCatalog | |
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+abstract class DistributionAndOrderingSuiteBase | |
+ extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { | |
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
+ | |
+ override def beforeAll(): Unit = { | |
+ super.beforeAll() | |
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) | |
+ } | |
+ | |
+ override def afterAll(): Unit = { | |
+ spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") | |
+ super.afterAll() | |
+ } | |
+ | |
+ protected val resolver: Resolver = conf.resolver | |
+ | |
+ protected def resolvePartitioning[T <: QueryPlan[T]]( | |
+ partitioning: Partitioning, | |
+ plan: QueryPlan[T]): Partitioning = partitioning match { | |
+ case HashPartitioning(exprs, numPartitions) => | |
+ HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) | |
+ case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => | |
+ KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, | |
+ partitionValues) | |
+ case PartitioningCollection(partitionings) => | |
+ PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) | |
+ case RangePartitioning(ordering, numPartitions) => | |
+ RangePartitioning(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder]), numPartitions) | |
+ case p @ SinglePartition => | |
+ p | |
+ case p: UnknownPartitioning => | |
+ p | |
+ case p => | |
+ fail(s"unexpected partitioning: $p") | |
+ } | |
+ | |
+ protected def resolveDistribution[T <: QueryPlan[T]]( | |
+ distribution: physical.Distribution, | |
+ plan: QueryPlan[T]): physical.Distribution = distribution match { | |
+ case physical.ClusteredDistribution(clustering, numPartitions, _) => | |
+ physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)), numPartitions) | |
+ case physical.OrderedDistribution(ordering) => | |
+ physical.OrderedDistribution(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder])) | |
+ case physical.UnspecifiedDistribution => | |
+ physical.UnspecifiedDistribution | |
+ case d => | |
+ fail(s"unexpected distribution: $d") | |
+ } | |
+ | |
+ protected def resolveAttrs[T <: QueryPlan[T]]( | |
+ expr: catalyst.expressions.Expression, | |
+ plan: QueryPlan[T]): catalyst.expressions.Expression = { | |
+ | |
+ expr.transform { | |
+ case UnresolvedAttribute(Seq(attrName)) => | |
+ plan.output.find(attr => resolver(attr.name, attrName)).get | |
+ case UnresolvedAttribute(nameParts) => | |
+ val attrName = nameParts.mkString(".") | |
+ fail(s"cannot resolve a nested attr: $attrName") | |
+ } | |
+ } | |
+ | |
+ protected def attr(name: String): UnresolvedAttribute = { | |
+ UnresolvedAttribute(name) | |
+ } | |
+ | |
+ protected def catalog: InMemoryCatalog = { | |
+ val catalog = spark.sessionState.catalogManager.catalog("testcat") | |
+ catalog.asTableCatalog.asInstanceOf[InMemoryCatalog] | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala | |
index be0dae2563..cfc8b2cc84 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala | |
@@ -16,7 +16,6 @@ | |
*/ | |
package org.apache.spark.sql.connector | |
-import scala.collection.JavaConverters._ | |
import scala.collection.mutable.ArrayBuffer | |
import org.apache.spark.SparkConf | |
@@ -56,7 +55,7 @@ class DummyReadOnlyFileTable extends Table with SupportsRead { | |
} | |
override def capabilities(): java.util.Set[TableCapability] = | |
- Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava | |
+ java.util.EnumSet.of(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA) | |
} | |
class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { | |
@@ -79,7 +78,7 @@ class DummyWriteOnlyFileTable extends Table with SupportsWrite { | |
throw new AnalysisException("Dummy file writer") | |
override def capabilities(): java.util.Set[TableCapability] = | |
- Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava | |
+ java.util.EnumSet.of(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA) | |
} | |
class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { | |
@@ -185,7 +184,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { | |
val df = spark.read.format(format).load(path.getCanonicalPath) | |
checkAnswer(df, inputData.toDF()) | |
assert( | |
- df.queryExecution.executedPlan.find(_.isInstanceOf[FileSourceScanExec]).isDefined) | |
+ df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) | |
} | |
} finally { | |
spark.listenerManager.unregister(listener) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala | |
index 0dee48fbb5..85904bbf12 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala | |
@@ -259,7 +259,7 @@ trait InsertIntoSQLOnlyTests | |
verifyTable(t1, spark.emptyDataFrame) | |
assert(exc.getMessage.contains( | |
- "PARTITION clause cannot contain a non-partition column name")) | |
+ "PARTITION clause cannot contain the non-partition column")) | |
assert(exc.getMessage.contains("id")) | |
assert(exc.getErrorClass == "NON_PARTITION_COLUMN") | |
} | |
@@ -276,28 +276,12 @@ trait InsertIntoSQLOnlyTests | |
verifyTable(t1, spark.emptyDataFrame) | |
assert(exc.getMessage.contains( | |
- "PARTITION clause cannot contain a non-partition column name")) | |
+ "PARTITION clause cannot contain the non-partition column")) | |
assert(exc.getMessage.contains("data")) | |
assert(exc.getErrorClass == "NON_PARTITION_COLUMN") | |
} | |
} | |
- test("InsertInto: IF PARTITION NOT EXISTS not supported") { | |
- val t1 = s"${catalogAndNamespace}tbl" | |
- withTableAndData(t1) { view => | |
- sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") | |
- | |
- val exc = intercept[AnalysisException] { | |
- sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view") | |
- } | |
- | |
- verifyTable(t1, spark.emptyDataFrame) | |
- assert(exc.getMessage.contains("Cannot write, IF NOT EXISTS is not supported for table")) | |
- assert(exc.getMessage.contains(t1)) | |
- assert(exc.getErrorClass == "IF_PARTITION_NOT_EXISTS_UNSUPPORTED") | |
- } | |
- } | |
- | |
test("InsertInto: overwrite - dynamic clause - static mode") { | |
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { | |
val t1 = s"${catalogAndNamespace}tbl" | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala | |
new file mode 100644 | |
index 0000000000..bdbf309214 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala | |
@@ -0,0 +1,425 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+package org.apache.spark.sql.connector | |
+ | |
+import java.util.Collections | |
+ | |
+import org.apache.spark.sql.{DataFrame, Row} | |
+import org.apache.spark.sql.catalyst.InternalRow | |
+import org.apache.spark.sql.catalyst.expressions.TransformExpression | |
+import org.apache.spark.sql.catalyst.plans.physical | |
+import org.apache.spark.sql.connector.catalog.Identifier | |
+import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog | |
+import org.apache.spark.sql.connector.catalog.functions._ | |
+import org.apache.spark.sql.connector.distributions.Distributions | |
+import org.apache.spark.sql.connector.expressions._ | |
+import org.apache.spark.sql.connector.expressions.Expressions._ | |
+import org.apache.spark.sql.execution.SparkPlan | |
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec | |
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation | |
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec | |
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.internal.SQLConf._ | |
+import org.apache.spark.sql.types._ | |
+ | |
+class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { | |
+ private var originalV2BucketingEnabled: Boolean = false | |
+ private var originalAutoBroadcastJoinThreshold: Long = -1 | |
+ | |
+ override def beforeAll(): Unit = { | |
+ super.beforeAll() | |
+ originalV2BucketingEnabled = conf.getConf(V2_BUCKETING_ENABLED) | |
+ conf.setConf(V2_BUCKETING_ENABLED, true) | |
+ originalAutoBroadcastJoinThreshold = conf.getConf(AUTO_BROADCASTJOIN_THRESHOLD) | |
+ conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, -1L) | |
+ } | |
+ | |
+ override def afterAll(): Unit = { | |
+ try { | |
+ super.afterAll() | |
+ } finally { | |
+ conf.setConf(V2_BUCKETING_ENABLED, originalV2BucketingEnabled) | |
+ conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, originalAutoBroadcastJoinThreshold) | |
+ } | |
+ } | |
+ | |
+ before { | |
+ Seq(UnboundYearsFunction, UnboundDaysFunction, UnboundBucketFunction).foreach { f => | |
+ catalog.createFunction(Identifier.of(Array.empty, f.name()), f) | |
+ } | |
+ } | |
+ | |
+ after { | |
+ catalog.clearTables() | |
+ catalog.clearFunctions() | |
+ } | |
+ | |
+ private val emptyProps: java.util.Map[String, String] = { | |
+ Collections.emptyMap[String, String] | |
+ } | |
+ private val table: String = "tbl" | |
+ private val schema = new StructType() | |
+ .add("id", IntegerType) | |
+ .add("data", StringType) | |
+ .add("ts", TimestampType) | |
+ | |
+ test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { | |
+ val partitions: Array[Transform] = Array(Expressions.years("ts")) | |
+ | |
+ // create a table with 3 partitions, partitioned by `years` transform | |
+ createTable(table, schema, partitions) | |
+ sql(s"INSERT INTO testcat.ns.$table VALUES " + | |
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + | |
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + | |
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") | |
+ | |
+ var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts") | |
+ val catalystDistribution = physical.ClusteredDistribution( | |
+ Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) | |
+ val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) | |
+ | |
+ checkQueryPlan(df, catalystDistribution, | |
+ physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) | |
+ | |
+ // multiple group keys should work too as long as partition keys are subset of them | |
+ df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") | |
+ checkQueryPlan(df, catalystDistribution, | |
+ physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) | |
+ } | |
+ | |
+ test("non-clustered distribution: no partition") { | |
+ val partitions: Array[Transform] = Array(bucket(32, "ts")) | |
+ createTable(table, schema, partitions) | |
+ | |
+ val df = sql(s"SELECT * FROM testcat.ns.$table") | |
+ val distribution = physical.ClusteredDistribution( | |
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) | |
+ | |
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) | |
+ } | |
+ | |
+ test("non-clustered distribution: single partition") { | |
+ val partitions: Array[Transform] = Array(bucket(32, "ts")) | |
+ createTable(table, schema, partitions) | |
+ sql(s"INSERT INTO testcat.ns.$table VALUES (0, 'aaa', CAST('2020-01-01' AS timestamp))") | |
+ | |
+ val df = sql(s"SELECT * FROM testcat.ns.$table") | |
+ val distribution = physical.ClusteredDistribution( | |
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) | |
+ | |
+ checkQueryPlan(df, distribution, physical.SinglePartition) | |
+ } | |
+ | |
+ test("non-clustered distribution: no V2 catalog") { | |
+ spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) | |
+ val nonFunctionCatalog = spark.sessionState.catalogManager.catalog("testcat2") | |
+ .asInstanceOf[InMemoryTableCatalog] | |
+ val partitions: Array[Transform] = Array(bucket(32, "ts")) | |
+ createTable(table, schema, partitions, catalog = nonFunctionCatalog) | |
+ sql(s"INSERT INTO testcat2.ns.$table VALUES " + | |
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + | |
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + | |
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") | |
+ | |
+ val df = sql(s"SELECT * FROM testcat2.ns.$table") | |
+ val distribution = physical.UnspecifiedDistribution | |
+ | |
+ try { | |
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) | |
+ } finally { | |
+ spark.conf.unset("spark.sql.catalog.testcat2") | |
+ } | |
+ } | |
+ | |
+ test("non-clustered distribution: no V2 function provided") { | |
+ catalog.clearFunctions() | |
+ | |
+ val partitions: Array[Transform] = Array(bucket(32, "ts")) | |
+ createTable(table, schema, partitions) | |
+ sql(s"INSERT INTO testcat.ns.$table VALUES " + | |
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + | |
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + | |
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") | |
+ | |
+ val df = sql(s"SELECT * FROM testcat.ns.$table") | |
+ val distribution = physical.UnspecifiedDistribution | |
+ | |
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) | |
+ } | |
+ | |
+ test("non-clustered distribution: V2 bucketing disabled") { | |
+ withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "false") { | |
+ val partitions: Array[Transform] = Array(bucket(32, "ts")) | |
+ createTable(table, schema, partitions) | |
+ sql(s"INSERT INTO testcat.ns.$table VALUES " + | |
+ s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + | |
+ s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + | |
+ s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") | |
+ | |
+ val df = sql(s"SELECT * FROM testcat.ns.$table") | |
+ val distribution = physical.ClusteredDistribution( | |
+ Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) | |
+ | |
+ checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) | |
+ } | |
+ } | |
+ | |
+ /** | |
+ * Check whether the query plan from `df` has the expected `distribution`, `ordering` and | |
+ * `partitioning`. | |
+ */ | |
+ private def checkQueryPlan( | |
+ df: DataFrame, | |
+ distribution: physical.Distribution, | |
+ partitioning: physical.Partitioning): Unit = { | |
+ // check distribution & ordering are correctly populated in logical plan | |
+ val relation = df.queryExecution.optimizedPlan.collect { | |
+ case r: DataSourceV2ScanRelation => r | |
+ }.head | |
+ | |
+ resolveDistribution(distribution, relation) match { | |
+ case physical.ClusteredDistribution(clustering, _, _) => | |
+ assert(relation.keyGroupedPartitioning.isDefined && | |
+ relation.keyGroupedPartitioning.get == clustering) | |
+ case _ => | |
+ assert(relation.keyGroupedPartitioning.isEmpty) | |
+ } | |
+ | |
+ // check distribution, ordering and output partitioning are correctly populated in physical plan | |
+ val scan = collect(df.queryExecution.executedPlan) { | |
+ case s: BatchScanExec => s | |
+ }.head | |
+ | |
+ val expectedPartitioning = resolvePartitioning(partitioning, scan) | |
+ assert(expectedPartitioning == scan.outputPartitioning) | |
+ } | |
+ | |
+ private def createTable( | |
+ table: String, | |
+ schema: StructType, | |
+ partitions: Array[Transform], | |
+ catalog: InMemoryTableCatalog = catalog): Unit = { | |
+ catalog.createTable(Identifier.of(Array("ns"), table), | |
+ schema, partitions, emptyProps, Distributions.unspecified(), Array.empty, None) | |
+ } | |
+ | |
+ private val customers: String = "customers" | |
+ private val customers_schema = new StructType() | |
+ .add("customer_name", StringType) | |
+ .add("customer_age", IntegerType) | |
+ .add("customer_id", LongType) | |
+ | |
+ private val orders: String = "orders" | |
+ private val orders_schema = new StructType() | |
+ .add("order_amount", DoubleType) | |
+ .add("customer_id", LongType) | |
+ | |
+ private def testWithCustomersAndOrders( | |
+ customers_partitions: Array[Transform], | |
+ orders_partitions: Array[Transform], | |
+ expectedNumOfShuffleExecs: Int): Unit = { | |
+ createTable(customers, customers_schema, customers_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$customers VALUES " + | |
+ s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") | |
+ | |
+ createTable(orders, orders_schema, orders_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$orders VALUES " + | |
+ s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)") | |
+ | |
+ val df = sql("SELECT customer_name, customer_age, order_amount " + | |
+ s"FROM testcat.ns.$customers c JOIN testcat.ns.$orders o " + | |
+ "ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount") | |
+ | |
+ val shuffles = collectShuffles(df.queryExecution.executedPlan) | |
+ assert(shuffles.length == expectedNumOfShuffleExecs) | |
+ | |
+ checkAnswer(df, | |
+ Seq(Row("aaa", 10, 100.0), Row("aaa", 10, 200.0), Row("bbb", 20, 150.0), | |
+ Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) | |
+ } | |
+ | |
+ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { | |
+ // here we skip collecting shuffle operators that are not associated with SMJ | |
+ collect(plan) { | |
+ case s: SortMergeJoinExec => s | |
+ }.flatMap(smj => | |
+ collect(smj) { | |
+ case s: ShuffleExchangeExec => s | |
+ }) | |
+ } | |
+ | |
+ test("partitioned join: exact distribution (same number of buckets) from both sides") { | |
+ val customers_partitions = Array(bucket(4, "customer_id")) | |
+ val orders_partitions = Array(bucket(4, "customer_id")) | |
+ | |
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 0) | |
+ } | |
+ | |
+ test("partitioned join: number of buckets mismatch should trigger shuffle") { | |
+ val customers_partitions = Array(bucket(4, "customer_id")) | |
+ val orders_partitions = Array(bucket(2, "customer_id")) | |
+ | |
+ // should shuffle both sides when number of buckets are not the same | |
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 2) | |
+ } | |
+ | |
+ test("partitioned join: only one side reports partitioning") { | |
+ val customers_partitions = Array(bucket(4, "customer_id")) | |
+ val orders_partitions = Array(bucket(2, "customer_id")) | |
+ | |
+ testWithCustomersAndOrders(customers_partitions, orders_partitions, 2) | |
+ } | |
+ | |
+ private val items: String = "items" | |
+ private val items_schema: StructType = new StructType() | |
+ .add("id", LongType) | |
+ .add("name", StringType) | |
+ .add("price", FloatType) | |
+ .add("arrive_time", TimestampType) | |
+ | |
+ private val purchases: String = "purchases" | |
+ private val purchases_schema: StructType = new StructType() | |
+ .add("item_id", LongType) | |
+ .add("price", FloatType) | |
+ .add("time", TimestampType) | |
+ | |
+ test("partitioned join: join with two partition keys and matching & sorted partitions") { | |
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time")) | |
+ createTable(items, items_schema, items_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$items VALUES " + | |
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + | |
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time")) | |
+ createTable(purchases, purchases_schema, purchases_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " + | |
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(3, 19.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + | |
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + | |
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") | |
+ | |
+ val shuffles = collectShuffles(df.queryExecution.executedPlan) | |
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") | |
+ checkAnswer(df, | |
+ Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), | |
+ Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) | |
+ ) | |
+ } | |
+ | |
+ test("partitioned join: join with two partition keys and unsorted partitions") { | |
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time")) | |
+ createTable(items, items_schema, items_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$items VALUES " + | |
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " + | |
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp))") | |
+ | |
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time")) | |
+ createTable(purchases, purchases_schema, purchases_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " + | |
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(3, 19.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + | |
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + | |
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") | |
+ | |
+ val shuffles = collectShuffles(df.queryExecution.executedPlan) | |
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") | |
+ checkAnswer(df, | |
+ Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), | |
+ Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) | |
+ ) | |
+ } | |
+ | |
+ test("partitioned join: join with two partition keys and different # of partition keys") { | |
+ val items_partitions = Array(bucket(8, "id"), days("arrive_time")) | |
+ createTable(items, items_schema, items_partitions) | |
+ | |
+ sql(s"INSERT INTO testcat.ns.$items VALUES " + | |
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ val purchases_partitions = Array(bucket(8, "item_id"), days("time")) | |
+ createTable(purchases, purchases_schema, purchases_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " + | |
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(2, 11.0, cast('2020-01-01' as timestamp))") | |
+ | |
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + | |
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + | |
+ "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") | |
+ | |
+ val shuffles = collectShuffles(df.queryExecution.executedPlan) | |
+ assert(shuffles.nonEmpty, "should add shuffle when partition keys mismatch") | |
+ } | |
+ | |
+ test("data source partitioning + dynamic partition filtering") { | |
+ withSQLConf( | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", | |
+ SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") { | |
+ val items_partitions = Array(identity("id")) | |
+ createTable(items, items_schema, items_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$items VALUES " + | |
+ s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + | |
+ s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ val purchases_partitions = Array(identity("item_id")) | |
+ createTable(purchases, purchases_schema, purchases_partitions) | |
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " + | |
+ s"(1, 42.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(1, 44.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(1, 45.0, cast('2020-01-15' as timestamp)), " + | |
+ s"(2, 11.0, cast('2020-01-01' as timestamp)), " + | |
+ s"(3, 19.5, cast('2020-02-01' as timestamp))") | |
+ | |
+ // number of unique partitions changed after dynamic filtering - should throw exception | |
+ var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " + | |
+ s"i.id = p.item_id AND i.price > 40.0") | |
+ val e = intercept[Exception](df.collect()) | |
+ assert(e.getMessage.contains("number of unique partition values")) | |
+ | |
+ // dynamic filtering doesn't change partitioning so storage-partitioned join should kick in | |
+ df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " + | |
+ s"i.id = p.item_id AND i.price >= 10.0") | |
+ val shuffles = collectShuffles(df.queryExecution.executedPlan) | |
+ assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") | |
+ checkAnswer(df, Seq(Row(303.5))) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala | |
index db71eeb75e..e3d61a846f 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala | |
@@ -17,10 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
- | |
-import scala.collection.JavaConverters._ | |
- | |
import org.apache.spark.sql.{QueryTest, Row} | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} | |
@@ -63,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog { | |
ident: Identifier, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): Table = { | |
+ properties: java.util.Map[String, String]): Table = { | |
val table = new TestLocalScanTable(ident.toString) | |
tables.put(ident, table) | |
table | |
@@ -78,7 +74,8 @@ object TestLocalScanTable { | |
class TestLocalScanTable(override val name: String) extends Table with SupportsRead { | |
override def schema(): StructType = TestLocalScanTable.schema | |
- override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava | |
+ override def capabilities(): java.util.Set[TableCapability] = | |
+ java.util.EnumSet.of(TableCapability.BATCH_READ) | |
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = | |
new TestLocalScanBuilder | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala | |
new file mode 100644 | |
index 0000000000..7f0e74f6bc | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala | |
@@ -0,0 +1,235 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.connector | |
+ | |
+import org.apache.spark.sql.{AnalysisException, Row} | |
+import org.apache.spark.sql.functions.struct | |
+ | |
+class MetadataColumnSuite extends DatasourceV2SQLBase { | |
+ import testImplicits._ | |
+ | |
+ private val tbl = "testcat.t" | |
+ | |
+ private def prepareTable(): Unit = { | |
+ sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (bucket(4, id), id)") | |
+ sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
+ } | |
+ | |
+ test("SPARK-31255: Project a metadata column") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl") | |
+ val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-31255: Projects data column when metadata column has the same name") { | |
+ withTable(tbl) { | |
+ sql(s"CREATE TABLE $tbl (index bigint, data string) PARTITIONED BY (bucket(4, index), index)") | |
+ sql(s"INSERT INTO $tbl VALUES (3, 'c'), (2, 'b'), (1, 'a')") | |
+ | |
+ val sqlQuery = sql(s"SELECT index, data, _partition FROM $tbl") | |
+ val dfQuery = spark.table(tbl).select("index", "data", "_partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-31255: * expansion does not include metadata columns") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql(s"SELECT * FROM $tbl") | |
+ val dfQuery = spark.table(tbl) | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-31255: metadata column should only be produced when necessary") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql(s"SELECT * FROM $tbl WHERE index = 0") | |
+ val dfQuery = spark.table(tbl).filter("index = 0") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ assert(query.schema.fieldNames.toSeq == Seq("id", "data")) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34547: metadata columns are resolved last") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ withTempView("v") { | |
+ sql(s"CREATE TEMPORARY VIEW v AS SELECT * FROM " + | |
+ s"VALUES (1, -1), (2, -2), (3, -3) AS v(id, index)") | |
+ | |
+ val sqlQuery = sql(s"SELECT $tbl.id, v.id, data, index, $tbl.index, v.index " + | |
+ s"FROM $tbl JOIN v WHERE $tbl.id = v.id") | |
+ val tableDf = spark.table(tbl) | |
+ val viewDf = spark.table("v") | |
+ val dfQuery = tableDf.join(viewDf, tableDf.col("id") === viewDf.col("id")) | |
+ .select(s"$tbl.id", "v.id", "data", "index", s"$tbl.index", "v.index") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, | |
+ Seq( | |
+ Row(1, 1, "a", -1, 0, -1), | |
+ Row(2, 2, "b", -2, 0, -2), | |
+ Row(3, 3, "c", -3, 0, -3) | |
+ ) | |
+ ) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34555: Resolve DataFrame metadata column") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val table = spark.table(tbl) | |
+ val dfQuery = table.select( | |
+ table.col("id"), | |
+ table.col("data"), | |
+ table.col("index"), | |
+ table.col("_partition") | |
+ ) | |
+ | |
+ checkAnswer( | |
+ dfQuery, | |
+ Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) | |
+ ) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: propagate metadata columns through Project") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ checkAnswer( | |
+ spark.table(tbl).select("id", "data").select("index", "_partition"), | |
+ Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3")) | |
+ ) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: do not propagate metadata columns through View") { | |
+ val view = "view" | |
+ withTable(tbl) { | |
+ withTempView(view) { | |
+ prepareTable() | |
+ sql(s"CACHE TABLE $view AS SELECT * FROM $tbl") | |
+ assertThrows[AnalysisException] { | |
+ sql(s"SELECT index, _partition FROM $view") | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: propagate metadata columns through Filter") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl WHERE id > 1") | |
+ val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: propagate metadata columns through Sort") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql(s"SELECT id, data, index, _partition FROM $tbl ORDER BY id") | |
+ val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: propagate metadata columns through RepartitionBy") { | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql( | |
+ s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $tbl") | |
+ val dfQuery = spark.table(tbl).repartitionByRange(3, $"id") | |
+ .select("id", "data", "index", "_partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-34923: propagate metadata columns through SubqueryAlias if child is leaf node") { | |
+ val sbq = "sbq" | |
+ withTable(tbl) { | |
+ prepareTable() | |
+ val sqlQuery = sql( | |
+ s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $tbl $sbq") | |
+ val dfQuery = spark.table(tbl).as(sbq).select( | |
+ s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") | |
+ | |
+ Seq(sqlQuery, dfQuery).foreach { query => | |
+ checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) | |
+ } | |
+ | |
+ assertThrows[AnalysisException] { | |
+ sql(s"SELECT $sbq.index FROM (SELECT id FROM $tbl) $sbq") | |
+ } | |
+ assertThrows[AnalysisException] { | |
+ spark.table(tbl).select($"id").as(sbq).select(s"$sbq.index") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-40149: select outer join metadata columns with DataFrame API") { | |
+ val df1 = Seq(1 -> "a").toDF("k", "v").as("left") | |
+ val df2 = Seq(1 -> "b").toDF("k", "v").as("right") | |
+ val dfQuery = df1.join(df2, Seq("k"), "outer") | |
+ .withColumn("left_all", struct($"left.*")) | |
+ .withColumn("right_all", struct($"right.*")) | |
+ checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) | |
+ } | |
+ | |
+ test("SPARK-40429: Only set KeyGroupedPartitioning when the referenced column is in the output") { | |
+ withTable(tbl) { | |
+ sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (id)") | |
+ sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") | |
+ checkAnswer( | |
+ spark.table(tbl).select("index", "_partition"), | |
+ Seq(Row(0, "3"), Row(0, "2"), Row(0, "1")) | |
+ ) | |
+ | |
+ checkAnswer( | |
+ spark.table(tbl).select("id", "index", "_partition"), | |
+ Seq(Row(3, 0, "3"), Row(2, 0, "2"), Row(1, 0, "1")) | |
+ ) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala | |
index bb2acecc78..64c893ed74 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala | |
@@ -18,7 +18,6 @@ | |
package org.apache.spark.sql.connector | |
import java.io.{BufferedReader, InputStreamReader, IOException} | |
-import java.util | |
import scala.collection.JavaConverters._ | |
@@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source { | |
new MyWriteBuilder(path, info) | |
} | |
- override def capabilities(): util.Set[TableCapability] = | |
- Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava | |
+ override def capabilities(): java.util.Set[TableCapability] = | |
+ java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) | |
} | |
override def getTable(options: CaseInsensitiveStringMap): Table = { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala | |
index 076dad7530..8d771b0736 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala | |
@@ -17,15 +17,19 @@ | |
package org.apache.spark.sql.connector | |
+import java.util.Optional | |
+ | |
+import scala.concurrent.duration.MICROSECONDS | |
import scala.language.implicitConversions | |
import scala.util.Try | |
import org.scalatest.BeforeAndAfter | |
import org.apache.spark.SparkException | |
-import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} | |
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, SaveMode} | |
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException | |
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} | |
+import org.apache.spark.sql.catalyst.util.DateTimeUtils | |
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsCatalogOptions, TableCatalog} | |
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME | |
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} | |
@@ -35,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types.{LongType, StructType} | |
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener} | |
+import org.apache.spark.unsafe.types.UTF8String | |
class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { | |
@@ -71,7 +76,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with | |
saveMode: SaveMode, | |
withCatalogOption: Option[String], | |
partitionBy: Seq[String]): Unit = { | |
- val df = spark.range(10).withColumn("part", 'id % 5) | |
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5) | |
val dfw = df.write.format(format).mode(saveMode).option("name", "t1") | |
withCatalogOption.foreach(cName => dfw.option("catalog", cName)) | |
dfw.partitionBy(partitionBy: _*).save() | |
@@ -136,7 +141,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with | |
test("Ignore mode if table exists - session catalog") { | |
sql(s"create table t1 (id bigint) using $format") | |
- val df = spark.range(10).withColumn("part", 'id % 5) | |
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5) | |
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") | |
dfw.save() | |
@@ -148,7 +153,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with | |
test("Ignore mode if table exists - testcat catalog") { | |
sql(s"create table $catalogName.t1 (id bigint) using $format") | |
- val df = spark.range(10).withColumn("part", 'id % 5) | |
+ val df = spark.range(10).withColumn("part", Symbol("id") % 5) | |
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") | |
dfw.option("catalog", catalogName).save() | |
@@ -271,6 +276,69 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with | |
} | |
} | |
+ test("time travel") { | |
+ // The testing in-memory table simply append the version/timestamp to the table name when | |
+ // looking up tables. | |
+ val t1 = s"$catalogName.tSnapshot123456789" | |
+ val t2 = s"$catalogName.t2345678910" | |
+ withTable(t1, t2) { | |
+ sql(s"create table $t1 (id bigint) using $format") | |
+ sql(s"create table $t2 (id bigint) using $format") | |
+ | |
+ val df1 = spark.range(10) | |
+ df1.write.format(format).option("name", "tSnapshot123456789").option("catalog", catalogName) | |
+ .mode(SaveMode.Append).save() | |
+ | |
+ val df2 = spark.range(10, 20) | |
+ df2.write.format(format).option("name", "t2345678910").option("catalog", catalogName) | |
+ .mode(SaveMode.Overwrite).save() | |
+ | |
+ // load with version | |
+ checkAnswer(load("t", Some(catalogName), version = Some("Snapshot123456789")), df1.toDF()) | |
+ checkAnswer(load("t", Some(catalogName), version = Some("2345678910")), df2.toDF()) | |
+ } | |
+ | |
+ val ts1 = DateTimeUtils.stringToTimestampAnsi( | |
+ UTF8String.fromString("2019-01-29 00:37:58"), | |
+ DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | |
+ val ts2 = DateTimeUtils.stringToTimestampAnsi( | |
+ UTF8String.fromString("2021-01-29 00:37:58"), | |
+ DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | |
+ val t3 = s"$catalogName.t$ts1" | |
+ val t4 = s"$catalogName.t$ts2" | |
+ withTable(t3, t4) { | |
+ sql(s"create table $t3 (id bigint) using $format") | |
+ sql(s"create table $t4 (id bigint) using $format") | |
+ | |
+ val df3 = spark.range(30, 40) | |
+ df3.write.format(format).option("name", s"t$ts1").option("catalog", catalogName) | |
+ .mode(SaveMode.Append).save() | |
+ | |
+ val df4 = spark.range(50, 60) | |
+ df4.write.format(format).option("name", s"t$ts2").option("catalog", catalogName) | |
+ .mode(SaveMode.Overwrite).save() | |
+ | |
+ // load with timestamp | |
+ checkAnswer(load("t", Some(catalogName), version = None, | |
+ timestamp = Some("2019-01-29 00:37:58")), df3.toDF()) | |
+ checkAnswer(load("t", Some(catalogName), version = None, | |
+ timestamp = Some("2021-01-29 00:37:58")), df4.toDF()) | |
+ | |
+ // load with timestamp in number format | |
+ checkAnswer(load("t", Some(catalogName), version = None, | |
+ timestamp = Some(MICROSECONDS.toSeconds(ts1).toString)), df3.toDF()) | |
+ checkAnswer(load("t", Some(catalogName), version = None, | |
+ timestamp = Some(MICROSECONDS.toSeconds(ts2).toString)), df4.toDF()) | |
+ } | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ load("t", Some(catalogName), version = Some("12345678"), | |
+ timestamp = Some("2019-01-29 00:37:58")) | |
+ } | |
+ assert(e.getMessage | |
+ .contains("Cannot specify both version and timestamp when time travelling the table.")) | |
+ } | |
+ | |
private def checkV2Identifiers( | |
plan: LogicalPlan, | |
identifier: String = "t1", | |
@@ -281,9 +349,19 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with | |
assert(v2.catalog.exists(_ == catalogPlugin)) | |
} | |
- private def load(name: String, catalogOpt: Option[String]): DataFrame = { | |
+ private def load( | |
+ name: String, | |
+ catalogOpt: Option[String], | |
+ version: Option[String] = None, | |
+ timestamp: Option[String] = None): DataFrame = { | |
val dfr = spark.read.format(format).option("name", name) | |
catalogOpt.foreach(cName => dfr.option("catalog", cName)) | |
+ if (version.nonEmpty) { | |
+ dfr.option("versionAsOf", version.get) | |
+ } | |
+ if (timestamp.nonEmpty) { | |
+ dfr.option("timestampAsOf", timestamp.get) | |
+ } | |
dfr.load() | |
} | |
@@ -312,4 +390,20 @@ class CatalogSupportingInMemoryTableProvider | |
override def extractCatalog(options: CaseInsensitiveStringMap): String = { | |
options.get("catalog") | |
} | |
+ | |
+ override def extractTimeTravelVersion(options: CaseInsensitiveStringMap): Optional[String] = { | |
+ if (options.get("versionAsOf") != null) { | |
+ Optional.of(options.get("versionAsOf")) | |
+ } else { | |
+ Optional.empty[String] | |
+ } | |
+ } | |
+ | |
+ override def extractTimeTravelTimestamp(options: CaseInsensitiveStringMap): Optional[String] = { | |
+ if (options.get("timestampAsOf") != null) { | |
+ Optional.of(options.get("timestampAsOf")) | |
+ } else { | |
+ Optional.empty[String] | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala | |
index ce94d3b5c2..5f2e0b28ae 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala | |
@@ -17,10 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
- | |
-import scala.collection.JavaConverters._ | |
- | |
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} | |
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} | |
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} | |
@@ -217,7 +213,11 @@ private case object TestRelation extends LeafNode with NamedRelation { | |
private case class CapabilityTable(_capabilities: TableCapability*) extends Table { | |
override def name(): String = "capability_test_table" | |
override def schema(): StructType = TableCapabilityCheckSuite.schema | |
- override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava | |
+ override def capabilities(): java.util.Set[TableCapability] = { | |
+ val set = java.util.EnumSet.noneOf(classOf[TableCapability]) | |
+ _capabilities.foreach(set.add) | |
+ set | |
+ } | |
} | |
private class TestStreamSourceProvider extends StreamSourceProvider { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala | |
index bf2749d1af..0a0aaa8021 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala | |
@@ -17,7 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
import java.util.concurrent.ConcurrentHashMap | |
import java.util.concurrent.atomic.AtomicBoolean | |
@@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType | |
*/ | |
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension { | |
- protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() | |
+ protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() | |
private val tableCreated: AtomicBoolean = new AtomicBoolean(false) | |
@@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating | |
name: String, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): T | |
+ properties: java.util.Map[String, String]): T | |
override def loadTable(ident: Identifier): Table = { | |
if (tables.containsKey(ident)) { | |
@@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating | |
ident: Identifier, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): Table = { | |
+ properties: java.util.Map[String, String]): Table = { | |
val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY | |
val propsWithLocation = if (properties.containsKey(key)) { | |
// Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. | |
if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { | |
- val newProps = new util.HashMap[String, String]() | |
+ val newProps = new java.util.HashMap[String, String]() | |
newProps.putAll(properties) | |
newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") | |
newProps | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala | |
index 847953e09c..c5be222645 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala | |
@@ -17,10 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
- | |
-import scala.collection.JavaConverters._ | |
- | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} | |
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} | |
@@ -106,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { | |
ident: Identifier, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): Table = { | |
+ properties: java.util.Map[String, String]): Table = { | |
// To simplify the test implementation, only support fixed schema. | |
if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { | |
throw new UnsupportedOperationException | |
@@ -131,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp | |
override def schema(): StructType = V1ReadFallbackCatalog.schema | |
- override def capabilities(): util.Set[TableCapability] = { | |
- Set(TableCapability.BATCH_READ).asJava | |
+ override def capabilities(): java.util.Set[TableCapability] = { | |
+ java.util.EnumSet.of(TableCapability.BATCH_READ) | |
} | |
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala | |
index 7effc747ab..992c46cc6c 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala | |
@@ -17,8 +17,6 @@ | |
package org.apache.spark.sql.connector | |
-import java.util | |
- | |
import scala.collection.JavaConverters._ | |
import scala.collection.mutable | |
@@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV | |
name: String, | |
schema: StructType, | |
partitions: Array[Transform], | |
- properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { | |
+ properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = { | |
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) | |
InMemoryV1Provider.tables.put(name, t) | |
tables.put(Identifier.of(Array("default"), name), t) | |
@@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback( | |
override val name: String, | |
override val schema: StructType, | |
override val partitioning: Array[Transform], | |
- override val properties: util.Map[String, String]) | |
+ override val properties: java.util.Map[String, String]) | |
extends Table | |
with SupportsWrite with SupportsRead { | |
@@ -331,11 +329,11 @@ class InMemoryTableWithV1Fallback( | |
} | |
} | |
- override def capabilities: util.Set[TableCapability] = Set( | |
+ override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of( | |
TableCapability.BATCH_READ, | |
TableCapability.V1_BATCH_WRITE, | |
TableCapability.OVERWRITE_BY_FILTER, | |
- TableCapability.TRUNCATE).asJava | |
+ TableCapability.TRUNCATE) | |
@volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty | |
private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala | |
index f262cf152c..ea30a6f25c 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala | |
@@ -17,8 +17,8 @@ | |
package org.apache.spark.sql.connector | |
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition} | |
-import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect} | |
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedDBObjectName, UnresolvedFieldName, UnresolvedFieldPosition} | |
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, TableSpec} | |
import org.apache.spark.sql.catalyst.rules.Rule | |
import org.apache.spark.sql.connector.catalog.Identifier | |
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition | |
@@ -46,12 +46,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes | |
Seq(true, false).foreach { caseSensitive => | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { | |
Seq("ID", "iD").foreach { ref => | |
+ val tableSpec = TableSpec(Map.empty, None, Map.empty, | |
+ None, None, None, false) | |
val plan = CreateTableAsSelect( | |
- catalog, | |
- Identifier.of(Array(), "table_name"), | |
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false), | |
Expressions.identity(ref) :: Nil, | |
TestRelation2, | |
- Map.empty, | |
+ tableSpec, | |
Map.empty, | |
ignoreIfExists = false) | |
@@ -69,12 +70,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes | |
Seq(true, false).foreach { caseSensitive => | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { | |
Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => | |
+ val tableSpec = TableSpec(Map.empty, None, Map.empty, | |
+ None, None, None, false) | |
val plan = CreateTableAsSelect( | |
- catalog, | |
- Identifier.of(Array(), "table_name"), | |
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false), | |
Expressions.bucket(4, ref) :: Nil, | |
TestRelation2, | |
- Map.empty, | |
+ tableSpec, | |
Map.empty, | |
ignoreIfExists = false) | |
@@ -93,12 +95,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes | |
Seq(true, false).foreach { caseSensitive => | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { | |
Seq("ID", "iD").foreach { ref => | |
+ val tableSpec = TableSpec(Map.empty, None, Map.empty, | |
+ None, None, None, false) | |
val plan = ReplaceTableAsSelect( | |
- catalog, | |
- Identifier.of(Array(), "table_name"), | |
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false), | |
Expressions.identity(ref) :: Nil, | |
TestRelation2, | |
- Map.empty, | |
+ tableSpec, | |
Map.empty, | |
orCreate = true) | |
@@ -116,12 +119,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes | |
Seq(true, false).foreach { caseSensitive => | |
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { | |
Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => | |
+ val tableSpec = TableSpec(Map.empty, None, Map.empty, | |
+ None, None, None, false) | |
val plan = ReplaceTableAsSelect( | |
- catalog, | |
- Identifier.of(Array(), "table_name"), | |
+ UnresolvedDBObjectName(Array("table_name"), isNamespace = false), | |
Expressions.bucket(4, ref) :: Nil, | |
TestRelation2, | |
- Map.empty, | |
+ tableSpec, | |
Map.empty, | |
orCreate = true) | |
@@ -273,10 +277,21 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes | |
test("AlterTable: drop column resolution") { | |
Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => | |
- alterTableTest( | |
- DropColumns(table, Seq(UnresolvedFieldName(ref))), | |
- Seq("Missing field " + ref.quoted) | |
- ) | |
+ Seq(true, false).foreach { ifExists => | |
+ val expectedErrors = if (ifExists) { | |
+ Seq.empty[String] | |
+ } else { | |
+ Seq("Missing field " + ref.quoted) | |
+ } | |
+ val alter = DropColumns(table, Seq(UnresolvedFieldName(ref)), ifExists) | |
+ if (ifExists) { | |
+ // using IF EXISTS will silence all errors for missing columns | |
+ assertAnalysisSuccess(alter, caseSensitive = true) | |
+ assertAnalysisSuccess(alter, caseSensitive = false) | |
+ } else { | |
+ alterTableTest(alter, expectedErrors, expectErrorOnCaseSensitive = true) | |
+ } | |
+ } | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala | |
index db4a9c153c..36efe5ec1d 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala | |
@@ -19,39 +19,31 @@ package org.apache.spark.sql.connector | |
import java.util.Collections | |
-import org.scalatest.BeforeAndAfter | |
- | |
-import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest} | |
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute | |
+import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} | |
import org.apache.spark.sql.catalyst.plans.physical | |
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} | |
-import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} | |
+import org.apache.spark.sql.connector.catalog.Identifier | |
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} | |
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} | |
import org.apache.spark.sql.connector.expressions.LogicalExpressions._ | |
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} | |
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec | |
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike | |
+import org.apache.spark.sql.execution.streaming.MemoryStream | |
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream | |
import org.apache.spark.sql.functions.lit | |
-import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} | |
import org.apache.spark.sql.types.{IntegerType, StringType, StructType} | |
import org.apache.spark.sql.util.QueryExecutionListener | |
-class WriteDistributionAndOrderingSuite | |
- extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { | |
- | |
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
- | |
- before { | |
- spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) | |
- } | |
+class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase { | |
+ import testImplicits._ | |
after { | |
spark.sessionState.catalogManager.reset() | |
- spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") | |
} | |
+ private val microBatchPrefix = "micro_batch_" | |
private val namespace = Array("ns1") | |
private val ident = Identifier.of(namespace, "test_table") | |
private val tableNameAsString = "testcat." + ident.toString | |
@@ -60,8 +52,6 @@ class WriteDistributionAndOrderingSuite | |
.add("id", IntegerType) | |
.add("data", StringType) | |
- private val resolver = conf.resolver | |
- | |
test("ordered distribution and sort with same exprs: append") { | |
checkOrderedDistributionAndSortWithSameExprs("append") | |
} | |
@@ -74,6 +64,18 @@ class WriteDistributionAndOrderingSuite | |
checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic") | |
} | |
+ test("ordered distribution and sort with same exprs: micro-batch append") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append") | |
+ } | |
+ | |
+ test("ordered distribution and sort with same exprs: micro-batch update") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update") | |
+ } | |
+ | |
+ test("ordered distribution and sort with same exprs: micro-batch complete") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete") | |
+ } | |
+ | |
test("ordered distribution and sort with same exprs with numPartitions: append") { | |
checkOrderedDistributionAndSortWithSameExprs("append", Some(10)) | |
} | |
@@ -86,6 +88,18 @@ class WriteDistributionAndOrderingSuite | |
checkOrderedDistributionAndSortWithSameExprs("overwriteDynamic", Some(10)) | |
} | |
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch append") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10)) | |
+ } | |
+ | |
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch update") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10)) | |
+ } | |
+ | |
+ test("ordered distribution and sort with same exprs with numPartitions: micro-batch complete") { | |
+ checkOrderedDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10)) | |
+ } | |
+ | |
private def checkOrderedDistributionAndSortWithSameExprs(command: String): Unit = { | |
checkOrderedDistributionAndSortWithSameExprs(command, None) | |
} | |
@@ -129,6 +143,18 @@ class WriteDistributionAndOrderingSuite | |
checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic") | |
} | |
+ test("clustered distribution and sort with same exprs: micro-batch append") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append") | |
+ } | |
+ | |
+ test("clustered distribution and sort with same exprs: micro-batch update") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update") | |
+ } | |
+ | |
+ test("clustered distribution and sort with same exprs: micro-batch complete") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete") | |
+ } | |
+ | |
test("clustered distribution and sort with same exprs with numPartitions: append") { | |
checkClusteredDistributionAndSortWithSameExprs("append", Some(10)) | |
} | |
@@ -141,6 +167,18 @@ class WriteDistributionAndOrderingSuite | |
checkClusteredDistributionAndSortWithSameExprs("overwriteDynamic", Some(10)) | |
} | |
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch append") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "append", Some(10)) | |
+ } | |
+ | |
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch update") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "update", Some(10)) | |
+ } | |
+ | |
+ test("clustered distribution and sort with same exprs with numPartitions: micro-batch complete") { | |
+ checkClusteredDistributionAndSortWithSameExprs(microBatchPrefix + "complete", Some(10)) | |
+ } | |
+ | |
private def checkClusteredDistributionAndSortWithSameExprs(command: String): Unit = { | |
checkClusteredDistributionAndSortWithSameExprs(command, None) | |
} | |
@@ -193,6 +231,18 @@ class WriteDistributionAndOrderingSuite | |
checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic") | |
} | |
+ test("clustered distribution and sort with extended exprs: micro-batch append") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append") | |
+ } | |
+ | |
+ test("clustered distribution and sort with extended exprs: micro-batch update") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update") | |
+ } | |
+ | |
+ test("clustered distribution and sort with extended exprs: micro-batch complete") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete") | |
+ } | |
+ | |
test("clustered distribution and sort with extended exprs with numPartitions: append") { | |
checkClusteredDistributionAndSortWithExtendedExprs("append", Some(10)) | |
} | |
@@ -206,6 +256,21 @@ class WriteDistributionAndOrderingSuite | |
checkClusteredDistributionAndSortWithExtendedExprs("overwriteDynamic", Some(10)) | |
} | |
+ test("clustered distribution and sort with extended exprs with numPartitions: " + | |
+ "micro-batch append") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "append", Some(10)) | |
+ } | |
+ | |
+ test("clustered distribution and sort with extended exprs with numPartitions: " + | |
+ "micro-batch update") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "update", Some(10)) | |
+ } | |
+ | |
+ test("clustered distribution and sort with extended exprs with numPartitions: " + | |
+ "micro-batch complete") { | |
+ checkClusteredDistributionAndSortWithExtendedExprs(microBatchPrefix + "complete", Some(10)) | |
+ } | |
+ | |
private def checkClusteredDistributionAndSortWithExtendedExprs(command: String): Unit = { | |
checkClusteredDistributionAndSortWithExtendedExprs(command, None) | |
} | |
@@ -258,6 +323,18 @@ class WriteDistributionAndOrderingSuite | |
checkUnspecifiedDistributionAndLocalSort("overwriteDynamic") | |
} | |
+ test("unspecified distribution and local sort: micro-batch append") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append") | |
+ } | |
+ | |
+ test("unspecified distribution and local sort: micro-batch update") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update") | |
+ } | |
+ | |
+ test("unspecified distribution and local sort: micro-batch complete") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete") | |
+ } | |
+ | |
test("unspecified distribution and local sort with numPartitions: append") { | |
checkUnspecifiedDistributionAndLocalSort("append", Some(10)) | |
} | |
@@ -270,6 +347,18 @@ class WriteDistributionAndOrderingSuite | |
checkUnspecifiedDistributionAndLocalSort("overwriteDynamic", Some(10)) | |
} | |
+ test("unspecified distribution and local sort with numPartitions: micro-batch append") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "append", Some(10)) | |
+ } | |
+ | |
+ test("unspecified distribution and local sort with numPartitions: micro-batch update") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "update", Some(10)) | |
+ } | |
+ | |
+ test("unspecified distribution and local sort with numPartitions: micro-batch complete") { | |
+ checkUnspecifiedDistributionAndLocalSort(microBatchPrefix + "complete", Some(10)) | |
+ } | |
+ | |
private def checkUnspecifiedDistributionAndLocalSort(command: String): Unit = { | |
checkUnspecifiedDistributionAndLocalSort(command, None) | |
} | |
@@ -316,6 +405,18 @@ class WriteDistributionAndOrderingSuite | |
checkUnspecifiedDistributionAndNoSort("overwriteDynamic") | |
} | |
+ test("unspecified distribution and no sort: micro-batch append") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append") | |
+ } | |
+ | |
+ test("unspecified distribution and no sort: micro-batch update") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update") | |
+ } | |
+ | |
+ test("unspecified distribution and no sort: micro-batch complete") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete") | |
+ } | |
+ | |
test("unspecified distribution and no sort with numPartitions: append") { | |
checkUnspecifiedDistributionAndNoSort("append", Some(10)) | |
} | |
@@ -328,6 +429,18 @@ class WriteDistributionAndOrderingSuite | |
checkUnspecifiedDistributionAndNoSort("overwriteDynamic", Some(10)) | |
} | |
+ test("unspecified distribution and no sort with numPartitions: micro-batch append") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "append", Some(10)) | |
+ } | |
+ | |
+ test("unspecified distribution and no sort with numPartitions: micro-batch update") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "update", Some(10)) | |
+ } | |
+ | |
+ test("unspecified distribution and no sort with numPartitions: micro-batch complete") { | |
+ checkUnspecifiedDistributionAndNoSort(microBatchPrefix + "complete", Some(10)) | |
+ } | |
+ | |
private def checkUnspecifiedDistributionAndNoSort(command: String): Unit = { | |
checkUnspecifiedDistributionAndNoSort(command, None) | |
} | |
@@ -677,7 +790,95 @@ class WriteDistributionAndOrderingSuite | |
writeCommand = command) | |
} | |
+ test("continuous mode does not support write distribution and ordering") { | |
+ val ordering = Array[SortOrder]( | |
+ sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) | |
+ ) | |
+ val distribution = Distributions.ordered(ordering) | |
+ | |
+ catalog.createTable(ident, schema, Array.empty, emptyProps, distribution, ordering, None) | |
+ | |
+ withTempDir { checkpointDir => | |
+ val inputData = ContinuousMemoryStream[(Long, String)] | |
+ val inputDF = inputData.toDF().toDF("id", "data") | |
+ | |
+ val writer = inputDF | |
+ .writeStream | |
+ .trigger(Trigger.Continuous(100)) | |
+ .option("checkpointLocation", checkpointDir.getAbsolutePath) | |
+ .outputMode("append") | |
+ | |
+ val analysisException = intercept[AnalysisException] { | |
+ val query = writer.toTable(tableNameAsString) | |
+ | |
+ inputData.addData((1, "a"), (2, "b")) | |
+ | |
+ query.processAllAvailable() | |
+ query.stop() | |
+ } | |
+ | |
+ assert(analysisException.message.contains("Sinks cannot request distribution and ordering")) | |
+ } | |
+ } | |
+ | |
+ test("continuous mode allows unspecified distribution and empty ordering") { | |
+ catalog.createTable(ident, schema, Array.empty, emptyProps) | |
+ | |
+ withTempDir { checkpointDir => | |
+ val inputData = ContinuousMemoryStream[(Long, String)] | |
+ val inputDF = inputData.toDF().toDF("id", "data") | |
+ | |
+ val writer = inputDF | |
+ .writeStream | |
+ .trigger(Trigger.Continuous(100)) | |
+ .option("checkpointLocation", checkpointDir.getAbsolutePath) | |
+ .outputMode("append") | |
+ | |
+ val query = writer.toTable(tableNameAsString) | |
+ | |
+ inputData.addData((1, "a"), (2, "b")) | |
+ | |
+ query.processAllAvailable() | |
+ query.stop() | |
+ | |
+ checkAnswer(spark.table(tableNameAsString), Row(1, "a") :: Row(2, "b") :: Nil) | |
+ } | |
+ } | |
+ | |
private def checkWriteRequirements( | |
+ tableDistribution: Distribution, | |
+ tableOrdering: Array[SortOrder], | |
+ tableNumPartitions: Option[Int], | |
+ expectedWritePartitioning: physical.Partitioning, | |
+ expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], | |
+ writeTransform: DataFrame => DataFrame = df => df, | |
+ writeCommand: String, | |
+ expectAnalysisException: Boolean = false): Unit = { | |
+ | |
+ if (writeCommand.startsWith(microBatchPrefix)) { | |
+ checkMicroBatchWriteRequirements( | |
+ tableDistribution, | |
+ tableOrdering, | |
+ tableNumPartitions, | |
+ expectedWritePartitioning, | |
+ expectedWriteOrdering, | |
+ writeTransform, | |
+ outputMode = writeCommand.stripPrefix(microBatchPrefix), | |
+ expectAnalysisException) | |
+ } else { | |
+ checkBatchWriteRequirements( | |
+ tableDistribution, | |
+ tableOrdering, | |
+ tableNumPartitions, | |
+ expectedWritePartitioning, | |
+ expectedWriteOrdering, | |
+ writeTransform, | |
+ writeCommand, | |
+ expectAnalysisException) | |
+ } | |
+ } | |
+ | |
+ private def checkBatchWriteRequirements( | |
tableDistribution: Distribution, | |
tableOrdering: Array[SortOrder], | |
tableNumPartitions: Option[Int], | |
@@ -712,15 +913,84 @@ class WriteDistributionAndOrderingSuite | |
} | |
} | |
+ private def checkMicroBatchWriteRequirements( | |
+ tableDistribution: Distribution, | |
+ tableOrdering: Array[SortOrder], | |
+ tableNumPartitions: Option[Int], | |
+ expectedWritePartitioning: physical.Partitioning, | |
+ expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], | |
+ writeTransform: DataFrame => DataFrame = df => df, | |
+ outputMode: String = "append", | |
+ expectAnalysisException: Boolean = false): Unit = { | |
+ | |
+ catalog.createTable(ident, schema, Array.empty, emptyProps, tableDistribution, | |
+ tableOrdering, tableNumPartitions) | |
+ | |
+ withTempDir { checkpointDir => | |
+ val inputData = MemoryStream[(Long, String)] | |
+ val inputDF = inputData.toDF().toDF("id", "data") | |
+ | |
+ val queryDF = outputMode match { | |
+ case "append" | "update" => | |
+ inputDF | |
+ case "complete" => | |
+ // add an aggregate for complete mode | |
+ inputDF | |
+ .groupBy("id") | |
+ .agg(Map("data" -> "count")) | |
+ .select($"id", $"count(data)".cast("string").as("data")) | |
+ } | |
+ | |
+ val writer = writeTransform(queryDF) | |
+ .writeStream | |
+ .option("checkpointLocation", checkpointDir.getAbsolutePath) | |
+ .outputMode(outputMode) | |
+ | |
+ def executeCommand(): SparkPlan = execute { | |
+ val query = writer.toTable(tableNameAsString) | |
+ | |
+ inputData.addData((1, "a"), (2, "b")) | |
+ | |
+ query.processAllAvailable() | |
+ query.stop() | |
+ } | |
+ | |
+ if (expectAnalysisException) { | |
+ val streamingQueryException = intercept[StreamingQueryException] { | |
+ executeCommand() | |
+ } | |
+ val cause = streamingQueryException.cause | |
+ assert(cause.getMessage.contains("number of partitions can't be specified")) | |
+ | |
+ } else { | |
+ val executedPlan = executeCommand() | |
+ | |
+ checkPartitioningAndOrdering( | |
+ executedPlan, | |
+ expectedWritePartitioning, | |
+ expectedWriteOrdering, | |
+ // there is an extra shuffle for groupBy in complete mode | |
+ maxNumShuffles = if (outputMode != "complete") 1 else 2) | |
+ | |
+ val expectedRows = outputMode match { | |
+ case "append" | "update" => Row(1, "a") :: Row(2, "b") :: Nil | |
+ case "complete" => Row(1, "1") :: Row(2, "1") :: Nil | |
+ } | |
+ checkAnswer(spark.table(tableNameAsString), expectedRows) | |
+ } | |
+ } | |
+ } | |
+ | |
private def checkPartitioningAndOrdering( | |
plan: SparkPlan, | |
partitioning: physical.Partitioning, | |
- ordering: Seq[catalyst.expressions.SortOrder]): Unit = { | |
+ ordering: Seq[catalyst.expressions.SortOrder], | |
+ maxNumShuffles: Int = 1): Unit = { | |
val sorts = collect(plan) { case s: SortExec => s } | |
assert(sorts.size <= 1, "must be at most one sort") | |
val shuffles = collect(plan) { case s: ShuffleExchangeLike => s } | |
- assert(shuffles.size <= 1, "must be at most one shuffle") | |
+ assert(shuffles.size <= maxNumShuffles, $"must be at most $maxNumShuffles shuffles") | |
val actualPartitioning = plan.outputPartitioning | |
val expectedPartitioning = partitioning match { | |
@@ -730,6 +1000,9 @@ class WriteDistributionAndOrderingSuite | |
case p: physical.HashPartitioning => | |
val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) | |
p.copy(expressions = resolvedExprs) | |
+ case _: UnknownPartitioning => | |
+ // don't check partitioning if no particular one is expected | |
+ actualPartitioning | |
case other => other | |
} | |
assert(actualPartitioning == expectedPartitioning, "partitioning must match") | |
@@ -739,28 +1012,6 @@ class WriteDistributionAndOrderingSuite | |
assert(actualOrdering == expectedOrdering, "ordering must match") | |
} | |
- private def resolveAttrs( | |
- expr: catalyst.expressions.Expression, | |
- plan: SparkPlan): catalyst.expressions.Expression = { | |
- | |
- expr.transform { | |
- case UnresolvedAttribute(Seq(attrName)) => | |
- plan.output.find(attr => resolver(attr.name, attrName)).get | |
- case UnresolvedAttribute(nameParts) => | |
- val attrName = nameParts.mkString(".") | |
- fail(s"cannot resolve a nested attr: $attrName") | |
- } | |
- } | |
- | |
- private def attr(name: String): UnresolvedAttribute = { | |
- UnresolvedAttribute(name) | |
- } | |
- | |
- private def catalog: InMemoryTableCatalog = { | |
- val catalog = spark.sessionState.catalogManager.catalog("testcat") | |
- catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog] | |
- } | |
- | |
// executes a write operation and keeps the executed physical plan | |
private def execute(writeFunc: => Unit): SparkPlan = { | |
var executedPlan: SparkPlan = null | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala | |
new file mode 100644 | |
index 0000000000..1994874d32 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala | |
@@ -0,0 +1,78 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+package org.apache.spark.sql.connector.catalog.functions | |
+ | |
+import org.apache.spark.sql.types._ | |
+ | |
+object UnboundYearsFunction extends UnboundFunction { | |
+ override def bind(inputType: StructType): BoundFunction = { | |
+ if (inputType.size == 1 && isValidType(inputType.head.dataType)) YearsFunction | |
+ else throw new UnsupportedOperationException( | |
+ "'years' only take date or timestamp as input type") | |
+ } | |
+ | |
+ private def isValidType(dt: DataType): Boolean = dt match { | |
+ case DateType | TimestampType => true | |
+ case _ => false | |
+ } | |
+ | |
+ override def description(): String = name() | |
+ override def name(): String = "years" | |
+} | |
+ | |
+object YearsFunction extends BoundFunction { | |
+ override def inputTypes(): Array[DataType] = Array(TimestampType) | |
+ override def resultType(): DataType = LongType | |
+ override def name(): String = "years" | |
+ override def canonicalName(): String = name() | |
+} | |
+ | |
+object DaysFunction extends BoundFunction { | |
+ override def inputTypes(): Array[DataType] = Array(TimestampType) | |
+ override def resultType(): DataType = LongType | |
+ override def name(): String = "days" | |
+ override def canonicalName(): String = name() | |
+} | |
+ | |
+object UnboundDaysFunction extends UnboundFunction { | |
+ override def bind(inputType: StructType): BoundFunction = { | |
+ if (inputType.size == 1 && isValidType(inputType.head.dataType)) DaysFunction | |
+ else throw new UnsupportedOperationException( | |
+ "'days' only take date or timestamp as input type") | |
+ } | |
+ | |
+ private def isValidType(dt: DataType): Boolean = dt match { | |
+ case DateType | TimestampType => true | |
+ case _ => false | |
+ } | |
+ | |
+ override def description(): String = name() | |
+ override def name(): String = "days" | |
+} | |
+ | |
+object UnboundBucketFunction extends UnboundFunction { | |
+ override def bind(inputType: StructType): BoundFunction = BucketFunction | |
+ override def description(): String = name() | |
+ override def name(): String = "bucket" | |
+} | |
+ | |
+object BucketFunction extends BoundFunction { | |
+ override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) | |
+ override def resultType(): DataType = IntegerType | |
+ override def name(): String = "bucket" | |
+ override def canonicalName(): String = name() | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala | |
new file mode 100644 | |
index 0000000000..a8fbfc49a5 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala | |
@@ -0,0 +1,52 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.errors | |
+ | |
+import org.apache.spark.sql.{AnalysisException, QueryTest} | |
+import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2Provider} | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+class QueryCompilationErrorsDSv2Suite | |
+ extends QueryTest | |
+ with SharedSparkSession | |
+ with DatasourceV2SQLBase { | |
+ | |
+ test("UNSUPPORTED_FEATURE: IF PARTITION NOT EXISTS not supported by INSERT") { | |
+ val v2Format = classOf[FakeV2Provider].getName | |
+ val tbl = "testcat.ns1.ns2.tbl" | |
+ | |
+ withTable(tbl) { | |
+ val view = "tmp_view" | |
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") | |
+ df.createOrReplaceTempView(view) | |
+ withTempView(view) { | |
+ sql(s"CREATE TABLE $tbl (id bigint, data string) USING $v2Format PARTITIONED BY (id)") | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ sql(s"INSERT OVERWRITE TABLE $tbl PARTITION (id = 1) IF NOT EXISTS SELECT * FROM $view") | |
+ } | |
+ | |
+ checkAnswer(spark.table(tbl), spark.emptyDataFrame) | |
+ assert(e.getMessage === "The feature is not supported: " + | |
+ s"""IF NOT EXISTS for the table `testcat`.`ns1`.`ns2`.`tbl` by INSERT INTO.""") | |
+ assert(e.getErrorClass === "UNSUPPORTED_FEATURE") | |
+ assert(e.getSqlState === "0A000") | |
+ } | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala | |
new file mode 100644 | |
index 0000000000..9e18e4e669 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala | |
@@ -0,0 +1,254 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.errors | |
+ | |
+import org.apache.spark.sql.{AnalysisException, ClassData, IntegratedUDFTestUtils, QueryTest} | |
+import org.apache.spark.sql.functions.{grouping, grouping_id, sum} | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+case class StringLongClass(a: String, b: Long) | |
+ | |
+case class StringIntClass(a: String, b: Int) | |
+ | |
+case class ComplexClass(a: Long, b: StringLongClass) | |
+ | |
+case class ArrayClass(arr: Seq[StringIntClass]) | |
+ | |
+class QueryCompilationErrorsSuite extends QueryTest with SharedSparkSession { | |
+ import testImplicits._ | |
+ | |
+ test("CANNOT_UP_CAST_DATATYPE: invalid upcast data type") { | |
+ val msg1 = intercept[AnalysisException] { | |
+ sql("select 'value1' as a, 1L as b").as[StringIntClass] | |
+ }.message | |
+ assert(msg1 === | |
+ s""" | |
+ |Cannot up cast b from "BIGINT" to "INT". | |
+ |The type path of the target object is: | |
+ |- field (class: "scala.Int", name: "b") | |
+ |- root class: "org.apache.spark.sql.errors.StringIntClass" | |
+ |You can either add an explicit cast to the input data or choose a higher precision type | |
+ """.stripMargin.trim + " of the field in the target object") | |
+ | |
+ val msg2 = intercept[AnalysisException] { | |
+ sql("select 1L as a," + | |
+ " named_struct('a', 'value1', 'b', cast(1.0 as decimal(38,18))) as b") | |
+ .as[ComplexClass] | |
+ }.message | |
+ assert(msg2 === | |
+ s""" | |
+ |Cannot up cast b.`b` from "DECIMAL(38,18)" to "BIGINT". | |
+ |The type path of the target object is: | |
+ |- field (class: "scala.Long", name: "b") | |
+ |- field (class: "org.apache.spark.sql.errors.StringLongClass", name: "b") | |
+ |- root class: "org.apache.spark.sql.errors.ComplexClass" | |
+ |You can either add an explicit cast to the input data or choose a higher precision type | |
+ """.stripMargin.trim + " of the field in the target object") | |
+ } | |
+ | |
+ test("UNSUPPORTED_GROUPING_EXPRESSION: filter with grouping/grouping_Id expression") { | |
+ val df = Seq( | |
+ (536361, "85123A", 2, 17850), | |
+ (536362, "85123B", 4, 17850), | |
+ (536363, "86123A", 6, 17851) | |
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ Seq("grouping", "grouping_id").foreach { grouping => | |
+ val errMsg = intercept[AnalysisException] { | |
+ df.groupBy("CustomerId").agg(Map("Quantity" -> "max")) | |
+ .filter(s"$grouping(CustomerId)=17850") | |
+ } | |
+ assert(errMsg.message === | |
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") | |
+ assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION")) | |
+ } | |
+ } | |
+ | |
+ test("UNSUPPORTED_GROUPING_EXPRESSION: Sort with grouping/grouping_Id expression") { | |
+ val df = Seq( | |
+ (536361, "85123A", 2, 17850), | |
+ (536362, "85123B", 4, 17850), | |
+ (536363, "86123A", 6, 17851) | |
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ Seq(grouping("CustomerId"), grouping_id("CustomerId")).foreach { grouping => | |
+ val errMsg = intercept[AnalysisException] { | |
+ df.groupBy("CustomerId").agg(Map("Quantity" -> "max")). | |
+ sort(grouping) | |
+ } | |
+ assert(errMsg.errorClass === Some("UNSUPPORTED_GROUPING_EXPRESSION")) | |
+ assert(errMsg.message === | |
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") | |
+ } | |
+ } | |
+ | |
+ test("INVALID_PARAMETER_VALUE: the argument_index of string format is invalid") { | |
+ withSQLConf(SQLConf.ALLOW_ZERO_INDEX_IN_FORMAT_STRING.key -> "false") { | |
+ val e = intercept[AnalysisException] { | |
+ sql("select format_string('%0$s', 'Hello')") | |
+ } | |
+ assert(e.errorClass === Some("INVALID_PARAMETER_VALUE")) | |
+ assert(e.message === "The value of parameter(s) 'strfmt' in `format_string` is invalid: " + | |
+ "expects %1$, %2$ and so on, but got %0$.") | |
+ } | |
+ } | |
+ | |
+ test("CANNOT_USE_MIXTURE: Using aggregate function with grouped aggregate pandas UDF") { | |
+ import IntegratedUDFTestUtils._ | |
+ assume(shouldTestGroupedAggPandasUDFs) | |
+ | |
+ val df = Seq( | |
+ (536361, "85123A", 2, 17850), | |
+ (536362, "85123B", 4, 17850), | |
+ (536363, "86123A", 6, 17851) | |
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ val e = intercept[AnalysisException] { | |
+ val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") | |
+ df.groupBy("CustomerId") | |
+ .agg(pandasTestUDF(df("Quantity")), sum(df("Quantity"))).collect() | |
+ } | |
+ | |
+ assert(e.errorClass === Some("CANNOT_USE_MIXTURE")) | |
+ assert(e.message === | |
+ "Cannot use a mixture of aggregate function and group aggregate pandas UDF") | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: Using Python UDF with unsupported join condition") { | |
+ import IntegratedUDFTestUtils._ | |
+ | |
+ val df1 = Seq( | |
+ (536361, "85123A", 2, 17850), | |
+ (536362, "85123B", 4, 17850), | |
+ (536363, "86123A", 6, 17851) | |
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ val df2 = Seq( | |
+ ("Bob", 17850), | |
+ ("Alice", 17850), | |
+ ("Tom", 17851) | |
+ ).toDF("CustomerName", "CustomerID") | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ val pythonTestUDF = TestPythonUDF(name = "python_udf") | |
+ df1.join( | |
+ df2, pythonTestUDF(df1("CustomerID") === df2("CustomerID")), "leftouter").collect() | |
+ } | |
+ | |
+ assert(e.errorClass === Some("UNSUPPORTED_FEATURE")) | |
+ assert(e.getSqlState === "0A000") | |
+ assert(e.message === | |
+ "The feature is not supported: " + | |
+ "Using PythonUDF in join condition of join type LEFT OUTER is not supported.") | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { | |
+ import IntegratedUDFTestUtils._ | |
+ assume(shouldTestGroupedAggPandasUDFs) | |
+ | |
+ val df = Seq( | |
+ (536361, "85123A", 2, 17850), | |
+ (536362, "85123B", 4, 17850), | |
+ (536363, "86123A", 6, 17851) | |
+ ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID") | |
+ | |
+ val e = intercept[AnalysisException] { | |
+ val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") | |
+ df.groupBy(df("CustomerID")).pivot(df("CustomerID")).agg(pandasTestUDF(df("Quantity"))) | |
+ } | |
+ | |
+ assert(e.errorClass === Some("UNSUPPORTED_FEATURE")) | |
+ assert(e.getSqlState === "0A000") | |
+ assert(e.message === | |
+ "The feature is not supported: " + | |
+ "Pandas UDF aggregate expressions don't support pivot.") | |
+ } | |
+ | |
+ test("UNSUPPORTED_DESERIALIZER: data type mismatch") { | |
+ val e = intercept[AnalysisException] { | |
+ sql("select 1 as arr").as[ArrayClass] | |
+ } | |
+ assert(e.errorClass === Some("UNSUPPORTED_DESERIALIZER")) | |
+ assert(e.message === | |
+ """The deserializer is not supported: need a(n) "ARRAY" field but got "INT".""") | |
+ } | |
+ | |
+ test("UNSUPPORTED_DESERIALIZER:" + | |
+ "the real number of fields doesn't match encoder schema") { | |
+ val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() | |
+ | |
+ val e1 = intercept[AnalysisException] { | |
+ ds.as[(String, Int, Long)] | |
+ } | |
+ assert(e1.errorClass === Some("UNSUPPORTED_DESERIALIZER")) | |
+ assert(e1.message === | |
+ "The deserializer is not supported: try to map \"STRUCT<a: STRING, b: INT>\" " + | |
+ "to Tuple3, but failed as the number of fields does not line up.") | |
+ | |
+ val e2 = intercept[AnalysisException] { | |
+ ds.as[Tuple1[String]] | |
+ } | |
+ assert(e2.errorClass === Some("UNSUPPORTED_DESERIALIZER")) | |
+ assert(e2.message === | |
+ "The deserializer is not supported: try to map \"STRUCT<a: STRING, b: INT>\" " + | |
+ "to Tuple1, but failed as the number of fields does not line up.") | |
+ } | |
+ | |
+ test("UNSUPPORTED_GENERATOR: " + | |
+ "generators are not supported when it's nested in expressions") { | |
+ val e = intercept[AnalysisException]( | |
+ sql("""select explode(Array(1, 2, 3)) + 1""").collect() | |
+ ) | |
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR")) | |
+ assert(e.message === | |
+ """The generator is not supported: """ + | |
+ """nested in expressions "(explode(array(1, 2, 3)) + 1)"""") | |
+ } | |
+ | |
+ test("UNSUPPORTED_GENERATOR: only one generator allowed") { | |
+ val e = intercept[AnalysisException]( | |
+ sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect() | |
+ ) | |
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR")) | |
+ assert(e.message === | |
+ "The generator is not supported: only one generator allowed per select clause " + | |
+ """but found 2: "explode(array(1, 2, 3))", "explode(array(1, 2, 3))"""") | |
+ } | |
+ | |
+ test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") { | |
+ val e = intercept[AnalysisException]( | |
+ sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect() | |
+ ) | |
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR")) | |
+ assert(e.message === | |
+ "The generator is not supported: outside the SELECT clause, found: " + | |
+ "'Sort [explode(array(1, 2, 3)) ASC NULLS FIRST], true") | |
+ } | |
+ | |
+ test("UNSUPPORTED_GENERATOR: not a generator") { | |
+ val e = intercept[AnalysisException]( | |
+ sql( | |
+ """ | |
+ |SELECT explodedvalue.* | |
+ |FROM VALUES array(1, 2, 3) AS (value) | |
+ |LATERAL VIEW array_contains(value, 1) AS explodedvalue""".stripMargin).collect() | |
+ ) | |
+ assert(e.errorClass === Some("UNSUPPORTED_GENERATOR")) | |
+ assert(e.message === | |
+ """The generator is not supported: `array_contains` is expected to be a generator. """ + | |
+ "However, its class is org.apache.spark.sql.catalyst.expressions.ArrayContains, " + | |
+ "which is not a generator.") | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala | |
new file mode 100644 | |
index 0000000000..21acea53ed | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala | |
@@ -0,0 +1,314 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.errors | |
+ | |
+import java.io.File | |
+import java.net.URI | |
+ | |
+import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} | |
+import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} | |
+import org.apache.spark.sql.execution.datasources.orc.OrcTest | |
+import org.apache.spark.sql.execution.datasources.parquet.ParquetTest | |
+import org.apache.spark.sql.functions.{lit, lower, struct, sum} | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.EXCEPTION | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.util.Utils | |
+ | |
+class QueryExecutionErrorsSuite extends QueryTest | |
+ with ParquetTest with OrcTest with SharedSparkSession { | |
+ | |
+ import testImplicits._ | |
+ | |
+ private def getAesInputs(): (DataFrame, DataFrame) = { | |
+ val encryptedText16 = "4Hv0UKCx6nfUeAoPZo1z+w==" | |
+ val encryptedText24 = "NeTYNgA+PCQBN50DA//O2w==" | |
+ val encryptedText32 = "9J3iZbIxnmaG+OIA9Amd+A==" | |
+ val encryptedEmptyText16 = "jmTOhz8XTbskI/zYFFgOFQ==" | |
+ val encryptedEmptyText24 = "9RDK70sHNzqAFRcpfGM5gQ==" | |
+ val encryptedEmptyText32 = "j9IDsCvlYXtcVJUf4FAjQQ==" | |
+ | |
+ val df1 = Seq("Spark", "").toDF | |
+ val df2 = Seq( | |
+ (encryptedText16, encryptedText24, encryptedText32), | |
+ (encryptedEmptyText16, encryptedEmptyText24, encryptedEmptyText32) | |
+ ).toDF("value16", "value24", "value32") | |
+ | |
+ (df1, df2) | |
+ } | |
+ | |
+ test("INVALID_PARAMETER_VALUE: invalid key lengths in AES functions") { | |
+ val (df1, df2) = getAesInputs() | |
+ def checkInvalidKeyLength(df: => DataFrame): Unit = { | |
+ val e = intercept[SparkException] { | |
+ df.collect | |
+ }.getCause.asInstanceOf[SparkRuntimeException] | |
+ assert(e.getErrorClass === "INVALID_PARAMETER_VALUE") | |
+ assert(e.getSqlState === "22023") | |
+ assert(e.getMessage.matches( | |
+ "The value of parameter\\(s\\) 'key' in the `aes_encrypt`/`aes_decrypt` function " + | |
+ "is invalid: expects a binary value with 16, 24 or 32 bytes, but got \\d+ bytes.")) | |
+ } | |
+ | |
+ // Encryption failure - invalid key length | |
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, '12345678901234567')")) | |
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, binary('123456789012345'))")) | |
+ checkInvalidKeyLength(df1.selectExpr("aes_encrypt(value, binary(''))")) | |
+ | |
+ // Decryption failure - invalid key length | |
+ Seq("value16", "value24", "value32").foreach { colName => | |
+ checkInvalidKeyLength(df2.selectExpr( | |
+ s"aes_decrypt(unbase64($colName), '12345678901234567')")) | |
+ checkInvalidKeyLength(df2.selectExpr( | |
+ s"aes_decrypt(unbase64($colName), binary('123456789012345'))")) | |
+ checkInvalidKeyLength(df2.selectExpr( | |
+ s"aes_decrypt(unbase64($colName), '')")) | |
+ checkInvalidKeyLength(df2.selectExpr( | |
+ s"aes_decrypt(unbase64($colName), binary(''))")) | |
+ } | |
+ } | |
+ | |
+ test("INVALID_PARAMETER_VALUE: AES decrypt failure - key mismatch") { | |
+ val (_, df2) = getAesInputs() | |
+ Seq( | |
+ ("value16", "1234567812345678"), | |
+ ("value24", "123456781234567812345678"), | |
+ ("value32", "12345678123456781234567812345678")).foreach { case (colName, key) => | |
+ val e = intercept[SparkException] { | |
+ df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 'ECB')").collect | |
+ }.getCause.asInstanceOf[SparkRuntimeException] | |
+ assert(e.getErrorClass === "INVALID_PARAMETER_VALUE") | |
+ assert(e.getSqlState === "22023") | |
+ assert(e.getMessage === | |
+ "The value of parameter(s) 'expr, key' in the `aes_encrypt`/`aes_decrypt` function " + | |
+ "is invalid: Detail message: " + | |
+ "Given final block not properly padded. " + | |
+ "Such issues can arise if a bad key is used during decryption.") | |
+ } | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: unsupported combinations of AES modes and padding") { | |
+ val key16 = "abcdefghijklmnop" | |
+ val key32 = "abcdefghijklmnop12345678ABCDEFGH" | |
+ val (df1, df2) = getAesInputs() | |
+ def checkUnsupportedMode(df: => DataFrame): Unit = { | |
+ val e = intercept[SparkException] { | |
+ df.collect | |
+ }.getCause.asInstanceOf[SparkRuntimeException] | |
+ assert(e.getErrorClass === "UNSUPPORTED_FEATURE") | |
+ assert(e.getSqlState === "0A000") | |
+ assert(e.getMessage.matches("""The feature is not supported: AES-\w+ with the padding \w+""" + | |
+ " by the `aes_encrypt`/`aes_decrypt` function.")) | |
+ } | |
+ | |
+ // Unsupported AES mode and padding in encrypt | |
+ checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'CBC')")) | |
+ checkUnsupportedMode(df1.selectExpr(s"aes_encrypt(value, '$key16', 'ECB', 'NoPadding')")) | |
+ | |
+ // Unsupported AES mode and padding in decrypt | |
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GSM')")) | |
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value16, '$key16', 'GCM', 'PKCS')")) | |
+ checkUnsupportedMode(df2.selectExpr(s"aes_decrypt(value32, '$key32', 'ECB', 'None')")) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { | |
+ def checkUnsupportedTypeInLiteral(v: Any): Unit = { | |
+ val e1 = intercept[SparkRuntimeException] { lit(v) } | |
+ assert(e1.getErrorClass === "UNSUPPORTED_FEATURE") | |
+ assert(e1.getSqlState === "0A000") | |
+ assert(e1.getMessage.matches("""The feature is not supported: literal for '.+' of .+\.""")) | |
+ } | |
+ checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) | |
+ checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) | |
+ | |
+ val e2 = intercept[SparkRuntimeException] { | |
+ trainingSales | |
+ .groupBy($"sales.year") | |
+ .pivot(struct(lower(trainingSales("sales.course")), trainingSales("training"))) | |
+ .agg(sum($"sales.earnings")) | |
+ .collect() | |
+ } | |
+ assert(e2.getMessage === "The feature is not supported: pivoting by the value" + | |
+ """ '[dotnet,Dummies]' of the column data type "STRUCT<col1: STRING, training: STRING>".""") | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: unsupported pivot operations") { | |
+ val e1 = intercept[SparkUnsupportedOperationException] { | |
+ trainingSales | |
+ .groupBy($"sales.year") | |
+ .pivot($"sales.course") | |
+ .pivot($"training") | |
+ .agg(sum($"sales.earnings")) | |
+ .collect() | |
+ } | |
+ assert(e1.getErrorClass === "UNSUPPORTED_FEATURE") | |
+ assert(e1.getSqlState === "0A000") | |
+ assert(e1.getMessage === """The feature is not supported: Repeated PIVOTs.""") | |
+ | |
+ val e2 = intercept[SparkUnsupportedOperationException] { | |
+ trainingSales | |
+ .rollup($"sales.year") | |
+ .pivot($"training") | |
+ .agg(sum($"sales.earnings")) | |
+ .collect() | |
+ } | |
+ assert(e2.getErrorClass === "UNSUPPORTED_FEATURE") | |
+ assert(e2.getSqlState === "0A000") | |
+ assert(e2.getMessage === """The feature is not supported: PIVOT not after a GROUP BY.""") | |
+ } | |
+ | |
+ test("INCONSISTENT_BEHAVIOR_CROSS_VERSION: " + | |
+ "compatibility with Spark 2.4/3.2 in reading/writing dates") { | |
+ | |
+ // Fail to read ancient datetime values. | |
+ withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_READ.key -> EXCEPTION.toString) { | |
+ val fileName = "before_1582_date_v2_4_5.snappy.parquet" | |
+ val filePath = getResourceParquetFilePath("test-data/" + fileName) | |
+ val e = intercept[SparkException] { | |
+ spark.read.parquet(filePath).collect() | |
+ }.getCause.asInstanceOf[SparkUpgradeException] | |
+ | |
+ val format = "Parquet" | |
+ val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_READ.key + "\"" | |
+ val option = "\"" + "datetimeRebaseMode" + "\"" | |
+ assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION") | |
+ assert(e.getMessage === | |
+ "You may get a different result due to the upgrading to Spark >= 3.0: " + | |
+ s""" | |
+ |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z | |
+ |from $format files can be ambiguous, as the files may be written by | |
+ |Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar | |
+ |that is different from Spark 3.0+'s Proleptic Gregorian calendar. | |
+ |See more details in SPARK-31404. You can set the SQL config $config or | |
+ |the datasource option $option to "LEGACY" to rebase the datetime values | |
+ |w.r.t. the calendar difference during reading. To read the datetime values | |
+ |as it is, set the SQL config $config or the datasource option $option | |
+ |to "CORRECTED". | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ // Fail to write ancient datetime values. | |
+ withSQLConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> EXCEPTION.toString) { | |
+ withTempPath { dir => | |
+ val df = Seq(java.sql.Date.valueOf("1001-01-01")).toDF("dt") | |
+ val e = intercept[SparkException] { | |
+ df.write.parquet(dir.getCanonicalPath) | |
+ }.getCause.getCause.getCause.asInstanceOf[SparkUpgradeException] | |
+ | |
+ val format = "Parquet" | |
+ val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + "\"" | |
+ assert(e.getErrorClass === "INCONSISTENT_BEHAVIOR_CROSS_VERSION") | |
+ assert(e.getMessage === | |
+ "You may get a different result due to the upgrading to Spark >= 3.0: " + | |
+ s""" | |
+ |writing dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z | |
+ |into $format files can be dangerous, as the files may be read by Spark 2.x | |
+ |or legacy versions of Hive later, which uses a legacy hybrid calendar that | |
+ |is different from Spark 3.0+'s Proleptic Gregorian calendar. See more | |
+ |details in SPARK-31404. You can set $config to "LEGACY" to rebase the | |
+ |datetime values w.r.t. the calendar difference during writing, to get maximum | |
+ |interoperability. Or set $config to "CORRECTED" to write the datetime | |
+ |values as it is, if you are 100% sure that the written files will only be read by | |
+ |Spark 3.0+ or other systems that use Proleptic Gregorian calendar. | |
+ |""".stripMargin) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("UNSUPPORTED_OPERATION - SPARK-36346: can't read Timestamp as TimestampNTZ") { | |
+ withTempPath { file => | |
+ sql("select timestamp_ltz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath) | |
+ withAllNativeOrcReaders { | |
+ val e = intercept[SparkException] { | |
+ spark.read.schema("time timestamp_ntz").orc(file.getCanonicalPath).collect() | |
+ }.getCause.asInstanceOf[SparkUnsupportedOperationException] | |
+ | |
+ assert(e.getErrorClass === "UNSUPPORTED_OPERATION") | |
+ assert(e.getMessage === "The operation is not supported: " + | |
+ "Unable to convert \"TIMESTAMP\" of Orc to data type \"TIMESTAMP_NTZ\".") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("UNSUPPORTED_OPERATION - SPARK-38504: can't read TimestampNTZ as TimestampLTZ") { | |
+ withTempPath { file => | |
+ sql("select timestamp_ntz'2019-03-21 00:02:03'").write.orc(file.getCanonicalPath) | |
+ withAllNativeOrcReaders { | |
+ val e = intercept[SparkException] { | |
+ spark.read.schema("time timestamp_ltz").orc(file.getCanonicalPath).collect() | |
+ }.getCause.asInstanceOf[SparkUnsupportedOperationException] | |
+ | |
+ assert(e.getErrorClass === "UNSUPPORTED_OPERATION") | |
+ assert(e.getMessage === "The operation is not supported: " + | |
+ "Unable to convert \"TIMESTAMP_NTZ\" of Orc to data type \"TIMESTAMP\".") | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("DATETIME_OVERFLOW: timestampadd() overflows its input timestamp") { | |
+ val e = intercept[SparkArithmeticException] { | |
+ sql("select timestampadd(YEAR, 1000000, timestamp'2022-03-09 01:02:03')").collect() | |
+ } | |
+ assert(e.getErrorClass === "DATETIME_OVERFLOW") | |
+ assert(e.getSqlState === "22008") | |
+ assert(e.getMessage === | |
+ "Datetime operation overflow: add 1000000 YEAR to TIMESTAMP '2022-03-09 01:02:03'.") | |
+ } | |
+ | |
+ test("UNSUPPORTED_SAVE_MODE: unsupported null saveMode whether the path exists or not") { | |
+ withTempPath { path => | |
+ val e1 = intercept[SparkIllegalArgumentException] { | |
+ val saveMode: SaveMode = null | |
+ Seq(1, 2).toDS().write.mode(saveMode).parquet(path.getAbsolutePath) | |
+ } | |
+ assert(e1.getErrorClass === "UNSUPPORTED_SAVE_MODE") | |
+ assert(e1.getMessage === "The save mode NULL is not supported for: a non-existent path.") | |
+ | |
+ Utils.createDirectory(path) | |
+ | |
+ val e2 = intercept[SparkIllegalArgumentException] { | |
+ val saveMode: SaveMode = null | |
+ Seq(1, 2).toDS().write.mode(saveMode).parquet(path.getAbsolutePath) | |
+ } | |
+ assert(e2.getErrorClass === "UNSUPPORTED_SAVE_MODE") | |
+ assert(e2.getMessage === "The save mode NULL is not supported for: an existent path.") | |
+ } | |
+ } | |
+ | |
+ test("INVALID_BUCKET_FILE: error if there exists any malformed bucket files") { | |
+ val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)). | |
+ toDF("i", "j", "k").as("df1") | |
+ | |
+ withTable("bucketed_table") { | |
+ df1.write.format("parquet").bucketBy(8, "i"). | |
+ saveAsTable("bucketed_table") | |
+ val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath | |
+ val tableDir = new File(warehouseFilePath, "bucketed_table") | |
+ Utils.deleteRecursively(tableDir) | |
+ df1.write.parquet(tableDir.getAbsolutePath) | |
+ | |
+ val aggregated = spark.table("bucketed_table").groupBy("i").count() | |
+ | |
+ val e = intercept[SparkException] { | |
+ aggregated.count() | |
+ } | |
+ assert(e.getErrorClass === "INVALID_BUCKET_FILE") | |
+ assert(e.getMessage.matches("Invalid bucket file: .+")) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala | |
new file mode 100644 | |
index 0000000000..6494e541d4 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala | |
@@ -0,0 +1,308 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.errors | |
+ | |
+import org.apache.spark.sql.QueryTest | |
+import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+// Turn of the length check because most of the tests check entire error messages | |
+// scalastyle:off line.size.limit | |
+class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession { | |
+ def validateParsingError( | |
+ sqlText: String, | |
+ errorClass: String, | |
+ sqlState: String, | |
+ message: String): Unit = { | |
+ val e = intercept[ParseException] { | |
+ sql(sqlText) | |
+ } | |
+ assert(e.getErrorClass === errorClass) | |
+ assert(e.getSqlState === sqlState) | |
+ assert(e.getMessage === message) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: LATERAL join with NATURAL join not supported") { | |
+ validateParsingError( | |
+ sqlText = "SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ """ | |
+ |The feature is not supported: LATERAL join with NATURAL join.(line 1, pos 14) | |
+ | | |
+ |== SQL == | |
+ |SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2) | |
+ |--------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: LATERAL join with USING join not supported") { | |
+ validateParsingError( | |
+ sqlText = "SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ """ | |
+ |The feature is not supported: LATERAL join with USING join.(line 1, pos 14) | |
+ | | |
+ |== SQL == | |
+ |SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2) | |
+ |--------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: Unsupported LATERAL join type") { | |
+ Seq("RIGHT OUTER", "FULL OUTER", "LEFT SEMI", "LEFT ANTI").foreach { joinType => | |
+ validateParsingError( | |
+ sqlText = s"SELECT * FROM t1 $joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ s""" | |
+ |The feature is not supported: LATERAL join type $joinType.(line 1, pos 14) | |
+ | | |
+ |== SQL == | |
+ |SELECT * FROM t1 $joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3 | |
+ |--------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ } | |
+ | |
+ test("INVALID_SQL_SYNTAX: LATERAL can only be used with subquery") { | |
+ Seq( | |
+ "SELECT * FROM t1, LATERAL t2" -> 26, | |
+ "SELECT * FROM t1 JOIN LATERAL t2" -> 30, | |
+ "SELECT * FROM t1, LATERAL (t2 JOIN t3)" -> 26, | |
+ "SELECT * FROM t1, LATERAL (LATERAL t2)" -> 26, | |
+ "SELECT * FROM t1, LATERAL VALUES (0, 1)" -> 26, | |
+ "SELECT * FROM t1, LATERAL RANGE(0, 1)" -> 26 | |
+ ).foreach { case (sqlText, pos) => | |
+ validateParsingError( | |
+ sqlText = sqlText, | |
+ errorClass = "INVALID_SQL_SYNTAX", | |
+ sqlState = "42000", | |
+ message = | |
+ s""" | |
+ |Invalid SQL syntax: LATERAL can only be used with subquery.(line 1, pos $pos) | |
+ | | |
+ |== SQL == | |
+ |$sqlText | |
+ |${"-" * pos}^^^ | |
+ |""".stripMargin) | |
+ } | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: NATURAL CROSS JOIN is not supported") { | |
+ validateParsingError( | |
+ sqlText = "SELECT * FROM a NATURAL CROSS JOIN b", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ """ | |
+ |The feature is not supported: NATURAL CROSS JOIN.(line 1, pos 14) | |
+ | | |
+ |== SQL == | |
+ |SELECT * FROM a NATURAL CROSS JOIN b | |
+ |--------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_SQL_SYNTAX: redefine window") { | |
+ validateParsingError( | |
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2", | |
+ errorClass = "INVALID_SQL_SYNTAX", | |
+ sqlState = "42000", | |
+ message = | |
+ """ | |
+ |Invalid SQL syntax: The definition of window `win` is repetitive.(line 1, pos 31) | |
+ | | |
+ |== SQL == | |
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2 | |
+ |-------------------------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_SQL_SYNTAX: invalid window reference") { | |
+ validateParsingError( | |
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win", | |
+ errorClass = "INVALID_SQL_SYNTAX", | |
+ sqlState = "42000", | |
+ message = | |
+ """ | |
+ |Invalid SQL syntax: Window reference `win` is not a window specification.(line 1, pos 31) | |
+ | | |
+ |== SQL == | |
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win | |
+ |-------------------------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_SQL_SYNTAX: window reference cannot be resolved") { | |
+ validateParsingError( | |
+ sqlText = "SELECT min(a) OVER win FROM t1 WINDOW win AS win2", | |
+ errorClass = "INVALID_SQL_SYNTAX", | |
+ sqlState = "42000", | |
+ message = | |
+ """ | |
+ |Invalid SQL syntax: Cannot resolve window reference `win2`.(line 1, pos 31) | |
+ | | |
+ |== SQL == | |
+ |SELECT min(a) OVER win FROM t1 WINDOW win AS win2 | |
+ |-------------------------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: TRANSFORM does not support DISTINCT/ALL") { | |
+ validateParsingError( | |
+ sqlText = "SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ """ | |
+ |The feature is not supported: TRANSFORM does not support DISTINCT/ALL in inputs(line 1, pos 17) | |
+ | | |
+ |== SQL == | |
+ |SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t | |
+ |-----------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: In-memory mode does not support TRANSFORM with serde") { | |
+ validateParsingError( | |
+ sqlText = "SELECT TRANSFORM(a) ROW FORMAT SERDE " + | |
+ "'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t", | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ """ | |
+ |The feature is not supported: TRANSFORM with serde is only supported in hive mode(line 1, pos 0) | |
+ | | |
+ |== SQL == | |
+ |SELECT TRANSFORM(a) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t | |
+ |^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_SQL_SYNTAX: Too many arguments for transform") { | |
+ validateParsingError( | |
+ sqlText = "CREATE TABLE table(col int) PARTITIONED BY (years(col,col))", | |
+ errorClass = "INVALID_SQL_SYNTAX", | |
+ sqlState = "42000", | |
+ message = | |
+ """ | |
+ |Invalid SQL syntax: Too many arguments for transform `years`(line 1, pos 44) | |
+ | | |
+ |== SQL == | |
+ |CREATE TABLE table(col int) PARTITIONED BY (years(col,col)) | |
+ |--------------------------------------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: cannot set reserved namespace property") { | |
+ val sql = "CREATE NAMESPACE IF NOT EXISTS a.b.c WITH PROPERTIES ('location'='/home/user/db')" | |
+ val msg = """The feature is not supported: location is a reserved namespace property, """ + | |
+ """please use the LOCATION clause to specify it.(line 1, pos 0)""" | |
+ validateParsingError( | |
+ sqlText = sql, | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ s""" | |
+ |$msg | |
+ | | |
+ |== SQL == | |
+ |$sql | |
+ |^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: cannot set reserved table property") { | |
+ val sql = "CREATE TABLE student (id INT, name STRING, age INT) " + | |
+ "USING PARQUET TBLPROPERTIES ('provider'='parquet')" | |
+ val msg = """The feature is not supported: provider is a reserved table property, """ + | |
+ """please use the USING clause to specify it.(line 1, pos 66)""" | |
+ validateParsingError( | |
+ sqlText = sql, | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ s""" | |
+ |$msg | |
+ | | |
+ |== SQL == | |
+ |$sql | |
+ |------------------------------------------------------------------^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_PROPERTY_KEY: invalid property key for set quoted configuration") { | |
+ val sql = "set =`value`" | |
+ val msg = """"" is an invalid property key, please use quotes, """ + | |
+ """e.g. SET ""="value"(line 1, pos 0)""" | |
+ validateParsingError( | |
+ sqlText = sql, | |
+ errorClass = "INVALID_PROPERTY_KEY", | |
+ sqlState = null, | |
+ message = | |
+ s""" | |
+ |$msg | |
+ | | |
+ |== SQL == | |
+ |$sql | |
+ |^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("INVALID_PROPERTY_VALUE: invalid property value for set quoted configuration") { | |
+ val sql = "set `key`=1;2;;" | |
+ val msg = """"1;2;;" is an invalid property value, please use quotes, """ + | |
+ """e.g. SET "key"="1;2;;"(line 1, pos 0)""" | |
+ validateParsingError( | |
+ sqlText = sql, | |
+ errorClass = "INVALID_PROPERTY_VALUE", | |
+ sqlState = null, | |
+ message = | |
+ s""" | |
+ |$msg | |
+ | | |
+ |== SQL == | |
+ |$sql | |
+ |^^^ | |
+ |""".stripMargin) | |
+ } | |
+ | |
+ test("UNSUPPORTED_FEATURE: cannot set Properties and DbProperties at the same time") { | |
+ val sql = "CREATE NAMESPACE IF NOT EXISTS a.b.c WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') " + | |
+ "WITH DBPROPERTIES('a'='a', 'b'='b', 'c'='c')" | |
+ val msg = """The feature is not supported: set PROPERTIES and DBPROPERTIES at the same time.""" + | |
+ """(line 1, pos 0)""" | |
+ validateParsingError( | |
+ sqlText = sql, | |
+ errorClass = "UNSUPPORTED_FEATURE", | |
+ sqlState = "0A000", | |
+ message = | |
+ s""" | |
+ |$msg | |
+ | | |
+ |== SQL == | |
+ |$sql | |
+ |^^^ | |
+ |""".stripMargin) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala | |
index a33b9fad7f..06fc2022c0 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala | |
@@ -35,9 +35,9 @@ class AggregatingAccumulatorSuite | |
extends SparkFunSuite | |
with SharedSparkSession | |
with ExpressionEvalHelper { | |
- private val a = 'a.long | |
- private val b = 'b.string | |
- private val c = 'c.double | |
+ private val a = Symbol("a").long | |
+ private val b = Symbol("b").string | |
+ private val c = Symbol("c").double | |
private val inputAttributes = Seq(a, b, c) | |
private def str(s: String): UTF8String = UTF8String.fromString(s) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala | |
index 80a2ee03ca..09a880a706 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala | |
@@ -133,8 +133,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
""".stripMargin) | |
checkAnswer(query, identity, df.select( | |
- 'a.cast("string"), | |
- 'b.cast("string"), | |
+ Symbol("a").cast("string"), | |
+ Symbol("b").cast("string"), | |
'c.cast("string"), | |
'd.cast("string"), | |
'e.cast("string")).collect()) | |
@@ -164,7 +164,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
'b.cast("string").as("value")).collect()) | |
checkAnswer( | |
- df.select('a, 'b), | |
+ df.select(Symbol("a"), Symbol("b")), | |
(child: SparkPlan) => createScriptTransformationExec( | |
script = "cat", | |
output = Seq( | |
@@ -178,7 +178,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
'b.cast("string").as("value")).collect()) | |
checkAnswer( | |
- df.select('a), | |
+ df.select(Symbol("a")), | |
(child: SparkPlan) => createScriptTransformationExec( | |
script = "cat", | |
output = Seq( | |
@@ -242,7 +242,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
child = child, | |
ioschema = serde | |
), | |
- df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect()) | |
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"), | |
+ Symbol("f"), Symbol("g"), Symbol("h"), Symbol("i"), Symbol("j")).collect()) | |
} | |
} | |
} | |
@@ -282,7 +283,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
child = child, | |
ioschema = defaultIOSchema | |
), | |
- df.select('a, 'b, 'c, 'd, 'e).collect()) | |
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e")).collect()) | |
} | |
} | |
@@ -304,7 +305,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
|USING 'cat' AS (a timestamp, b date) | |
|FROM v | |
""".stripMargin) | |
- checkAnswer(query, identity, df.select('a, 'b).collect()) | |
+ checkAnswer(query, identity, df.select(Symbol("a"), Symbol("b")).collect()) | |
} | |
} | |
} | |
@@ -379,7 +380,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) | |
checkAnswer( | |
- df.select('a, 'b), | |
+ df.select(Symbol("a"), Symbol("b")), | |
(child: SparkPlan) => createScriptTransformationExec( | |
script = "cat", | |
output = Seq( | |
@@ -452,10 +453,10 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
(Array(6, 7, 8), Array(Array(6, 7), Array(8)), | |
Map("c" -> 3), Map("d" -> Array("e", "f"))) | |
).toDF("a", "b", "c", "d") | |
- .select('a, 'b, 'c, 'd, | |
- struct('a, 'b).as("e"), | |
- struct('a, 'd).as("f"), | |
- struct(struct('a, 'b), struct('a, 'd)).as("g") | |
+ .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), | |
+ struct(Symbol("a"), Symbol("b")).as("e"), | |
+ struct(Symbol("a"), Symbol("d")).as("f"), | |
+ struct(struct(Symbol("a"), Symbol("b")), struct(Symbol("a"), Symbol("d"))).as("g") | |
) | |
checkAnswer( | |
@@ -483,7 +484,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
child = child, | |
ioschema = defaultIOSchema | |
), | |
- df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) | |
+ df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"), | |
+ Symbol("f"), Symbol("g")).collect()) | |
} | |
} | |
@@ -654,6 +656,20 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | |
df.select($"ym", $"dt").collect()) | |
} | |
} | |
+ | |
+ test("SPARK-36675: TRANSFORM should support timestamp_ntz (no serde)") { | |
+ val df = spark.sql("SELECT timestamp_ntz'2021-09-06 20:19:13' col") | |
+ checkAnswer( | |
+ df, | |
+ (child: SparkPlan) => createScriptTransformationExec( | |
+ script = "cat", | |
+ output = Seq( | |
+ AttributeReference("col", TimestampNTZType)()), | |
+ child = child, | |
+ ioschema = defaultIOSchema | |
+ ), | |
+ df.select($"col").collect()) | |
+ } | |
} | |
case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala | |
index 4ff96e6574..e4f17eb601 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala | |
@@ -26,9 +26,11 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { | |
test("basic") { | |
val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator | |
val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator | |
- val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) | |
- val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) | |
- val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) | |
+ val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), | |
+ Seq(Symbol("i").int, Symbol("s").string)) | |
+ val rightGrouped = GroupedIterator(rightInput, Seq(Symbol("i").int.at(0)), | |
+ Seq(Symbol("i").int, Symbol("l").long)) | |
+ val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq(Symbol("i").int)) | |
val result = cogrouped.map { | |
case (key, leftData, rightData) => | |
@@ -52,7 +54,8 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { | |
test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { | |
val leftInput = Seq(create_row(2, "a")).iterator | |
val rightInput = Seq(create_row(1, 2L)).iterator | |
- val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) | |
+ val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), | |
+ Seq(Symbol("i").int, Symbol("s").string)) | |
val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) | |
val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala | |
index 2a28517868..ddf4d421f3 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala | |
@@ -410,14 +410,14 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl | |
QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) | |
+ // Shuffle partition coalescing of the join is performed independent of the non-grouping | |
+ // aggregate on the other side of the union. | |
val finalPlan = resultDf.queryExecution.executedPlan | |
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan | |
- // As the pre-shuffle partition number are different, we will skip reducing | |
- // the shuffle partition numbers. | |
assert( | |
finalPlan.collect { | |
case r @ CoalescedShuffleRead() => r | |
- }.isEmpty) | |
+ }.size == 2) | |
} | |
withSparkSession(test, 100, None) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala | |
index 612cd6f0d8..e29b7f579f 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala | |
@@ -18,13 +18,11 @@ package org.apache.spark.sql.execution | |
import java.io.File | |
-import scala.collection.mutable | |
import scala.util.Random | |
import org.apache.hadoop.fs.Path | |
import org.apache.spark.SparkConf | |
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} | |
import org.apache.spark.sql.{DataFrame, QueryTest} | |
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec | |
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan | |
@@ -215,33 +213,4 @@ class DataSourceV2ScanExecRedactionSuite extends DataSourceScanRedactionTest { | |
} | |
} | |
} | |
- | |
- test("SPARK-30362: test input metrics for DSV2") { | |
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { | |
- Seq("json", "orc", "parquet").foreach { format => | |
- withTempPath { path => | |
- val dir = path.getCanonicalPath | |
- spark.range(0, 10).write.format(format).save(dir) | |
- val df = spark.read.format(format).load(dir) | |
- val bytesReads = new mutable.ArrayBuffer[Long]() | |
- val recordsRead = new mutable.ArrayBuffer[Long]() | |
- val bytesReadListener = new SparkListener() { | |
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { | |
- bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead | |
- recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead | |
- } | |
- } | |
- sparkContext.addSparkListener(bytesReadListener) | |
- try { | |
- df.collect() | |
- sparkContext.listenerBus.waitUntilEmpty() | |
- assert(bytesReads.sum > 0) | |
- assert(recordsRead.sum == 10) | |
- } finally { | |
- sparkContext.removeSparkListener(bytesReadListener) | |
- } | |
- } | |
- } | |
- } | |
- } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala | |
index b27a940c36..635c794338 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala | |
@@ -36,9 +36,9 @@ class DeprecatedWholeStageCodegenSuite extends QueryTest | |
.groupByKey(_._1).agg(typed.sum(_._2)) | |
val plan = ds.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) | |
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala | |
index 4b2a2b439c..06c51cee02 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala | |
@@ -32,7 +32,7 @@ class GroupedIteratorSuite extends SparkFunSuite { | |
val fromRow = encoder.createDeserializer() | |
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) | |
val grouped = GroupedIterator(input.iterator.map(toRow), | |
- Seq('i.int.at(0)), schema.toAttributes) | |
+ Seq(Symbol("i").int.at(0)), schema.toAttributes) | |
val result = grouped.map { | |
case (key, data) => | |
@@ -59,7 +59,7 @@ class GroupedIteratorSuite extends SparkFunSuite { | |
Row(3, 2L, "e")) | |
val grouped = GroupedIterator(input.iterator.map(toRow), | |
- Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) | |
+ Seq(Symbol("i").int.at(0), Symbol("l").long.at(1)), schema.toAttributes) | |
val result = grouped.map { | |
case (key, data) => | |
@@ -80,7 +80,7 @@ class GroupedIteratorSuite extends SparkFunSuite { | |
val toRow = encoder.createSerializer() | |
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) | |
val grouped = GroupedIterator(input.iterator.map(toRow), | |
- Seq('i.int.at(0)), schema.toAttributes) | |
+ Seq(Symbol("i").int.at(0)), schema.toAttributes) | |
assert(grouped.length == 2) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala | |
index 5bcec9b1e5..743ec41dbe 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala | |
@@ -21,7 +21,7 @@ import scala.reflect.ClassTag | |
import org.apache.spark.sql.TPCDSQuerySuite | |
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} | |
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} | |
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window, WithCTE} | |
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite | |
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} | |
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} | |
@@ -48,6 +48,7 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv | |
// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. | |
// We may add `ColumnarToRowExec` and `InputAdapter` above the scan node after planning. | |
+ @scala.annotation.tailrec | |
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { | |
case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child) | |
case p: ProjectExec => isScanPlanTree(p.child) | |
@@ -107,7 +108,11 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv | |
// logical = Project(Filter(Scan A)) | |
// physical = ProjectExec(ScanExec A) | |
// we only check that leaf modes match between logical and physical plan. | |
- val logicalLeaves = getLogicalPlan(actualPlan).collectLeaves() | |
+ val logicalPlan = getLogicalPlan(actualPlan) match { | |
+ case w: WithCTE => w.plan | |
+ case o => o | |
+ } | |
+ val logicalLeaves = logicalPlan.collectLeaves() | |
val physicalLeaves = plan.collectLeaves() | |
assert(logicalLeaves.length == 1) | |
assert(physicalLeaves.length == 1) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala | |
index df310cbaee..000bd8c84f 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala | |
@@ -59,18 +59,21 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { | |
} | |
test("count is partially aggregated") { | |
- val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed | |
+ val query = testData.groupBy(Symbol("value")).agg(count(Symbol("key"))).queryExecution.analyzed | |
testPartialAggregationPlan(query) | |
} | |
test("count distinct is partially aggregated") { | |
- val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed | |
+ val query = testData.groupBy(Symbol("value")).agg(count_distinct(Symbol("key"))) | |
+ .queryExecution.analyzed | |
testPartialAggregationPlan(query) | |
} | |
test("mixed aggregates are partially aggregated") { | |
val query = | |
- testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed | |
+ testData.groupBy(Symbol("value")) | |
+ .agg(count(Symbol("value")), count_distinct(Symbol("key"))) | |
+ .queryExecution.analyzed | |
testPartialAggregationPlan(query) | |
} | |
@@ -193,47 +196,49 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { | |
} | |
test("efficient terminal limit -> sort should use TakeOrderedAndProject") { | |
- val query = testData.select('key, 'value).sort('key).limit(2) | |
+ val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")).limit(2) | |
val planned = query.queryExecution.executedPlan | |
assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) | |
- assert(planned.output === testData.select('key, 'value).logicalPlan.output) | |
+ assert(planned.output === testData.select(Symbol("key"), Symbol("value")).logicalPlan.output) | |
} | |
test("terminal limit -> project -> sort should use TakeOrderedAndProject") { | |
- val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) | |
+ val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")) | |
+ .select(Symbol("value"), Symbol("key")).limit(2) | |
val planned = query.queryExecution.executedPlan | |
assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) | |
- assert(planned.output === testData.select('value, 'key).logicalPlan.output) | |
+ assert(planned.output === testData.select(Symbol("value"), Symbol("key")).logicalPlan.output) | |
} | |
test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { | |
- val query = testData.select('value).limit(2) | |
+ val query = testData.select(Symbol("value")).limit(2) | |
val planned = query.queryExecution.sparkPlan | |
assert(planned.isInstanceOf[CollectLimitExec]) | |
- assert(planned.output === testData.select('value).logicalPlan.output) | |
+ assert(planned.output === testData.select(Symbol("value")).logicalPlan.output) | |
} | |
test("TakeOrderedAndProject can appear in the middle of plans") { | |
- val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) | |
+ val query = testData.select(Symbol("key"), Symbol("value")) | |
+ .sort(Symbol("key")).limit(2).filter('key === 3) | |
val planned = query.queryExecution.executedPlan | |
- assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) | |
+ assert(planned.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) | |
} | |
test("CollectLimit can appear in the middle of a plan when caching is used") { | |
- val query = testData.select('key, 'value).limit(2).cache() | |
+ val query = testData.select(Symbol("key"), Symbol("value")).limit(2).cache() | |
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] | |
assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) | |
} | |
test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { | |
withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { | |
- val query0 = testData.select('value).orderBy('key).limit(100) | |
+ val query0 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(100) | |
val planned0 = query0.queryExecution.executedPlan | |
- assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) | |
+ assert(planned0.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) | |
- val query1 = testData.select('value).orderBy('key).limit(2000) | |
+ val query1 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(2000) | |
val planned1 = query1.queryExecution.executedPlan | |
- assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) | |
+ assert(!planned1.exists(_.isInstanceOf[TakeOrderedAndProjectExec])) | |
} | |
} | |
@@ -432,7 +437,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { | |
} | |
test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { | |
- val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) | |
+ val distribution = ClusteredDistribution(Literal(1) :: Nil, requiredNumPartitions = Some(13)) | |
// Number of partitions differ | |
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) | |
val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala | |
index 9636f33ff6..41a1cd9b29 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala | |
@@ -20,10 +20,11 @@ import scala.io.Source | |
import org.apache.spark.sql.{AnalysisException, FastOperator} | |
import org.apache.spark.sql.catalyst.analysis.UnresolvedNamespace | |
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow | |
import org.apache.spark.sql.catalyst.plans.QueryPlan | |
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan, OneRowRelation, Project, ShowTables, SubqueryAlias} | |
import org.apache.spark.sql.catalyst.trees.TreeNodeTag | |
-import org.apache.spark.sql.execution.command.{ExecutedCommandExec, ShowTablesCommand} | |
+import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.util.Utils | |
@@ -227,7 +228,8 @@ class QueryExecutionSuite extends SharedSparkSession { | |
} | |
Seq("=== Applying Rule org.apache.spark.sql.execution", | |
"=== Result of Batch Preparations ===").foreach { expectedMsg => | |
- assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) | |
+ assert(testAppender.loggingEvents.exists( | |
+ _.getMessage.getFormattedMessage.contains(expectedMsg))) | |
} | |
} | |
@@ -247,9 +249,7 @@ class QueryExecutionSuite extends SharedSparkSession { | |
assert(showTablesQe.commandExecuted.isInstanceOf[CommandResult]) | |
assert(showTablesQe.executedPlan.isInstanceOf[CommandResultExec]) | |
val showTablesResultExec = showTablesQe.executedPlan.asInstanceOf[CommandResultExec] | |
- assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec]) | |
- assert(showTablesResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec] | |
- .cmd.isInstanceOf[ShowTablesCommand]) | |
+ assert(showTablesResultExec.commandPhysicalPlan.isInstanceOf[ShowTablesExec]) | |
val project = Project(showTables.output, SubqueryAlias("s", showTables)) | |
val projectQe = qe(project) | |
@@ -260,9 +260,15 @@ class QueryExecutionSuite extends SharedSparkSession { | |
assert(projectQe.commandExecuted.children(0).children(0).isInstanceOf[CommandResult]) | |
assert(projectQe.executedPlan.isInstanceOf[CommandResultExec]) | |
val cmdResultExec = projectQe.executedPlan.asInstanceOf[CommandResultExec] | |
- assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ExecutedCommandExec]) | |
- assert(cmdResultExec.commandPhysicalPlan.asInstanceOf[ExecutedCommandExec] | |
- .cmd.isInstanceOf[ShowTablesCommand]) | |
+ assert(cmdResultExec.commandPhysicalPlan.isInstanceOf[ShowTablesExec]) | |
+ } | |
+ | |
+ test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute methods") { | |
+ val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan | |
+ assert(plan.isInstanceOf[CommandResultExec]) | |
+ plan.executeCollect().foreach { row => assert(row.isInstanceOf[UnsafeRow]) } | |
+ plan.executeTake(10).foreach { row => assert(row.isInstanceOf[UnsafeRow]) } | |
+ plan.executeTail(10).foreach { row => assert(row.isInstanceOf[UnsafeRow]) } | |
} | |
test("SPARK-38198: check specify maxFields when call toFile method") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala | |
index 751078d08f..21702b6cf5 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala | |
@@ -51,7 +51,7 @@ abstract class RemoveRedundantSortsSuiteBase | |
test("remove redundant sorts with limit") { | |
withTempView("t") { | |
- spark.range(100).select('id as "key").createOrReplaceTempView("t") | |
+ spark.range(100).select(Symbol("id") as "key").createOrReplaceTempView("t") | |
val query = | |
""" | |
|SELECT key FROM | |
@@ -64,8 +64,8 @@ abstract class RemoveRedundantSortsSuiteBase | |
test("remove redundant sorts with broadcast hash join") { | |
withTempView("t1", "t2") { | |
- spark.range(1000).select('id as "key").createOrReplaceTempView("t1") | |
- spark.range(1000).select('id as "key").createOrReplaceTempView("t2") | |
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") | |
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") | |
val queryTemplate = """ | |
|SELECT /*+ BROADCAST(%s) */ t1.key FROM | |
@@ -100,8 +100,8 @@ abstract class RemoveRedundantSortsSuiteBase | |
test("remove redundant sorts with sort merge join") { | |
withTempView("t1", "t2") { | |
- spark.range(1000).select('id as "key").createOrReplaceTempView("t1") | |
- spark.range(1000).select('id as "key").createOrReplaceTempView("t2") | |
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") | |
+ spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") | |
val query = """ | |
|SELECT /*+ MERGE(t1) */ t1.key FROM | |
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 | |
@@ -123,8 +123,8 @@ abstract class RemoveRedundantSortsSuiteBase | |
test("cached sorted data doesn't need to be re-sorted") { | |
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { | |
- val df = spark.range(1000).select('id as "key").sort('key.desc).cache() | |
- val resorted = df.sort('key.desc) | |
+ val df = spark.range(1000).select(Symbol("id") as "key").sort(Symbol("key").desc).cache() | |
+ val resorted = df.sort(Symbol("key").desc) | |
val sortedAsc = df.sort('key.asc) | |
checkNumSorts(df, 0) | |
checkNumSorts(resorted, 0) | |
@@ -140,7 +140,7 @@ abstract class RemoveRedundantSortsSuiteBase | |
test("SPARK-33472: shuffled join with different left and right side partition numbers") { | |
withTempView("t1", "t2") { | |
- spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1") | |
+ spark.range(0, 100, 1, 2).select(Symbol("id") as "key").createOrReplaceTempView("t1") | |
(0 to 100).toDF("key").createOrReplaceTempView("t2") | |
val queryTemplate = """ | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala | |
new file mode 100644 | |
index 0000000000..47679ed786 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala | |
@@ -0,0 +1,141 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution | |
+ | |
+import org.apache.spark.sql.{DataFrame, QueryTest} | |
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} | |
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.test.SharedSparkSession | |
+ | |
+abstract class ReplaceHashWithSortAggSuiteBase | |
+ extends QueryTest | |
+ with SharedSparkSession | |
+ with AdaptiveSparkPlanHelper { | |
+ | |
+ private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { | |
+ val plan = df.queryExecution.executedPlan | |
+ assert(collectWithSubqueries(plan) { | |
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s | |
+ }.length == hashAggCount) | |
+ assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) | |
+ } | |
+ | |
+ private def checkAggs( | |
+ query: String, | |
+ enabledHashAggCount: Int, | |
+ enabledSortAggCount: Int, | |
+ disabledHashAggCount: Int, | |
+ disabledSortAggCount: Int): Unit = { | |
+ withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "true") { | |
+ val df = sql(query) | |
+ checkNumAggs(df, enabledHashAggCount, enabledSortAggCount) | |
+ val result = df.collect() | |
+ withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "false") { | |
+ val df = sql(query) | |
+ checkNumAggs(df, disabledHashAggCount, disabledSortAggCount) | |
+ checkAnswer(df, result) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("replace partial hash aggregate with sort aggregate") { | |
+ withTempView("t") { | |
+ spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t") | |
+ Seq("FIRST", "COLLECT_LIST").foreach { aggExpr => | |
+ val query = | |
+ s""" | |
+ |SELECT key, $aggExpr(key) | |
+ |FROM | |
+ |( | |
+ | SELECT key | |
+ | FROM t | |
+ | WHERE key > 10 | |
+ | SORT BY key | |
+ |) | |
+ |GROUP BY key | |
+ """.stripMargin | |
+ checkAggs(query, 1, 1, 2, 0) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("replace partial and final hash aggregate together with sort aggregate") { | |
+ withTempView("t1", "t2") { | |
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | |
+ spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") | |
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => | |
+ val query = | |
+ s""" | |
+ |SELECT key, $aggExpr(key) | |
+ |FROM | |
+ |( | |
+ | SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key | |
+ | FROM t1 | |
+ | JOIN t2 | |
+ | ON t1.key = t2.key | |
+ |) | |
+ |GROUP BY key | |
+ """.stripMargin | |
+ checkAggs(query, 0, 1, 2, 0) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("do not replace hash aggregate if child does not have sort order") { | |
+ withTempView("t1", "t2") { | |
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | |
+ spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") | |
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => | |
+ val query = | |
+ s""" | |
+ |SELECT key, $aggExpr(key) | |
+ |FROM | |
+ |( | |
+ | SELECT /*+ BROADCAST(t1) */ t1.key AS key | |
+ | FROM t1 | |
+ | JOIN t2 | |
+ | ON t1.key = t2.key | |
+ |) | |
+ |GROUP BY key | |
+ """.stripMargin | |
+ checkAggs(query, 2, 0, 2, 0) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("do not replace hash aggregate if there is no group-by column") { | |
+ withTempView("t1") { | |
+ spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | |
+ Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => | |
+ val query = | |
+ s""" | |
+ |SELECT $aggExpr(key) | |
+ |FROM t1 | |
+ """.stripMargin | |
+ checkAggs(query, 2, 0, 2, 0) | |
+ } | |
+ } | |
+ } | |
+} | |
+ | |
+class ReplaceHashWithSortAggSuite extends ReplaceHashWithSortAggSuiteBase | |
+ with DisableAdaptiveExecutionSuite | |
+ | |
+class ReplaceHashWithSortAggSuiteAE extends ReplaceHashWithSortAggSuiteBase | |
+ with EnableAdaptiveExecutionSuite | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala | |
index 81e692076b..740c10f17b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala | |
@@ -17,17 +17,21 @@ | |
package org.apache.spark.sql.execution | |
+import java.util.Locale | |
import java.util.concurrent.Executors | |
+import java.util.concurrent.atomic.AtomicInteger | |
import scala.collection.parallel.immutable.ParRange | |
import scala.concurrent.{ExecutionContext, Future} | |
import scala.concurrent.duration._ | |
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} | |
-import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} | |
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} | |
import org.apache.spark.sql.{Row, SparkSession} | |
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.util.ThreadUtils | |
+import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT | |
class SQLExecutionSuite extends SparkFunSuite { | |
@@ -157,6 +161,45 @@ class SQLExecutionSuite extends SparkFunSuite { | |
} | |
} | |
} | |
+ | |
+ test("SPARK-34735: Add modified configs for SQL execution in UI") { | |
+ val spark = SparkSession.builder() | |
+ .master("local[*]") | |
+ .appName("test") | |
+ .config("k1", "v1") | |
+ .getOrCreate() | |
+ | |
+ try { | |
+ val index = new AtomicInteger(0) | |
+ spark.sparkContext.addSparkListener(new SparkListener { | |
+ override def onOtherEvent(event: SparkListenerEvent): Unit = event match { | |
+ case start: SparkListenerSQLExecutionStart => | |
+ if (index.get() == 0 && hasProject(start)) { | |
+ assert(!start.modifiedConfigs.contains("k1")) | |
+ index.incrementAndGet() | |
+ } else if (index.get() == 1 && hasProject(start)) { | |
+ assert(start.modifiedConfigs.contains("k2")) | |
+ assert(start.modifiedConfigs("k2") == "v2") | |
+ assert(start.modifiedConfigs.contains("redaction.password")) | |
+ assert(start.modifiedConfigs("redaction.password") == REDACTION_REPLACEMENT_TEXT) | |
+ index.incrementAndGet() | |
+ } | |
+ case _ => | |
+ } | |
+ | |
+ private def hasProject(start: SparkListenerSQLExecutionStart): Boolean = | |
+ start.physicalPlanDescription.toLowerCase(Locale.ROOT).contains("project") | |
+ }) | |
+ spark.sql("SELECT 1").collect() | |
+ spark.sql("SET k2 = v2") | |
+ spark.sql("SET redaction.password = 123") | |
+ spark.sql("SELECT 1").collect() | |
+ spark.sparkContext.listenerBus.waitUntilEmpty() | |
+ assert(index.get() == 2) | |
+ } finally { | |
+ spark.stop() | |
+ } | |
+ } | |
} | |
object SQLExecutionSuite { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala | |
index 08789e63fa..55f1713422 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala | |
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution | |
import org.json4s.jackson.JsonMethods._ | |
import org.apache.spark.SparkFunSuite | |
+import org.apache.spark.scheduler.SparkListenerEvent | |
import org.apache.spark.sql.LocalSparkSession | |
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} | |
import org.apache.spark.sql.test.TestSparkSession | |
@@ -28,28 +29,46 @@ import org.apache.spark.util.JsonProtocol | |
class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { | |
test("SparkPlanGraph backward compatibility: metadata") { | |
- val SQLExecutionStartJsonString = | |
- """ | |
- |{ | |
- | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", | |
- | "executionId":0, | |
- | "description":"test desc", | |
- | "details":"test detail", | |
- | "physicalPlanDescription":"test plan", | |
- | "sparkPlanInfo": { | |
- | "nodeName":"TestNode", | |
- | "simpleString":"test string", | |
- | "children":[], | |
- | "metadata":{}, | |
- | "metrics":[] | |
- | }, | |
- | "time":0 | |
- |} | |
+ Seq(true, false).foreach { newExecutionStartEvent => | |
+ val event = if (newExecutionStartEvent) { | |
+ "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" | |
+ } else { | |
+ "org.apache.spark.sql.execution.OldVersionSQLExecutionStart" | |
+ } | |
+ val SQLExecutionStartJsonString = | |
+ s""" | |
+ |{ | |
+ | "Event":"$event", | |
+ | "executionId":0, | |
+ | "description":"test desc", | |
+ | "details":"test detail", | |
+ | "physicalPlanDescription":"test plan", | |
+ | "sparkPlanInfo": { | |
+ | "nodeName":"TestNode", | |
+ | "simpleString":"test string", | |
+ | "children":[], | |
+ | "metadata":{}, | |
+ | "metrics":[] | |
+ | }, | |
+ | "time":0, | |
+ | "modifiedConfigs": { | |
+ | "k1":"v1" | |
+ | } | |
+ |} | |
""".stripMargin | |
- val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) | |
- val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", | |
- new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) | |
- assert(reconstructedEvent == expectedEvent) | |
+ | |
+ val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) | |
+ if (newExecutionStartEvent) { | |
+ val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", | |
+ "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0, | |
+ Map("k1" -> "v1")) | |
+ assert(reconstructedEvent == expectedEvent) | |
+ } else { | |
+ val expectedOldEvent = OldVersionSQLExecutionStart(0, "test desc", "test detail", | |
+ "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) | |
+ assert(reconstructedEvent == expectedOldEvent) | |
+ } | |
+ } | |
} | |
test("SparkListenerSQLExecutionEnd backward compatibility") { | |
@@ -77,3 +96,12 @@ class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { | |
assert(readBack == event) | |
} | |
} | |
+ | |
+private case class OldVersionSQLExecutionStart( | |
+ executionId: Long, | |
+ description: String, | |
+ details: String, | |
+ physicalPlanDescription: String, | |
+ sparkPlanInfo: SparkPlanInfo, | |
+ time: Long) | |
+ extends SparkListenerEvent | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala | |
index cf699d3234..52aa1066f5 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala | |
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution | |
import org.apache.spark.SparkException | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.catalyst.TableIdentifier | |
-import org.apache.spark.sql.catalyst.analysis.NoSuchTableException | |
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Divide} | |
import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.catalyst.plans.logical.Project | |
+import org.apache.spark.sql.catalyst.trees.Origin | |
import org.apache.spark.sql.internal.SQLConf._ | |
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} | |
@@ -220,12 +222,6 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
} | |
} | |
- private def assertNoSuchTable(query: String): Unit = { | |
- intercept[NoSuchTableException] { | |
- sql(query) | |
- } | |
- } | |
- | |
private def assertAnalysisError(query: String, message: String): Unit = { | |
val e = intercept[AnalysisException](sql(query)) | |
assert(e.message.contains(message)) | |
@@ -600,7 +596,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
spark.range(10).write.saveAsTable("add_col") | |
withView("v") { | |
sql("CREATE VIEW v AS SELECT * FROM add_col") | |
- spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") | |
+ spark.range(10).select(Symbol("id"), 'id as Symbol("a")) | |
+ .write.mode("overwrite").saveAsTable("add_col") | |
checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF()) | |
} | |
} | |
@@ -772,7 +769,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
withTable("t") { | |
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") | |
withTempView("v1") { | |
- sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0") | |
+ withSQLConf(ANSI_ENABLED.key -> "false") { | |
+ sql("CREATE TEMPORARY VIEW v1 AS SELECT 1/0") | |
+ } | |
withSQLConf( | |
USE_CURRENT_SQL_CONFIGS_FOR_VIEW.key -> "true", | |
ANSI_ENABLED.key -> "true") { | |
@@ -845,7 +844,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
sql("CREATE VIEW v2 (c1) AS SELECT c1 FROM t ORDER BY 1 ASC, c1 DESC") | |
sql("CREATE VIEW v3 (c1, count) AS SELECT c1, count(c1) AS cnt FROM t GROUP BY 1") | |
sql("CREATE VIEW v4 (a, count) AS SELECT c1 as a, count(c1) AS cnt FROM t GROUP BY a") | |
- sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid") | |
+ withSQLConf(ANSI_ENABLED.key -> "false") { | |
+ sql("CREATE VIEW v5 (c1) AS SELECT 1/0 AS invalid") | |
+ } | |
withSQLConf(CASE_SENSITIVE.key -> "true") { | |
checkAnswer(sql("SELECT * FROM v1"), Seq(Row(2), Row(3), Row(1))) | |
@@ -869,9 +870,9 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
withSQLConf(CASE_SENSITIVE.key -> "true") { | |
val e = intercept[AnalysisException] { | |
sql("SELECT * FROM v1") | |
- }.getMessage | |
- assert(e.contains("cannot resolve 'C1' given input columns: " + | |
- "[spark_catalog.default.t.c1]")) | |
+ } | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("C1", "spark_catalog.default.t.c1"))) | |
} | |
withSQLConf(ORDER_BY_ORDINAL.key -> "false") { | |
checkAnswer(sql("SELECT * FROM v2"), Seq(Row(3), Row(2), Row(1))) | |
@@ -888,15 +889,15 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
withSQLConf(GROUP_BY_ALIASES.key -> "false") { | |
val e = intercept[AnalysisException] { | |
sql("SELECT * FROM v4") | |
- }.getMessage | |
- assert(e.contains("cannot resolve 'a' given input columns: " + | |
- "[spark_catalog.default.t.c1]")) | |
+ } | |
+ assert(e.getErrorClass == "MISSING_COLUMN") | |
+ assert(e.messageParameters.sameElements(Array("a", "spark_catalog.default.t.c1"))) | |
} | |
withSQLConf(ANSI_ENABLED.key -> "true") { | |
val e = intercept[ArithmeticException] { | |
sql("SELECT * FROM v5").collect() | |
}.getMessage | |
- assert(e.contains("divide by zero")) | |
+ assert(e.contains("Division by zero")) | |
} | |
} | |
@@ -906,7 +907,59 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { | |
val e = intercept[ArithmeticException] { | |
sql("SELECT * FROM v1").collect() | |
}.getMessage | |
- assert(e.contains("divide by zero")) | |
+ assert(e.contains("Division by zero")) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("CurrentOrigin is correctly set in and out of the View") { | |
+ withTable("t") { | |
+ Seq((1, 1), (2, 2)).toDF("a", "b").write.format("parquet").saveAsTable("t") | |
+ Seq("VIEW", "TEMPORARY VIEW").foreach { viewType => | |
+ val viewId = "v" | |
+ withView(viewId) { | |
+ val viewText = "SELECT a + b c FROM t" | |
+ sql( | |
+ s""" | |
+ |CREATE $viewType $viewId AS | |
+ |-- the body of the view | |
+ |$viewText | |
+ |""".stripMargin) | |
+ val plan = sql("select c / 2.0D d from v").logicalPlan | |
+ val add = plan.collectFirst { | |
+ case Project(Seq(Alias(a: Add, _)), _) => a | |
+ } | |
+ assert(add.isDefined) | |
+ val qualifiedName = if (viewType == "VIEW") { | |
+ s"default.$viewId" | |
+ } else { | |
+ viewId | |
+ } | |
+ val expectedAddOrigin = Origin( | |
+ line = Some(1), | |
+ startPosition = Some(7), | |
+ startIndex = Some(7), | |
+ stopIndex = Some(11), | |
+ sqlText = Some("SELECT a + b c FROM t"), | |
+ objectType = Some("VIEW"), | |
+ objectName = Some(qualifiedName) | |
+ ) | |
+ assert(add.get.origin == expectedAddOrigin) | |
+ | |
+ val divide = plan.collectFirst { | |
+ case Project(Seq(Alias(d: Divide, _)), _) => d | |
+ } | |
+ assert(divide.isDefined) | |
+ val expectedDivideOrigin = Origin( | |
+ line = Some(1), | |
+ startPosition = Some(7), | |
+ startIndex = Some(7), | |
+ stopIndex = Some(14), | |
+ sqlText = Some("select c / 2.0D d from v"), | |
+ objectType = None, | |
+ objectName = None) | |
+ assert(divide.get.origin == expectedDivideOrigin) | |
+ } | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala | |
index 4ba2b1703c..1874aa15f8 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala | |
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} | |
import org.apache.spark.sql.catalyst.catalog.CatalogFunction | |
import org.apache.spark.sql.catalyst.expressions.Expression | |
import org.apache.spark.sql.catalyst.plans.logical.Repartition | |
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.withDefaultTimeZone | |
import org.apache.spark.sql.connector.catalog._ | |
import org.apache.spark.sql.internal.SQLConf._ | |
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} | |
@@ -45,10 +46,12 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { | |
viewName: String, | |
sqlText: String, | |
columnNames: Seq[String] = Seq.empty, | |
+ others: Seq[String] = Seq.empty, | |
replace: Boolean = false): String = { | |
val replaceString = if (replace) "OR REPLACE" else "" | |
val columnString = if (columnNames.nonEmpty) columnNames.mkString("(", ",", ")") else "" | |
- sql(s"CREATE $replaceString $viewTypeString $viewName $columnString AS $sqlText") | |
+ val othersString = if (others.nonEmpty) others.mkString(" ") else "" | |
+ sql(s"CREATE $replaceString $viewTypeString $viewName $columnString $othersString AS $sqlText") | |
formattedViewName(viewName) | |
} | |
@@ -117,11 +120,13 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { | |
test("change SQLConf should not change view behavior - ansiEnabled") { | |
withTable("t") { | |
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") | |
- val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1")) | |
- withView(viewName) { | |
- Seq("true", "false").foreach { flag => | |
- withSQLConf(ANSI_ENABLED.key -> flag) { | |
- checkViewOutput(viewName, Seq(Row(null))) | |
+ withSQLConf(ANSI_ENABLED.key -> "false") { | |
+ val viewName = createView("v1", "SELECT 1/0 AS invalid", Seq("c1")) | |
+ withView(viewName) { | |
+ Seq("true", "false").foreach { flag => | |
+ withSQLConf(ANSI_ENABLED.key -> flag) { | |
+ checkViewOutput(viewName, Seq(Row(null))) | |
+ } | |
} | |
} | |
} | |
@@ -378,6 +383,31 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { | |
} | |
} | |
} | |
+ | |
+ test("SPARK-37219: time travel is unsupported") { | |
+ val viewName = createView("testView", "SELECT 1 col") | |
+ withView(viewName) { | |
+ val e1 = intercept[AnalysisException]( | |
+ sql(s"SELECT * FROM $viewName VERSION AS OF 1").collect() | |
+ ) | |
+ assert(e1.message.contains("Cannot time travel views")) | |
+ | |
+ val e2 = intercept[AnalysisException]( | |
+ sql(s"SELECT * FROM $viewName TIMESTAMP AS OF '2000-10-10'").collect() | |
+ ) | |
+ assert(e2.message.contains("Cannot time travel views")) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37569: view should report correct nullability information for nested fields") { | |
+ val sql = "SELECT id, named_struct('a', id) AS nested FROM RANGE(10)" | |
+ val viewName = createView("testView", sql) | |
+ withView(viewName) { | |
+ val df = spark.sql(sql) | |
+ val dfFromView = spark.table(viewName) | |
+ assert(df.schema == dfFromView.schema) | |
+ } | |
+ } | |
} | |
abstract class TempViewTestSuite extends SQLViewTestSuite { | |
@@ -400,6 +430,18 @@ abstract class TempViewTestSuite extends SQLViewTestSuite { | |
} | |
} | |
+ test("show create table does not support temp view") { | |
+ val viewName = "spark_28383" | |
+ withView(viewName) { | |
+ createView(viewName, "SELECT 1 AS a") | |
+ val ex = intercept[AnalysisException] { | |
+ sql(s"SHOW CREATE TABLE ${formattedViewName(viewName)}") | |
+ } | |
+ assert(ex.getMessage.contains( | |
+ s"$viewName is a temp view. 'SHOW CREATE TABLE' expects a table or permanent view.")) | |
+ } | |
+ } | |
+ | |
test("back compatibility: skip cyclic reference check if view is stored as logical plan") { | |
val viewName = formattedViewName("v") | |
withSQLConf(STORE_ANALYZED_PLAN_FOR_VIEW.key -> "false") { | |
@@ -581,4 +623,84 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { | |
spark.sessionState.conf.clear() | |
} | |
} | |
+ | |
+ test("SPARK-37266: View text can only be SELECT queries") { | |
+ withView("v") { | |
+ sql("CREATE VIEW v AS SELECT 1") | |
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("v")) | |
+ val dropView = "DROP VIEW v" | |
+ // Simulate the behavior of hackers | |
+ val tamperedTable = table.copy(viewText = Some(dropView)) | |
+ spark.sessionState.catalog.alterTable(tamperedTable) | |
+ val message = intercept[AnalysisException] { | |
+ sql("SELECT * FROM v") | |
+ }.getMessage | |
+ assert(message.contains(s"Invalid view text: $dropView." + | |
+ s" The view ${table.qualifiedName} may have been tampered with")) | |
+ } | |
+ } | |
+ | |
+ test("show create table for persisted simple view") { | |
+ val viewName = "v1" | |
+ Seq(true, false).foreach { serde => | |
+ withView(viewName) { | |
+ createView(viewName, "SELECT 1 AS a") | |
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( a) AS SELECT 1 AS a" | |
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("show create table for persisted view with output columns") { | |
+ val viewName = "v1" | |
+ Seq(true, false).foreach { serde => | |
+ withView(viewName) { | |
+ createView(viewName, "SELECT 1 AS a, 2 AS b", Seq("a", "b COMMENT 'b column'")) | |
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)}" + | |
+ s" ( a, b COMMENT 'b column') AS SELECT 1 AS a, 2 AS b" | |
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("show create table for persisted simple view with table comment and properties") { | |
+ val viewName = "v1" | |
+ Seq(true, false).foreach { serde => | |
+ withView(viewName) { | |
+ createView(viewName, "SELECT 1 AS c1, '2' AS c2", Seq("c1 COMMENT 'bla'", "c2"), | |
+ Seq("COMMENT 'table comment'", "TBLPROPERTIES ( 'prop2' = 'value2', 'prop1' = 'value1')")) | |
+ | |
+ val expected = s"CREATE VIEW ${formattedViewName(viewName)} ( c1 COMMENT 'bla', c2)" + | |
+ " COMMENT 'table comment'" + | |
+ " TBLPROPERTIES ( 'prop1' = 'value1', 'prop2' = 'value2')" + | |
+ " AS SELECT 1 AS c1, '2' AS c2" | |
+ assert(getShowCreateDDL(formattedViewName(viewName), serde) == expected) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("capture the session time zone config while creating a view") { | |
+ val viewName = "v1_capture_test" | |
+ withView(viewName) { | |
+ assert(get.sessionLocalTimeZone === "America/Los_Angeles") | |
+ createView(viewName, | |
+ """select hour(ts) as H from ( | |
+ | select cast('2022-01-01T00:00:00.000 America/Los_Angeles' as timestamp) as ts | |
+ |)""".stripMargin, Seq("H")) | |
+ withDefaultTimeZone(java.time.ZoneId.of("UTC-09:00")) { | |
+ withSQLConf(SESSION_LOCAL_TIMEZONE.key -> "UTC-10:00") { | |
+ checkAnswer(sql(s"select H from $viewName"), Row(0)) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ def getShowCreateDDL(view: String, serde: Boolean = false): String = { | |
+ val result = if (serde) { | |
+ sql(s"SHOW CREATE TABLE $view AS SERDE") | |
+ } else { | |
+ sql(s"SHOW CREATE TABLE $view") | |
+ } | |
+ result.head().getString(0).split("\n").map(_.trim).mkString(" ") | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala | |
index 9f70c8aeca..da05373125 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala | |
@@ -703,27 +703,55 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { | |
test("splitSizeListByTargetSize") { | |
val targetSize = 100 | |
+ val smallPartitionFactor1 = ShufflePartitionsUtil.SMALL_PARTITION_FACTOR | |
// merge the small partitions at the beginning/end | |
- val sizeList1 = Seq[Long](15, 90, 15, 15, 15, 90, 15) | |
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList1, targetSize).toSeq == | |
+ val sizeList1 = Array[Long](15, 90, 15, 15, 15, 90, 15) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList1, targetSize, smallPartitionFactor1).toSeq == | |
Seq(0, 2, 5)) | |
// merge the small partitions in the middle | |
- val sizeList2 = Seq[Long](30, 15, 90, 10, 90, 15, 30) | |
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList2, targetSize).toSeq == | |
+ val sizeList2 = Array[Long](30, 15, 90, 10, 90, 15, 30) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList2, targetSize, smallPartitionFactor1).toSeq == | |
Seq(0, 2, 4, 5)) | |
// merge small partitions if the partition itself is smaller than | |
// targetSize * SMALL_PARTITION_FACTOR | |
- val sizeList3 = Seq[Long](15, 1000, 15, 1000) | |
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList3, targetSize).toSeq == | |
+ val sizeList3 = Array[Long](15, 1000, 15, 1000) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList3, targetSize, smallPartitionFactor1).toSeq == | |
Seq(0, 3)) | |
// merge small partitions if the combined size is smaller than | |
// targetSize * MERGED_PARTITION_FACTOR | |
- val sizeList4 = Seq[Long](35, 75, 90, 20, 35, 25, 35) | |
- assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList4, targetSize).toSeq == | |
+ val sizeList4 = Array[Long](35, 75, 90, 20, 35, 25, 35) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList4, targetSize, smallPartitionFactor1).toSeq == | |
Seq(0, 2, 3)) | |
+ | |
+ val smallPartitionFactor2 = 0.5 | |
+ // merge last two partition if their size is not bigger than smallPartitionFactor * target | |
+ val sizeList5 = Array[Long](50, 50, 40, 5) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList5, targetSize, smallPartitionFactor2).toSeq == | |
+ Seq(0)) | |
+ | |
+ val sizeList6 = Array[Long](40, 5, 50, 45) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList6, targetSize, smallPartitionFactor2).toSeq == | |
+ Seq(0)) | |
+ | |
+ // do not merge | |
+ val sizeList7 = Array[Long](50, 50, 10, 40, 5) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList7, targetSize, smallPartitionFactor2).toSeq == | |
+ Seq(0, 2)) | |
+ | |
+ val sizeList8 = Array[Long](10, 40, 5, 50, 50) | |
+ assert(ShufflePartitionsUtil.splitSizeListByTargetSize( | |
+ sizeList8, targetSize, smallPartitionFactor2).toSeq == | |
+ Seq(0, 3)) | |
} | |
test("SPARK-35923: Coalesce empty partition with mixed CoalescedPartitionSpec and" + | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala | |
index 6a4f3f6264..7c74423af6 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala | |
@@ -22,6 +22,7 @@ import scala.util.Random | |
import org.apache.spark.AccumulatorSuite | |
import org.apache.spark.sql.{RandomDataGenerator, Row} | |
import org.apache.spark.sql.catalyst.dsl.expressions._ | |
+import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
import org.apache.spark.sql.types._ | |
@@ -43,13 +44,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
checkAnswer( | |
input.toDF("a", "b", "c"), | |
- (child: SparkPlan) => SortExec('a.asc :: 'b.asc :: Nil, global = true, child = child), | |
+ (child: SparkPlan) => SortExec(Symbol("a").asc :: Symbol("b").asc :: Nil, | |
+ global = true, child = child), | |
input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), | |
sortAnswers = false) | |
checkAnswer( | |
input.toDF("a", "b", "c"), | |
- (child: SparkPlan) => SortExec('b.asc :: 'a.asc :: Nil, global = true, child = child), | |
+ (child: SparkPlan) => SortExec(Symbol("b").asc :: Symbol("a").asc :: Nil, | |
+ global = true, child = child), | |
input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), | |
sortAnswers = false) | |
} | |
@@ -58,9 +61,9 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
checkThatPlansAgree( | |
(1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"), | |
(child: SparkPlan) => | |
- GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), | |
+ GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), | |
(child: SparkPlan) => | |
- GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), | |
+ GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), | |
sortAnswers = false | |
) | |
} | |
@@ -69,15 +72,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
checkThatPlansAgree( | |
(1 to 100).map(v => Tuple1(v)).toDF("a"), | |
(child: SparkPlan) => | |
- GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), | |
+ GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), | |
(child: SparkPlan) => | |
- GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), | |
+ GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), | |
sortAnswers = false | |
) | |
} | |
test("sorting does not crash for large inputs") { | |
- val sortOrder = 'a.asc :: Nil | |
+ val sortOrder = Symbol("a").asc :: Nil | |
val stringLength = 1024 * 1024 * 2 | |
checkThatPlansAgree( | |
Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), | |
@@ -91,8 +94,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { | |
checkThatPlansAgree( | |
(1 to 100).map(v => Tuple1(v)).toDF("a"), | |
- (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child), | |
- (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), | |
+ (child: SparkPlan) => SortExec(Symbol("a").asc :: Nil, global = true, child = child), | |
+ (child: SparkPlan) => ReferenceSort(Symbol("a").asc :: Nil, global = true, child), | |
sortAnswers = false) | |
} | |
} | |
@@ -105,11 +108,31 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
) | |
checkAnswer( | |
input.toDF("a", "b", "c"), | |
- (child: SparkPlan) => SortExec(Stream('a.asc, 'b.asc, 'c.asc), global = true, child = child), | |
+ (child: SparkPlan) => SortExec(Stream(Symbol("a").asc, 'b.asc, 'c.asc), | |
+ global = true, child = child), | |
input.sortBy(t => (t._1, t._2, t._3)).map(Row.fromTuple), | |
sortAnswers = false) | |
} | |
+ test("SPARK-40089: decimal values sort correctly") { | |
+ val input = Seq( | |
+ BigDecimal("999999999999999999.50"), | |
+ BigDecimal("1.11"), | |
+ BigDecimal("999999999999999999.49") | |
+ ) | |
+ // The range partitioner does the right thing. If there are too many | |
+ // shuffle partitions the error might not always show up. | |
+ withSQLConf("spark.sql.shuffle.partitions" -> "1") { | |
+ val inputDf = spark.createDataFrame(sparkContext.parallelize(input.map(v => Row(v)), 1), | |
+ StructType(StructField("a", DecimalType(20, 2)) :: Nil)) | |
+ checkAnswer( | |
+ inputDf, | |
+ (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child), | |
+ input.sorted.map(Row(_)), | |
+ sortAnswers = false) | |
+ } | |
+ } | |
+ | |
// Test sorting on different data types | |
for ( | |
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); | |
@@ -124,12 +147,16 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { | |
sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), | |
StructType(StructField("a", dataType, nullable = true) :: Nil) | |
) | |
- checkThatPlansAgree( | |
- inputDf, | |
- p => SortExec(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), | |
- ReferenceSort(sortOrder, global = true, _: SparkPlan), | |
- sortAnswers = false | |
- ) | |
+ Seq(true, false).foreach { enableRadix => | |
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> enableRadix.toString) { | |
+ checkThatPlansAgree( | |
+ inputDf, | |
+ p => SortExec(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), | |
+ ReferenceSort(sortOrder, global = true, _: SparkPlan), | |
+ sortAnswers = false | |
+ ) | |
+ } | |
+ } | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala | |
index 04b589de7c..12d311d683 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala | |
@@ -18,11 +18,16 @@ | |
package org.apache.spark.sql.execution | |
import org.apache.spark.SparkEnv | |
+import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.QueryTest | |
+import org.apache.spark.sql.catalyst.InternalRow | |
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} | |
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate | |
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
+import org.apache.spark.sql.types.IntegerType | |
+import org.apache.spark.sql.vectorized.ColumnarBatch | |
class SparkPlanSuite extends QueryTest with SharedSparkSession { | |
@@ -110,6 +115,16 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { | |
"should have been replaced by aggregate in the optimizer")) | |
} | |
+ test("SPARK-37221: The collect-like API in SparkPlan should support columnar output") { | |
+ val emptyResults = ColumnarOp(LocalTableScanExec(Nil, Nil)).toRowBased.executeCollect() | |
+ assert(emptyResults.isEmpty) | |
+ | |
+ val relation = LocalTableScanExec( | |
+ Seq(AttributeReference("val", IntegerType)()), Seq(InternalRow(1))) | |
+ val nonEmpty = ColumnarOp(relation).toRowBased.executeCollect() | |
+ assert(nonEmpty === relation.executeCollect()) | |
+ } | |
+ | |
test("SPARK-37779: ColumnarToRowExec should be canonicalizable after being (de)serialized") { | |
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { | |
withTempPath { path => | |
@@ -129,3 +144,13 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
} | |
+ | |
+case class ColumnarOp(child: SparkPlan) extends UnaryExecNode { | |
+ override val supportsColumnar: Boolean = true | |
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = | |
+ RowToColumnarExec(child).executeColumnar() | |
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() | |
+ override def output: Seq[Attribute] = child.output | |
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarOp = | |
+ copy(child = newChild) | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala | |
index ba6dd170d8..49f65ab51c 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala | |
@@ -47,7 +47,7 @@ class SparkSqlParserSuite extends AnalysisTest { | |
} | |
private def intercept(sqlCommand: String, messages: String*): Unit = | |
- interceptParseException(parser.parsePlan)(sqlCommand, messages: _*) | |
+ interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)() | |
test("Checks if SET/RESET can parse all the configurations") { | |
// Force to build static SQL configurations | |
@@ -168,11 +168,11 @@ class SparkSqlParserSuite extends AnalysisTest { | |
intercept("SET a=1;2;;", expectedErrMsg) | |
intercept("SET a b=`1;;`", | |
- "'a b' is an invalid property key, please use quotes, e.g. SET `a b`=`1;;`") | |
+ "\"a b\" is an invalid property key, please use quotes, e.g. SET \"a b\"=\"1;;\"") | |
intercept("SET `a`=1;2;;", | |
- "'1;2;;' is an invalid property value, please use quotes, e.g." + | |
- " SET `a`=`1;2;;`") | |
+ "\"1;2;;\" is an invalid property value, please use quotes, e.g." + | |
+ " SET \"a\"=\"1;2;;\"") | |
} | |
test("refresh resource") { | |
@@ -312,7 +312,7 @@ class SparkSqlParserSuite extends AnalysisTest { | |
Seq(AttributeReference("a", StringType)(), | |
AttributeReference("b", StringType)(), | |
AttributeReference("c", StringType)()), | |
- Project(Seq('a, 'b, 'c), | |
+ Project(Seq(Symbol("a"), Symbol("b"), Symbol("c")), | |
UnresolvedRelation(TableIdentifier("testData"))), | |
ioSchema)) | |
@@ -336,9 +336,9 @@ class SparkSqlParserSuite extends AnalysisTest { | |
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), | |
Literal(10)), | |
Aggregate( | |
- Seq('a), | |
+ Seq(Symbol("a")), | |
Seq( | |
- 'a, | |
+ Symbol("a"), | |
UnresolvedAlias( | |
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), | |
UnresolvedAlias( | |
@@ -363,12 +363,12 @@ class SparkSqlParserSuite extends AnalysisTest { | |
AttributeReference("c", StringType)()), | |
WithWindowDefinition( | |
Map("w" -> WindowSpecDefinition( | |
- Seq('a), | |
- Seq(SortOrder('b, Ascending, NullsFirst, Seq.empty)), | |
+ Seq(Symbol("a")), | |
+ Seq(SortOrder(Symbol("b"), Ascending, NullsFirst, Seq.empty)), | |
UnspecifiedFrame)), | |
Project( | |
Seq( | |
- 'a, | |
+ Symbol("a"), | |
UnresolvedAlias( | |
UnresolvedWindowExpression( | |
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), | |
@@ -403,9 +403,9 @@ class SparkSqlParserSuite extends AnalysisTest { | |
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), | |
Literal(10)), | |
Aggregate( | |
- Seq('a, 'myCol, 'myCol2), | |
+ Seq(Symbol("a"), Symbol("myCol"), Symbol("myCol2")), | |
Seq( | |
- 'a, | |
+ Symbol("a"), | |
UnresolvedAlias( | |
UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), | |
UnresolvedAlias( | |
@@ -415,7 +415,7 @@ class SparkSqlParserSuite extends AnalysisTest { | |
UnresolvedGenerator( | |
FunctionIdentifier("explode"), | |
Seq(UnresolvedAttribute("myTable.myCol"))), | |
- Nil, false, Option("mytable2"), Seq('myCol2), | |
+ Nil, false, Option("mytable2"), Seq(Symbol("myCol2")), | |
Generate( | |
UnresolvedGenerator( | |
FunctionIdentifier("explode"), | |
@@ -423,7 +423,7 @@ class SparkSqlParserSuite extends AnalysisTest { | |
Seq( | |
UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), false)), | |
false))), | |
- Nil, false, Option("mytable"), Seq('myCol), | |
+ Nil, false, Option("mytable"), Seq(Symbol("myCol")), | |
UnresolvedRelation(TableIdentifier("testData")))))), | |
ioSchema)) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala | |
index c025670fb8..3718b3a3c3 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala | |
@@ -49,7 +49,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { | |
val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) | |
val cols = (0 until numCols).map { idx => | |
- from_json('value, schema).getField(s"col$idx") | |
+ from_json(Symbol("value"), schema).getField(s"col$idx") | |
} | |
Seq( | |
@@ -88,7 +88,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { | |
val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) | |
val predicate = (0 until numCols).map { idx => | |
- (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr | |
+ (from_json(Symbol("value"), schema).getField(s"col$idx") >= Literal(100000)).expr | |
}.asInstanceOf[Seq[Expression]].reduce(Or) | |
Seq( | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala | |
index 6ec5c6287e..ce48945e52 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala | |
@@ -58,7 +58,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { | |
private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) | |
val limit = 250 | |
- val sortOrder = 'a.desc :: 'b.desc :: Nil | |
+ val sortOrder = Symbol("a").desc :: Symbol("b").desc :: Nil | |
test("TakeOrderedAndProject.doExecute without project") { | |
withClue(s"seed = $seed") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala | |
index 2f626f7769..73c4e4c3e1 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala | |
@@ -40,15 +40,26 @@ class WholeStageCodegenSparkSubmitSuite extends SparkSubmitTestUtils | |
val unusedJar = TestUtils.createJarWithClasses(Seq.empty) | |
// HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched | |
- // settings of UseCompressedOops JVM option. | |
+ // settings of UseCompressedClassPointers JVM option. | |
val argsForSparkSubmit = Seq( | |
"--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"), | |
"--master", "local-cluster[1,1,1024]", | |
"--driver-memory", "1g", | |
"--conf", "spark.ui.enabled=false", | |
"--conf", "spark.master.rest.enabled=false", | |
- "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops", | |
- "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops", | |
+ // SPARK-37008: The results of `Platform.BYTE_ARRAY_OFFSET` using different Java versions | |
+ // and different args as follows table: | |
+ // +------------------------------+--------+---------+ | |
+ // | |Java 8 |Java 17 | | |
+ // +------------------------------+--------+---------+ | |
+ // |-XX:-UseCompressedOops | 24 | 16 | | |
+ // |-XX:+UseCompressedOops | 16 | 16 | | |
+ // |-XX:-UseCompressedClassPointers| 24 | 24 | | |
+ // |-XX:+UseCompressedClassPointers| 16 | 16 | | |
+ // +-------------------------------+-------+---------+ | |
+ // So SPARK-37008 replace `UseCompressedOops` with `UseCompressedClassPointers`. | |
+ "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedClassPointers", | |
+ "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedClassPointers", | |
"--conf", "spark.sql.adaptive.enabled=false", | |
unusedJar.toString) | |
runSparkSubmit(argsForSparkSubmit, timeout = 3.minutes) | |
@@ -60,7 +71,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { | |
var spark: SparkSession = _ | |
def main(args: Array[String]): Unit = { | |
- TestUtils.configTestLog4j("INFO") | |
+ TestUtils.configTestLog4j2("INFO") | |
spark = SparkSession.builder().getOrCreate() | |
@@ -73,7 +84,7 @@ object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging { | |
val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v") | |
.groupBy(array(col("v"))).agg(count(col("*"))) | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) | |
+ assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) | |
val expectedAnswer = | |
Row(Array(0), 7178) :: | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | |
index c2b22125cb..2be915f000 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | |
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution | |
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} | |
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} | |
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite | |
-import org.apache.spark.sql.execution.aggregate.HashAggregateExec | |
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} | |
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec | |
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} | |
import org.apache.spark.sql.functions._ | |
@@ -37,19 +37,30 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
test("range/filter should be combined") { | |
val df = spark.range(10).filter("id = 1").selectExpr("id + 1") | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) | |
+ assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec])) | |
assert(df.collect() === Array(Row(2))) | |
} | |
- test("Aggregate should be included in WholeStageCodegen") { | |
+ test("HashAggregate should be included in WholeStageCodegen") { | |
val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) | |
assert(df.collect() === Array(Row(9, 4.5))) | |
} | |
+ test("SortAggregate should be included in WholeStageCodegen") { | |
+ val df = spark.range(10).agg(max(col("id")), avg(col("id"))) | |
+ withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") { | |
+ val plan = df.queryExecution.executedPlan | |
+ assert(plan.exists(p => | |
+ p.isInstanceOf[WholeStageCodegenExec] && | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec])) | |
+ assert(df.collect() === Array(Row(9, 4.5))) | |
+ } | |
+ } | |
+ | |
testWithWholeStageCodegenOnAndOff("GenerateExec should be" + | |
" included in WholeStageCodegen") { codegenEnabled => | |
import testImplicits._ | |
@@ -59,22 +70,22 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// Array - explode | |
var expDF = df.select($"name", explode($"knownLanguages"), $"properties") | |
var plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, Array(Row("James", "Java", Map("hair" -> "black", "eye" -> "brown")), | |
Row("James", "Scala", Map("hair" -> "black", "eye" -> "brown")))) | |
// Map - explode | |
expDF = df.select($"name", $"knownLanguages", explode($"properties")) | |
plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, | |
Array(Row("James", List("Java", "Scala"), "hair", "black"), | |
Row("James", List("Java", "Scala"), "eye", "brown"))) | |
@@ -82,33 +93,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// Array - posexplode | |
expDF = df.select($"name", posexplode($"knownLanguages")) | |
plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, | |
Array(Row("James", 0, "Java"), Row("James", 1, "Scala"))) | |
// Map - posexplode | |
expDF = df.select($"name", posexplode($"properties")) | |
plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, | |
Array(Row("James", 0, "hair", "black"), Row("James", 1, "eye", "brown"))) | |
// Array - explode , selecting all columns | |
expDF = df.select($"*", explode($"knownLanguages")) | |
plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, | |
Array(Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Java"), | |
Row("James", Seq("Java", "Scala"), Map("hair" -> "black", "eye" -> "brown"), "Scala"))) | |
@@ -116,11 +127,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// Map - explode, selecting all columns | |
expDF = df.select($"*", explode($"properties")) | |
plan = expDF.queryExecution.executedPlan | |
- assert(plan.find { | |
+ assert(plan.exists { | |
case stage: WholeStageCodegenExec => | |
- stage.find(_.isInstanceOf[GenerateExec]).isDefined | |
+ stage.exists(_.isInstanceOf[GenerateExec]) | |
case _ => !codegenEnabled.toBoolean | |
- }.isDefined) | |
+ }) | |
checkAnswer(expDF, | |
Array( | |
Row("James", List("Java", "Scala"), | |
@@ -129,12 +140,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
Map("hair" -> "black", "eye" -> "brown"), "eye", "brown"))) | |
} | |
- test("Aggregate with grouping keys should be included in WholeStageCodegen") { | |
+ test("HashAggregate with grouping keys should be included in WholeStageCodegen") { | |
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec])) | |
assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) | |
} | |
@@ -143,13 +154,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
val schema = new StructType().add("k", IntegerType).add("v", StringType) | |
val smallDF = spark.createDataFrame(rdd, schema) | |
val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) | |
- assert(df.queryExecution.executedPlan.find(p => | |
+ assert(df.queryExecution.executedPlan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec])) | |
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) | |
} | |
- test("ShuffledHashJoin should be included in WholeStageCodegen") { | |
+ test("Inner ShuffledHashJoin should be included in WholeStageCodegen") { | |
val df1 = spark.range(5).select($"id".as("k1")) | |
val df2 = spark.range(15).select($"id".as("k2")) | |
val df3 = spark.range(6).select($"id".as("k3")) | |
@@ -171,6 +182,56 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4))) | |
} | |
+ test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") { | |
+ val df1 = spark.range(5).select($"id".as("k1")) | |
+ val df2 = spark.range(10).select($"id".as("k2")) | |
+ val df3 = spark.range(3).select($"id".as("k3")) | |
+ | |
+ Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint => | |
+ // test one join with unique key from build side | |
+ val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer") | |
+ assert(joinUniqueDF.queryExecution.executedPlan.collect { | |
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true | |
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true | |
+ }.size === 1) | |
+ checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4), | |
+ Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9))) | |
+ assert(joinUniqueDF.count() === 10) | |
+ | |
+ // test one join with non-unique key from build side | |
+ val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer") | |
+ assert(joinNonUniqueDF.queryExecution.executedPlan.collect { | |
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true | |
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true | |
+ }.size === 1) | |
+ checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), | |
+ Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) | |
+ | |
+ // test one join with non-equi condition | |
+ val joinWithNonEquiDF = df1.join(df2.hint(hint), | |
+ $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") | |
+ assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { | |
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true | |
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true | |
+ }.size === 1) | |
+ checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), | |
+ Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4), | |
+ Row(null, 5))) | |
+ | |
+ // test two joins | |
+ val twoJoinsDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer") | |
+ .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer") | |
+ assert(twoJoinsDF.queryExecution.executedPlan.collect { | |
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true | |
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true | |
+ }.size === 2) | |
+ checkAnswer(twoJoinsDF, | |
+ Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), | |
+ Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null), | |
+ Row(null, 9, null), Row(null, null, 1))) | |
+ } | |
+ } | |
+ | |
test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") { | |
val df1 = spark.range(10).select($"id".as("k1")) | |
val df2 = spark.range(4).select($"id".as("k2")) | |
@@ -373,9 +434,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
test("Sort should be included in WholeStageCodegen") { | |
val df = spark.range(3, 0, -1).toDF().sort(col("id")) | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec])) | |
assert(df.collect() === Array(Row(1), Row(2), Row(3))) | |
} | |
@@ -384,27 +445,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
val ds = spark.range(10).map(_.toString) | |
val plan = ds.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec])) | |
assert(ds.collect() === 0.until(10).map(_.toString).toArray) | |
} | |
test("typed filter should be included in WholeStageCodegen") { | |
val ds = spark.range(10).filter(_ % 2 == 0) | |
val plan = ds.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec])) | |
assert(ds.collect() === Array(0, 2, 4, 6, 8)) | |
} | |
test("back-to-back typed filter should be included in WholeStageCodegen") { | |
val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) | |
val plan = ds.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) | |
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec])) | |
assert(ds.collect() === Array(0, 6)) | |
} | |
@@ -417,7 +478,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
val planInt = dsIntFilter.queryExecution.executedPlan | |
assert(planInt.collect { | |
case WholeStageCodegenExec(FilterExec(_, | |
- ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => () | |
+ InputAdapter(_: InMemoryTableScanExec))) => () | |
}.length == 1) | |
assert(dsIntFilter.collect() === Array(1, 2)) | |
@@ -456,10 +517,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
.select("int") | |
val plan = df.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(!plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
p.asInstanceOf[WholeStageCodegenExec].child.children(0) | |
- .isInstanceOf[SortMergeJoinExec]).isEmpty) | |
+ .isInstanceOf[SortMergeJoinExec])) | |
assert(df.collect() === Array(Row(1), Row(2))) | |
} | |
} | |
@@ -512,7 +573,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
import testImplicits._ | |
withTempPath { dir => | |
val path = dir.getCanonicalPath | |
- val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) | |
+ val df = spark.range(10).select(Seq.tabulate(201) {i => (Symbol("id") + i).as(s"c$i")} : _*) | |
df.write.mode(SaveMode.Overwrite).parquet(path) | |
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", | |
@@ -529,7 +590,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
test("Control splitting consume function by operators with config") { | |
import testImplicits._ | |
- val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) | |
+ val df = spark.range(10).select(Seq.tabulate(2) {i => (Symbol("id") + i).as(s"c$i")} : _*) | |
Seq(true, false).foreach { config => | |
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { | |
@@ -578,9 +639,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
val df = spark.range(100) | |
val join = df.join(df, "id") | |
val plan = join.queryExecution.executedPlan | |
- assert(plan.find(p => | |
+ assert(!plan.exists(p => | |
p.isInstanceOf[WholeStageCodegenExec] && | |
- p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty, | |
+ p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0), | |
"codegen stage IDs should be preserved through ReuseExchange") | |
checkAnswer(join, df.toDF) | |
} | |
@@ -592,9 +653,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { | |
// the same query run twice should produce identical code, which would imply a hit in | |
// the generated code cache. | |
- val ds1 = spark.range(3).select('id + 2) | |
+ val ds1 = spark.range(3).select(Symbol("id") + 2) | |
val code1 = genCode(ds1) | |
- val ds2 = spark.range(3).select('id + 2) | |
+ val ds2 = spark.range(3).select(Symbol("id") + 2) | |
val code2 = genCode(ds2) // same query shape as above, deliberately | |
assert(code1 == code2, "Should produce same code") | |
} | |
@@ -639,10 +700,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions | |
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) | |
.join(baseTable, "idx") | |
- assert(distinctWithId.queryExecution.executedPlan.collectFirst { | |
+ assert(distinctWithId.queryExecution.executedPlan.exists { | |
case WholeStageCodegenExec( | |
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true | |
- }.isDefined) | |
+ case _ => false | |
+ }) | |
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) | |
// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate | |
@@ -650,10 +712,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
val groupByWithId = | |
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) | |
.join(baseTable, "idx") | |
- assert(groupByWithId.queryExecution.executedPlan.collectFirst { | |
+ assert(groupByWithId.queryExecution.executedPlan.exists { | |
case WholeStageCodegenExec( | |
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true | |
- }.isDefined) | |
+ case _ => false | |
+ }) | |
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) | |
} | |
} | |
@@ -679,11 +742,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// HashAggregateExec supports WholeStageCodegen and it's the parent of | |
// LocalTableScanExec so LocalTableScanExec should be within a WholeStageCodegen domain. | |
assert( | |
- executedPlan.find { | |
+ executedPlan.exists { | |
case WholeStageCodegenExec( | |
HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true | |
case _ => false | |
- }.isDefined, | |
+ }, | |
"LocalTableScanExec should be within a WholeStageCodegen domain.") | |
} | |
@@ -699,10 +762,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
"SELECT AVG(v) FROM VALUES(1) t(v)", | |
// Tet case with keys | |
"SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query => | |
- val errMsg = intercept[IllegalStateException] { | |
+ val e = intercept[IllegalStateException] { | |
sql(query).collect | |
- }.getMessage | |
- assert(errMsg.contains(expectedErrMsg)) | |
+ } | |
+ assert(e.getMessage.contains(expectedErrMsg)) | |
} | |
} | |
} | |
@@ -721,10 +784,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession | |
// Tet case with keys | |
"SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " + | |
"GROUP BY k").foreach { query => | |
- val e = intercept[Exception] { | |
+ val e = intercept[IllegalStateException] { | |
sql(query).collect | |
} | |
- assert(e.isInstanceOf[IllegalStateException]) | |
assert(e.getMessage.contains(expectedErrMsg)) | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | |
index 7ae162ca8a..0055b94fa0 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | |
@@ -20,19 +20,22 @@ package org.apache.spark.sql.execution.adaptive | |
import java.io.File | |
import java.net.URI | |
-import org.apache.log4j.Level | |
+import org.apache.logging.log4j.Level | |
import org.scalatest.PrivateMethodTester | |
+import org.scalatest.time.SpanSugar._ | |
+import org.apache.spark.SparkException | |
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} | |
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} | |
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} | |
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} | |
-import org.apache.spark.sql.execution.{CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode} | |
+import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} | |
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec | |
import org.apache.spark.sql.execution.command.DataWritingCommandExec | |
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource | |
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec | |
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} | |
-import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} | |
+import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} | |
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter | |
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate | |
import org.apache.spark.sql.functions._ | |
@@ -102,6 +105,12 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = { | |
+ collect(plan) { | |
+ case j: BroadcastNestedLoopJoinExec => j | |
+ } | |
+ } | |
+ | |
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { | |
collect(plan) { | |
case j: SortMergeJoinExec => j | |
@@ -126,6 +135,18 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { | |
+ collect(plan) { | |
+ case agg: BaseAggregateExec => agg | |
+ } | |
+ } | |
+ | |
+ private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { | |
+ collect(plan) { | |
+ case l: CollectLimitExec => l | |
+ } | |
+ } | |
+ | |
private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { | |
collectWithSubqueries(plan) { | |
case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e | |
@@ -267,11 +288,12 @@ class AdaptiveQueryExecSuite | |
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", | |
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { | |
- val df1 = spark.range(10).withColumn("a", 'id) | |
- val df2 = spark.range(10).withColumn("b", 'id) | |
+ val df1 = spark.range(10).withColumn("a", Symbol("id")) | |
+ val df2 = spark.range(10).withColumn("b", Symbol("id")) | |
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | |
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") | |
- .groupBy('a).count() | |
+ val testDf = df1.where(Symbol("a") > 10) | |
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") | |
+ .groupBy(Symbol("a")).count() | |
checkAnswer(testDf, Seq()) | |
val plan = testDf.queryExecution.executedPlan | |
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) | |
@@ -283,8 +305,9 @@ class AdaptiveQueryExecSuite | |
} | |
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { | |
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") | |
- .groupBy('a).count() | |
+ val testDf = df1.where(Symbol("a") > 10) | |
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") | |
+ .groupBy(Symbol("a")).count() | |
checkAnswer(testDf, Seq()) | |
val plan = testDf.queryExecution.executedPlan | |
assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
@@ -668,6 +691,47 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
} | |
+ test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" + | |
+ " partitions on broadcast side") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { | |
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config. | |
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { | |
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( | |
+ "SELECT * FROM (select * from testData where value = '1') td" + | |
+ " right outer join testData2 ON key = a") | |
+ val smj = findTopLevelSortMergeJoin(plan) | |
+ assert(smj.size == 1) | |
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | |
+ assert(bhj.size == 1) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" + | |
+ " partitions on outer/left side") { | |
+ // if the right side is completed first and the left side is still being executed, | |
+ // the right side does not know whether there are many empty partitions on the left side, | |
+ // so there is no demote, and then the right side is broadcast in the planning stage. | |
+ // so retry several times here to avoid unit test failure. | |
+ eventually(timeout(15.seconds), interval(500.milliseconds)) { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { | |
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config. | |
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") { | |
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( | |
+ "SELECT * FROM (select * from testData where value = '1') td" + | |
+ " left outer join testData2 ON key = a") | |
+ val smj = findTopLevelSortMergeJoin(plan) | |
+ assert(smj.size == 1) | |
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | |
+ assert(bhj.isEmpty) | |
+ } | |
+ } | |
+ } | |
+ } | |
test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { | |
var numStages = 0 | |
@@ -738,17 +802,17 @@ class AdaptiveQueryExecSuite | |
spark | |
.range(0, 1000, 1, 10) | |
.select( | |
- when('id < 250, 249) | |
- .when('id >= 750, 1000) | |
- .otherwise('id).as("key1"), | |
- 'id as "value1") | |
+ when(Symbol("id") < 250, 249) | |
+ .when(Symbol("id") >= 750, 1000) | |
+ .otherwise(Symbol("id")).as("key1"), | |
+ Symbol("id") as "value1") | |
.createOrReplaceTempView("skewData1") | |
spark | |
.range(0, 1000, 1, 10) | |
.select( | |
- when('id < 250, 249) | |
- .otherwise('id).as("key2"), | |
- 'id as "value2") | |
+ when(Symbol("id") < 250, 249) | |
+ .otherwise(Symbol("id")).as("key2"), | |
+ Symbol("id") as "value2") | |
.createOrReplaceTempView("skewData2") | |
def checkSkewJoin( | |
@@ -806,11 +870,11 @@ class AdaptiveQueryExecSuite | |
df1.write.parquet(tableDir.getAbsolutePath) | |
val aggregated = spark.table("bucketed_table").groupBy("i").count() | |
- val error = intercept[Exception] { | |
+ val error = intercept[SparkException] { | |
aggregated.count() | |
} | |
- assert(error.toString contains "Invalid bucket file") | |
- assert(error.getSuppressed.size === 0) | |
+ assert(error.getErrorClass === "INVALID_BUCKET_FILE") | |
+ assert(error.getMessage contains "Invalid bucket file") | |
} | |
} | |
} | |
@@ -842,7 +906,7 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
assert(!testAppender.loggingEvents | |
- .exists(msg => msg.getRenderedMessage.contains( | |
+ .exists(msg => msg.getMessage.getFormattedMessage.contains( | |
s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + | |
s" enabled but is not supported for"))) | |
} | |
@@ -850,6 +914,7 @@ class AdaptiveQueryExecSuite | |
test("test log level") { | |
def verifyLog(expectedLevel: Level): Unit = { | |
val logAppender = new LogAppender("adaptive execution") | |
+ logAppender.setThreshold(expectedLevel) | |
withLogAppender( | |
logAppender, | |
loggerNames = Seq(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), | |
@@ -863,7 +928,7 @@ class AdaptiveQueryExecSuite | |
Seq("Plan changed", "Final plan").foreach { msg => | |
assert( | |
logAppender.loggingEvents.exists { event => | |
- event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel | |
+ event.getMessage.getFormattedMessage.contains(msg) && event.getLevel == expectedLevel | |
}) | |
} | |
} | |
@@ -982,17 +1047,17 @@ class AdaptiveQueryExecSuite | |
spark | |
.range(0, 1000, 1, 10) | |
.select( | |
- when('id < 250, 249) | |
- .when('id >= 750, 1000) | |
- .otherwise('id).as("key1"), | |
- 'id as "value1") | |
+ when(Symbol("id") < 250, 249) | |
+ .when(Symbol("id") >= 750, 1000) | |
+ .otherwise(Symbol("id")).as("key1"), | |
+ Symbol("id") as "value1") | |
.createOrReplaceTempView("skewData1") | |
spark | |
.range(0, 1000, 1, 10) | |
.select( | |
- when('id < 250, 249) | |
- .otherwise('id).as("key2"), | |
- 'id as "value2") | |
+ when(Symbol("id") < 250, 249) | |
+ .otherwise(Symbol("id")).as("key2"), | |
+ Symbol("id") as "value2") | |
.createOrReplaceTempView("skewData2") | |
val (_, adaptivePlan) = runAdaptiveAndVerifyResult( | |
"SELECT * FROM skewData1 join skewData2 ON key1 = key2") | |
@@ -1070,7 +1135,7 @@ class AdaptiveQueryExecSuite | |
test("AQE should set active session during execution") { | |
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
- val df = spark.range(10).select(sum('id)) | |
+ val df = spark.range(10).select(sum(Symbol("id"))) | |
assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) | |
SparkSession.setActiveSession(null) | |
checkAnswer(df, Seq(Row(45))) | |
@@ -1097,7 +1162,7 @@ class AdaptiveQueryExecSuite | |
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { | |
try { | |
spark.experimental.extraStrategies = TestStrategy :: Nil | |
- val df = spark.range(10).groupBy('id).count() | |
+ val df = spark.range(10).groupBy(Symbol("id")).count() | |
df.collect() | |
} finally { | |
spark.experimental.extraStrategies = Nil | |
@@ -1422,6 +1487,56 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ test("SPARK-35442: Support propagate empty relation through aggregate") { | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( | |
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key") | |
+ assert(!plan1.isInstanceOf[LocalTableScanExec]) | |
+ assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) | |
+ | |
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( | |
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1") | |
+ assert(!plan2.isInstanceOf[LocalTableScanExec]) | |
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) | |
+ | |
+ val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult( | |
+ "SELECT count(*) FROM testData WHERE value = 'no_match'") | |
+ assert(!plan3.isInstanceOf[LocalTableScanExec]) | |
+ assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec]) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-35442: Support propagate empty relation through union") { | |
+ def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = { | |
+ assert( | |
+ collect(plan) { | |
+ case u: UnionExec => u | |
+ }.size == numUnion) | |
+ } | |
+ | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key | |
+ |UNION ALL | |
+ |SELECT key, 1 FROM testData | |
+ |""".stripMargin) | |
+ checkNumUnion(plan1, 1) | |
+ checkNumUnion(adaptivePlan1, 0) | |
+ assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) | |
+ | |
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key | |
+ |UNION ALL | |
+ |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match' | |
+ |""".stripMargin) | |
+ checkNumUnion(plan2, 1) | |
+ checkNumUnion(adaptivePlan2, 0) | |
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) | |
+ } | |
+ } | |
+ | |
test("SPARK-32753: Only copy tags to node with no tags") { | |
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
withTempView("v1") { | |
@@ -1450,7 +1565,8 @@ class AdaptiveQueryExecSuite | |
"=== Result of Batch AQE Post Stage Creation ===", | |
"=== Result of Batch AQE Replanning ===", | |
"=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg => | |
- assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) | |
+ assert(testAppender.loggingEvents.exists( | |
+ _.getMessage.getFormattedMessage.contains(expectedMsg))) | |
} | |
} | |
} | |
@@ -1502,7 +1618,7 @@ class AdaptiveQueryExecSuite | |
test("SPARK-33494: Do not use local shuffle read for repartition") { | |
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
- val df = spark.table("testData").repartition('key) | |
+ val df = spark.table("testData").repartition(Symbol("key")) | |
df.collect() | |
// local shuffle read breaks partitioning and shouldn't be used for repartition operation | |
// which is specified by users. | |
@@ -1581,28 +1697,28 @@ class AdaptiveQueryExecSuite | |
| SELECT * FROM testData WHERE key = 1 | |
|) | |
|RIGHT OUTER JOIN testData2 | |
- |ON value = b | |
+ |ON CAST(value AS INT) = b | |
""".stripMargin) | |
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { | |
// Repartition with no partition num specified. | |
- checkBHJ(df.repartition('b), | |
+ checkBHJ(df.repartition(Symbol("b")), | |
// The top shuffle from repartition is optimized out. | |
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) | |
// Repartition with default partition num (5 in test env) specified. | |
- checkBHJ(df.repartition(5, 'b), | |
+ checkBHJ(df.repartition(5, Symbol("b")), | |
// The top shuffle from repartition is optimized out | |
// The final plan must have 5 partitions, no optimization can be made to the probe side. | |
optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) | |
// Repartition with non-default partition num specified. | |
- checkBHJ(df.repartition(4, 'b), | |
+ checkBHJ(df.repartition(4, Symbol("b")), | |
// The top shuffle from repartition is not optimized out | |
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) | |
// Repartition by col and project away the partition cols | |
- checkBHJ(df.repartition('b).select('key), | |
+ checkBHJ(df.repartition(Symbol("b")).select(Symbol("key")), | |
// The top shuffle from repartition is not optimized out | |
optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) | |
} | |
@@ -1614,23 +1730,23 @@ class AdaptiveQueryExecSuite | |
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", | |
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { | |
// Repartition with no partition num specified. | |
- checkSMJ(df.repartition('b), | |
+ checkSMJ(df.repartition(Symbol("b")), | |
// The top shuffle from repartition is optimized out. | |
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) | |
// Repartition with default partition num (5 in test env) specified. | |
- checkSMJ(df.repartition(5, 'b), | |
+ checkSMJ(df.repartition(5, Symbol("b")), | |
// The top shuffle from repartition is optimized out. | |
// The final plan must have 5 partitions, can't do coalesced read. | |
optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) | |
// Repartition with non-default partition num specified. | |
- checkSMJ(df.repartition(4, 'b), | |
+ checkSMJ(df.repartition(4, Symbol("b")), | |
// The top shuffle from repartition is not optimized out. | |
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) | |
// Repartition by col and project away the partition cols | |
- checkSMJ(df.repartition('b).select('key), | |
+ checkSMJ(df.repartition(Symbol("b")).select(Symbol("key")), | |
// The top shuffle from repartition is not optimized out. | |
optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) | |
} | |
@@ -1667,6 +1783,7 @@ class AdaptiveQueryExecSuite | |
test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") { | |
val testAppender = new LogAppender("aqe query stage materialization order test") | |
+ testAppender.setThreshold(Level.DEBUG) | |
val df = spark.range(1000).select($"id" % 26, $"id" % 10) | |
.toDF("index", "pv") | |
val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) | |
@@ -1683,7 +1800,7 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
val materializeLogs = testAppender.loggingEvents | |
- .map(_.getRenderedMessage) | |
+ .map(_.getMessage.getFormattedMessage) | |
.filter(_.startsWith("Materialize query stage")) | |
.toArray | |
assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec")) | |
@@ -1722,10 +1839,94 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ test("SPARK-34980: Support coalesce partition through union") { | |
+ def checkResultPartition( | |
+ df: Dataset[Row], | |
+ numUnion: Int, | |
+ numShuffleReader: Int, | |
+ numPartition: Int): Unit = { | |
+ df.collect() | |
+ assert(collect(df.queryExecution.executedPlan) { | |
+ case u: UnionExec => u | |
+ }.size == numUnion) | |
+ assert(collect(df.queryExecution.executedPlan) { | |
+ case r: AQEShuffleReadExec => r | |
+ }.size === numShuffleReader) | |
+ assert(df.rdd.partitions.length === numPartition) | |
+ } | |
+ | |
+ Seq(true, false).foreach { combineUnionEnabled => | |
+ val combineUnionConfig = if (combineUnionEnabled) { | |
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "" | |
+ } else { | |
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> | |
+ "org.apache.spark.sql.catalyst.optimizer.CombineUnions" | |
+ } | |
+ // advisory partition size 1048576 has no special meaning, just a big enough value | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", | |
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576", | |
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10", | |
+ combineUnionConfig) { | |
+ withTempView("t1", "t2") { | |
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 2) | |
+ .toDF().createOrReplaceTempView("t1") | |
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 4) | |
+ .toDF().createOrReplaceTempView("t2") | |
+ | |
+ // positive test that could be coalesced | |
+ checkResultPartition( | |
+ sql(""" | |
+ |SELECT key, count(*) FROM t1 GROUP BY key | |
+ |UNION ALL | |
+ |SELECT * FROM t2 | |
+ """.stripMargin), | |
+ numUnion = 1, | |
+ numShuffleReader = 1, | |
+ numPartition = 1 + 4) | |
+ | |
+ checkResultPartition( | |
+ sql(""" | |
+ |SELECT key, count(*) FROM t1 GROUP BY key | |
+ |UNION ALL | |
+ |SELECT * FROM t2 | |
+ |UNION ALL | |
+ |SELECT * FROM t1 | |
+ """.stripMargin), | |
+ numUnion = if (combineUnionEnabled) 1 else 2, | |
+ numShuffleReader = 1, | |
+ numPartition = 1 + 4 + 2) | |
+ | |
+ checkResultPartition( | |
+ sql(""" | |
+ |SELECT /*+ merge(t2) */ t1.key, t2.key FROM t1 JOIN t2 ON t1.key = t2.key | |
+ |UNION ALL | |
+ |SELECT key, count(*) FROM t2 GROUP BY key | |
+ |UNION ALL | |
+ |SELECT * FROM t1 | |
+ """.stripMargin), | |
+ numUnion = if (combineUnionEnabled) 1 else 2, | |
+ numShuffleReader = 3, | |
+ numPartition = 1 + 1 + 2) | |
+ | |
+ // negative test | |
+ checkResultPartition( | |
+ sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2"), | |
+ numUnion = if (combineUnionEnabled) 1 else 1, | |
+ numShuffleReader = 0, | |
+ numPartition = 2 + 4 | |
+ ) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { | |
withTable("t") { | |
withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", | |
- SQLConf.SHUFFLE_PARTITIONS.key -> "2") { | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2", | |
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { | |
spark.sql("CREATE TABLE t (c1 int) USING PARQUET") | |
val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") | |
assert( | |
@@ -1889,8 +2090,8 @@ class AdaptiveQueryExecSuite | |
withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") { | |
// partition size [0,258,72,72,72] | |
checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4) | |
- // partition size [72,216,216,144,72] | |
- checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 4, 7) | |
+ // partition size [144,72,144,72,72,144,72] | |
+ checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 6, 7) | |
} | |
// no skewed partition should be optimized | |
@@ -1925,6 +2126,74 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ test("SPARK-33832: Support optimize skew join even if introduce extra shuffle") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "false", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", | |
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", | |
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10", | |
+ SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") { | |
+ withTempView("skewData1", "skewData2") { | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 3 as key1", "id as value1") | |
+ .createOrReplaceTempView("skewData1") | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 1 as key2", "id as value2") | |
+ .createOrReplaceTempView("skewData2") | |
+ | |
+ // check if optimized skewed join does not satisfy the required distribution | |
+ Seq(true, false).foreach { hasRequiredDistribution => | |
+ Seq(true, false).foreach { hasPartitionNumber => | |
+ val repartition = if (hasRequiredDistribution) { | |
+ s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) */" | |
+ } else { | |
+ "" | |
+ } | |
+ | |
+ // check required distribution and extra shuffle | |
+ val (_, adaptive1) = | |
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + | |
+ s"JOIN skewData2 ON key1 = key2 GROUP BY key1") | |
+ val shuffles1 = collect(adaptive1) { | |
+ case s: ShuffleExchangeExec => s | |
+ } | |
+ assert(shuffles1.size == 3) | |
+ // shuffles1.head is the top-level shuffle under the Aggregate operator | |
+ assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS) | |
+ val smj1 = findTopLevelSortMergeJoin(adaptive1) | |
+ assert(smj1.size == 1 && smj1.head.isSkewJoin) | |
+ | |
+ // only check required distribution | |
+ val (_, adaptive2) = | |
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + | |
+ s"JOIN skewData2 ON key1 = key2") | |
+ val shuffles2 = collect(adaptive2) { | |
+ case s: ShuffleExchangeExec => s | |
+ } | |
+ if (hasRequiredDistribution) { | |
+ assert(shuffles2.size == 3) | |
+ val finalShuffle = shuffles2.head | |
+ if (hasPartitionNumber) { | |
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM) | |
+ } else { | |
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL) | |
+ } | |
+ } else { | |
+ assert(shuffles2.size == 2) | |
+ } | |
+ val smj2 = findTopLevelSortMergeJoin(adaptive2) | |
+ assert(smj2.size == 1 && smj2.head.isSkewJoin) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
test("SPARK-35968: AQE coalescing should not produce too small partitions by default") { | |
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
val (_, adaptive) = | |
@@ -2033,10 +2302,108 @@ class AdaptiveQueryExecSuite | |
} | |
} | |
+ test("SPARK-36424: Support eliminate limits in AQE Optimizer") { | |
+ withTempView("v") { | |
+ spark.sparkContext.parallelize( | |
+ (1 to 10).map(i => TestData(i, if (i > 2) "2" else i.toString)), 2) | |
+ .toDF("c1", "c2").createOrReplaceTempView("v") | |
+ | |
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "3") { | |
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 LIMIT 5 | |
+ """.stripMargin) | |
+ assert(findTopLevelLimit(origin1).size == 1) | |
+ assert(findTopLevelLimit(adaptive1).isEmpty) | |
+ | |
+ // eliminate limit through filter | |
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 HAVING sum(c1) > 1 LIMIT 5 | |
+ """.stripMargin) | |
+ assert(findTopLevelLimit(origin2).size == 1) | |
+ assert(findTopLevelLimit(adaptive2).isEmpty) | |
+ | |
+ // The strategy of Eliminate Limits batch should be fixedPoint | |
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20 | |
+ """.stripMargin | |
+ ) | |
+ assert(findTopLevelLimit(origin3).size == 1) | |
+ assert(findTopLevelLimit(adaptive3).isEmpty) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") { | |
+ withTempView("v") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1", | |
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { | |
+ spark.sparkContext.parallelize( | |
+ (1 to 10).map(i => TestData(if (i > 2) 2 else i, i.toString)), 2) | |
+ .toDF("c1", "c2").createOrReplaceTempView("v") | |
+ | |
+ def checkRebalance(query: String, numShufflePartitions: Int): Unit = { | |
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query) | |
+ assert(adaptive.collect { | |
+ case sort: SortExec => sort | |
+ }.size == 1) | |
+ val read = collect(adaptive) { | |
+ case read: AQEShuffleReadExec => read | |
+ } | |
+ assert(read.size == 1) | |
+ assert(read.head.partitionSpecs.forall(_.isInstanceOf[PartialReducerPartitionSpec])) | |
+ assert(read.head.partitionSpecs.size == numShufflePartitions) | |
+ } | |
+ | |
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "50") { | |
+ checkRebalance("SELECT /*+ REBALANCE(c1) */ * FROM v SORT BY c1", 2) | |
+ checkRebalance("SELECT /*+ REBALANCE */ * FROM v SORT BY c1", 2) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37357: Add small partition factor for rebalance partitions") { | |
+ withTempView("v") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") { | |
+ spark.sparkContext.parallelize( | |
+ (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3) | |
+ .toDF("c1", "c2").createOrReplaceTempView("v") | |
+ | |
+ def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = { | |
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query) | |
+ assert( | |
+ collect(adaptive) { | |
+ case read: AQEShuffleReadExec => read | |
+ }.nonEmpty == exists) | |
+ } | |
+ | |
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") { | |
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") { | |
+ // block size: [88, 97, 97] | |
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false) | |
+ } | |
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") { | |
+ // block size: [88, 97, 97] | |
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true) | |
+ } | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") { | |
withSQLConf( | |
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584") { | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584", | |
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { | |
// Spark estimates a string column as 20 bytes so with 60k rows, these relations should be | |
// estimated at ~120m bytes which is greater than the broadcast join threshold. | |
val joinKeyOne = "00112233445566778899" | |
@@ -2085,6 +2452,203 @@ class AdaptiveQueryExecSuite | |
assert(bhj.length == 1) | |
} | |
} | |
+ | |
+ test("SPARK-37328: skew join with 3 tables") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", | |
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", | |
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", | |
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") { | |
+ withTempView("skewData1", "skewData2", "skewData3") { | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 3 as key1", "id % 3 as value1") | |
+ .createOrReplaceTempView("skewData1") | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 1 as key2", "id as value2") | |
+ .createOrReplaceTempView("skewData2") | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 1 as key3", "id as value3") | |
+ .createOrReplaceTempView("skewData3") | |
+ | |
+ // skewedJoin doesn't happen in last stage | |
+ val (_, adaptive1) = | |
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + | |
+ "JOIN skewData3 ON value2 = value3") | |
+ val shuffles1 = collect(adaptive1) { | |
+ case s: ShuffleExchangeExec => s | |
+ } | |
+ assert(shuffles1.size == 4) | |
+ val smj1 = findTopLevelSortMergeJoin(adaptive1) | |
+ assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin) | |
+ | |
+ // Query has two skewJoin in two continuous stages. | |
+ val (_, adaptive2) = | |
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + | |
+ "JOIN skewData3 ON value1 = value3") | |
+ val shuffles2 = collect(adaptive2) { | |
+ case s: ShuffleExchangeExec => s | |
+ } | |
+ assert(shuffles2.size == 4) | |
+ val smj2 = findTopLevelSortMergeJoin(adaptive2) | |
+ assert(smj2.size == 2 && smj2.forall(_.isSkewJoin)) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37652: optimize skewed join through union") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | |
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", | |
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { | |
+ withTempView("skewData1", "skewData2") { | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 3 as key1", "id as value1") | |
+ .createOrReplaceTempView("skewData1") | |
+ spark | |
+ .range(0, 1000, 1, 10) | |
+ .selectExpr("id % 1 as key2", "id as value2") | |
+ .createOrReplaceTempView("skewData2") | |
+ | |
+ def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = { | |
+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) | |
+ val joins = findTopLevelSortMergeJoin(innerAdaptivePlan) | |
+ val optimizeSkewJoins = joins.filter(_.isSkewJoin) | |
+ assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums) | |
+ } | |
+ | |
+ // skewJoin union skewJoin | |
+ checkSkewJoin( | |
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + | |
+ "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2) | |
+ | |
+ // skewJoin union aggregate | |
+ checkSkewJoin( | |
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + | |
+ "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) | |
+ | |
+ // skewJoin1 union (skewJoin2 join aggregate) | |
+ // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized | |
+ checkSkewJoin( | |
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + | |
+ "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + | |
+ "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-38162: Optimize one row plan in AQE Optimizer") { | |
+ withTempView("v") { | |
+ spark.sparkContext.parallelize( | |
+ (1 to 4).map(i => TestData(i, i.toString)), 2) | |
+ .toDF("c1", "c2").createOrReplaceTempView("v") | |
+ | |
+ // remove sort | |
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT * FROM v where c1 = 1 order by c1, c2 | |
+ |""".stripMargin) | |
+ assert(findTopLevelSort(origin1).size == 1) | |
+ assert(findTopLevelSort(adaptive1).isEmpty) | |
+ | |
+ // convert group only aggregate to project | |
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) | |
+ |""".stripMargin) | |
+ assert(findTopLevelAggregate(origin2).size == 2) | |
+ assert(findTopLevelAggregate(adaptive2).isEmpty) | |
+ | |
+ // remove distinct in aggregate | |
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) | |
+ |""".stripMargin) | |
+ assert(findTopLevelAggregate(origin3).size == 4) | |
+ assert(findTopLevelAggregate(adaptive3).size == 2) | |
+ | |
+ // do not optimize if the aggregate is inside query stage | |
+ val (origin4, adaptive4) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT distinct c1 FROM v where c1 = 1 | |
+ |""".stripMargin) | |
+ assert(findTopLevelAggregate(origin4).size == 2) | |
+ assert(findTopLevelAggregate(adaptive4).size == 2) | |
+ | |
+ val (origin5, adaptive5) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT sum(distinct c1) FROM v where c1 = 1 | |
+ |""".stripMargin) | |
+ assert(findTopLevelAggregate(origin5).size == 4) | |
+ assert(findTopLevelAggregate(adaptive5).size == 4) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39551: Invalid plan check - invalid broadcast query stage") { | |
+ withSQLConf( | |
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { | |
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult( | |
+ """ | |
+ |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 | |
+ |INNER JOIN testData2 t2 | |
+ |ON t1.b = t2.b AND t1.a = 0 | |
+ |RIGHT OUTER JOIN testData2 t3 | |
+ |ON t1.a > t3.a | |
+ |GROUP BY t3.b | |
+ """.stripMargin | |
+ ) | |
+ assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { | |
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { | |
+ // partitioning: HashPartitioning | |
+ // shuffleOrigin: REPARTITION_BY_NUM | |
+ assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5) | |
+ // shuffleOrigin: REPARTITION_BY_COL | |
+ // The minimum partition number after AQE coalesce is 1 | |
+ assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1) | |
+ // through project | |
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") | |
+ .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5) | |
+ | |
+ // partitioning: RangePartitioning | |
+ // shuffleOrigin: REPARTITION_BY_NUM | |
+ // The minimum partition number of RangePartitioner is 1 | |
+ assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1) | |
+ // shuffleOrigin: REPARTITION_BY_COL | |
+ assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1) | |
+ | |
+ // partitioning: RoundRobinPartitioning | |
+ // shuffleOrigin: REPARTITION_BY_NUM | |
+ assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5) | |
+ // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE | |
+ assert(spark.range(0).repartition().rdd.getNumPartitions == 0) | |
+ // through project | |
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") | |
+ .repartition(5).select($"c2").rdd.getNumPartitions == 5) | |
+ | |
+ // partitioning: SinglePartition | |
+ assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-39915: Ensure the output partitioning is user-specified") { | |
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", | |
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | |
+ val df1 = spark.range(1).selectExpr("id as c1") | |
+ val df2 = spark.range(1).selectExpr("id as c2") | |
+ val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1")) | |
+ assert(df.rdd.getNumPartitions == 3) | |
+ } | |
+ } | |
} | |
/** | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala | |
index a5ac2d5aa7..e876e9d6ff 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala | |
@@ -1377,7 +1377,7 @@ class ArrowConvertersSuite extends SharedSparkSession { | |
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) | |
val ctx = TaskContext.empty() | |
- val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) | |
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) | |
val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) | |
var count = 0 | |
@@ -1398,7 +1398,7 @@ class ArrowConvertersSuite extends SharedSparkSession { | |
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) | |
val ctx = TaskContext.empty() | |
- val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) | |
+ val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, ctx) | |
// Write batches to Arrow stream format as a byte array | |
val out = new ByteArrayOutputStream() | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala | |
index 146d9fcdb9..a88f423ae0 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala | |
@@ -17,8 +17,6 @@ | |
package org.apache.spark.sql.execution.arrow | |
-import org.apache.arrow.vector.IntervalDayVector | |
- | |
import org.apache.spark.SparkFunSuite | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.util._ | |
@@ -30,12 +28,12 @@ class ArrowWriterSuite extends SparkFunSuite { | |
test("simple") { | |
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { | |
- val avroDatatype = dt match { | |
+ val datatype = dt match { | |
case _: DayTimeIntervalType => DayTimeIntervalType() | |
case _: YearMonthIntervalType => YearMonthIntervalType() | |
case tpe => tpe | |
} | |
- val schema = new StructType().add("value", avroDatatype, nullable = true) | |
+ val schema = new StructType().add("value", datatype, nullable = true) | |
val writer = ArrowWriter.create(schema, timeZoneId) | |
assert(writer.schema === schema) | |
@@ -61,6 +59,7 @@ class ArrowWriterSuite extends SparkFunSuite { | |
case BinaryType => reader.getBinary(rowId) | |
case DateType => reader.getInt(rowId) | |
case TimestampType => reader.getLong(rowId) | |
+ case TimestampNTZType => reader.getLong(rowId) | |
case _: YearMonthIntervalType => reader.getInt(rowId) | |
case _: DayTimeIntervalType => reader.getLong(rowId) | |
} | |
@@ -81,6 +80,7 @@ class ArrowWriterSuite extends SparkFunSuite { | |
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) | |
check(DateType, Seq(0, 1, 2, null, 4)) | |
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") | |
+ check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong)) | |
check(NullType, Seq(null, null, null)) | |
DataTypeTestUtils.yearMonthIntervalTypes | |
.foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))) | |
@@ -88,30 +88,6 @@ class ArrowWriterSuite extends SparkFunSuite { | |
Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), (Long.MinValue + 808L)))) | |
} | |
- test("long overflow for DayTimeIntervalType") | |
- { | |
- val schema = new StructType().add("value", DayTimeIntervalType(), nullable = true) | |
- val writer = ArrowWriter.create(schema, null) | |
- val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) | |
- val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector] | |
- | |
- valueVector.set(0, 106751992, 0) | |
- valueVector.set(1, 106751991, Int.MaxValue) | |
- | |
- // first long overflow for test Math.multiplyExact() | |
- val msg = intercept[java.lang.ArithmeticException] { | |
- reader.getLong(0) | |
- }.getMessage | |
- assert(msg.equals("long overflow")) | |
- | |
- // second long overflow for test Math.addExact() | |
- val msg1 = intercept[java.lang.ArithmeticException] { | |
- reader.getLong(1) | |
- }.getMessage | |
- assert(msg1.equals("long overflow")) | |
- writer.root.close() | |
- } | |
- | |
test("get multiple") { | |
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { | |
val avroDatatype = dt match { | |
@@ -139,6 +115,7 @@ class ArrowWriterSuite extends SparkFunSuite { | |
case DoubleType => reader.getDoubles(0, data.size) | |
case DateType => reader.getInts(0, data.size) | |
case TimestampType => reader.getLongs(0, data.size) | |
+ case TimestampNTZType => reader.getLongs(0, data.size) | |
case _: YearMonthIntervalType => reader.getInts(0, data.size) | |
case _: DayTimeIntervalType => reader.getLongs(0, data.size) | |
} | |
@@ -155,6 +132,7 @@ class ArrowWriterSuite extends SparkFunSuite { | |
check(DoubleType, (0 until 10).map(_.toDouble)) | |
check(DateType, (0 until 10)) | |
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") | |
+ check(TimestampNTZType, (0 until 10).map(_ * 4.32e10.toLong)) | |
DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14))) | |
DataTypeTestUtils.dayTimeIntervalTypes.foreach(check(_, (-10 until 10).map(_ * 1000.toLong))) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala | |
new file mode 100644 | |
index 0000000000..0537527b85 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnsiIntervalSortBenchmark.scala | |
@@ -0,0 +1,73 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.benchmark | |
+ | |
+import org.apache.spark.benchmark.Benchmark | |
+import org.apache.spark.sql.internal.SQLConf | |
+ | |
+/** | |
+ * Benchmark to measure performance for interval sort. | |
+ * To run this benchmark: | |
+ * {{{ | |
+ * 1. without sbt: | |
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar> | |
+ * 2. build/sbt "sql/test:runMain <this class>" | |
+ * 3. generate result: | |
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>" | |
+ * Results will be written to "benchmarks/IntervalBenchmark-results.txt". | |
+ * }}} | |
+ */ | |
+object AnsiIntervalSortBenchmark extends SqlBasedBenchmark { | |
+ private val numRows = 100 * 1000 * 1000 | |
+ | |
+ private def radixBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { | |
+ val benchmark = new Benchmark(name, cardinality, output = output) | |
+ benchmark.addCase(s"$name enable radix", 3) { _ => | |
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> "true") { | |
+ f | |
+ } | |
+ } | |
+ | |
+ benchmark.addCase(s"$name disable radix", 3) { _ => | |
+ withSQLConf(SQLConf.RADIX_SORT_ENABLED.key -> "false") { | |
+ f | |
+ } | |
+ } | |
+ benchmark.run() | |
+ } | |
+ | |
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
+ val dt = spark.range(numRows).selectExpr("make_dt_interval(id % 24) as c1", "id as c2") | |
+ radixBenchmark("year month interval one column", numRows) { | |
+ dt.sortWithinPartitions("c1").select("c2").noop() | |
+ } | |
+ | |
+ radixBenchmark("year month interval two columns", numRows) { | |
+ dt.sortWithinPartitions("c1", "c2").select("c2").noop() | |
+ } | |
+ | |
+ val ym = spark.range(numRows).selectExpr("make_ym_interval(id % 2000) as c1", "id as c2") | |
+ radixBenchmark("day time interval one columns", numRows) { | |
+ ym.sortWithinPartitions("c1").select("c2").noop() | |
+ } | |
+ | |
+ radixBenchmark("day time interval two columns", numRows) { | |
+ ym.sortWithinPartitions("c1", "c2").select("c2").noop() | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala | |
new file mode 100644 | |
index 0000000000..eb0b896574 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/Base64Benchmark.scala | |
@@ -0,0 +1,78 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.benchmark | |
+ | |
+import org.apache.spark.benchmark.Benchmark | |
+ | |
+/** | |
+ * Benchmark for measuring perf of different Base64 implementations | |
+ * To run this benchmark: | |
+ * {{{ | |
+ * 1. without sbt: | |
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar> | |
+ * 2. build/sbt "sql/test:runMain <this class>" | |
+ * 3. generate result: | |
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>" | |
+ * Results will be written to "benchmarks/Base64Benchmark-results.txt". | |
+ * }}} | |
+ */ | |
+object Base64Benchmark extends SqlBasedBenchmark { | |
+ import spark.implicits._ | |
+ private val N = 20L * 1000 * 1000 | |
+ | |
+ private def doEncode(len: Int, f: Array[Byte] => Array[Byte]): Unit = { | |
+ spark.range(N).map(_ => "Spark" * len).foreach { s => | |
+ f(s.getBytes) | |
+ () | |
+ } | |
+ } | |
+ | |
+ private def doDecode(len: Int, f: Array[Byte] => Array[Byte]): Unit = { | |
+ spark.range(N).map(_ => "Spark" * len).map { s => | |
+ // using the same encode func | |
+ java.util.Base64.getMimeEncoder.encode(s.getBytes) | |
+ }.foreach { s => | |
+ f(s) | |
+ () | |
+ } | |
+ } | |
+ | |
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
+ Seq(1, 3, 5, 7).map { len => | |
+ val benchmark = new Benchmark(s"encode for $len", N, output = output) | |
+ benchmark.addCase("java", 3) { _ => | |
+ doEncode(len, x => java.util.Base64.getMimeEncoder().encode(x)) | |
+ } | |
+ benchmark.addCase(s"apache", 3) { _ => | |
+ doEncode(len, org.apache.commons.codec.binary.Base64.encodeBase64) | |
+ } | |
+ benchmark | |
+ }.foreach(_.run()) | |
+ | |
+ Seq(1, 3, 5, 7).map { len => | |
+ val benchmark = new Benchmark(s"decode for $len", N, output = output) | |
+ benchmark.addCase("java", 3) { _ => | |
+ doDecode(len, x => java.util.Base64.getMimeDecoder.decode(x)) | |
+ } | |
+ benchmark.addCase(s"apache", 3) { _ => | |
+ doDecode(len, org.apache.commons.codec.binary.Base64.decodeBase64) | |
+ } | |
+ benchmark | |
+ }.foreach(_.run()) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala | |
index f78ccf9569..ccb65c7d3a 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala | |
@@ -19,14 +19,13 @@ package org.apache.spark.sql.execution.benchmark | |
import scala.util.Random | |
+import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat} | |
+ | |
import org.apache.spark.benchmark.Benchmark | |
/** | |
* Benchmark to measure read performance with Bloom filters. | |
* | |
- * Currently, only ORC supports bloom filters, we will add Parquet BM as soon as it becomes | |
- * available. | |
- * | |
* To run this benchmark: | |
* {{{ | |
* 1. without sbt: bin/spark-submit --class <this class> | |
@@ -43,7 +42,7 @@ object BloomFilterBenchmark extends SqlBasedBenchmark { | |
private val N = scaleFactor * 1000 * 1000 | |
private val df = spark.range(N).map(_ => Random.nextInt) | |
- private def writeBenchmark(): Unit = { | |
+ private def writeORCBenchmark(): Unit = { | |
withTempPath { dir => | |
val path = dir.getCanonicalPath | |
@@ -61,7 +60,7 @@ object BloomFilterBenchmark extends SqlBasedBenchmark { | |
} | |
} | |
- private def readBenchmark(): Unit = { | |
+ private def readORCBenchmark(): Unit = { | |
withTempPath { dir => | |
val path = dir.getCanonicalPath | |
@@ -81,8 +80,55 @@ object BloomFilterBenchmark extends SqlBasedBenchmark { | |
} | |
} | |
+ private def writeParquetBenchmark(): Unit = { | |
+ withTempPath { dir => | |
+ val path = dir.getCanonicalPath | |
+ | |
+ runBenchmark("Parquet Write") { | |
+ val benchmark = new Benchmark(s"Write ${scaleFactor}M rows", N, output = output) | |
+ benchmark.addCase("Without bloom filter") { _ => | |
+ df.write.mode("overwrite").parquet(path + "/withoutBF") | |
+ } | |
+ benchmark.addCase("With bloom filter") { _ => | |
+ df.write.mode("overwrite") | |
+ .option(ParquetOutputFormat.BLOOM_FILTER_ENABLED + "#value", true) | |
+ .parquet(path + "/withBF") | |
+ } | |
+ benchmark.run() | |
+ } | |
+ } | |
+ } | |
+ | |
+ private def readParquetBenchmark(): Unit = { | |
+ val blockSizes = Seq(2 * 1024 * 1024, 4 * 1024 * 1024, 6 * 1024 * 1024, 8 * 1024 * 1024, | |
+ 12 * 1024 * 1024, 16 * 1024 * 1024, 32 * 1024 * 1024) | |
+ for (blocksize <- blockSizes) { | |
+ withTempPath { dir => | |
+ val path = dir.getCanonicalPath | |
+ df.write.option("parquet.block.size", blocksize).parquet(path + "/withoutBF") | |
+ df.write.option(ParquetOutputFormat.BLOOM_FILTER_ENABLED + "#value", true) | |
+ .option("parquet.block.size", blocksize) | |
+ .parquet(path + "/withBF") | |
+ | |
+ runBenchmark("Parquet Read") { | |
+ val benchmark = new Benchmark(s"Read a row from ${scaleFactor}M rows", N, output = output) | |
+ benchmark.addCase("Without bloom filter, blocksize: " + blocksize) { _ => | |
+ spark.read.parquet(path + "/withoutBF").where("value = 0").noop() | |
+ } | |
+ benchmark.addCase("With bloom filter, blocksize: " + blocksize) { _ => | |
+ spark.read.option(ParquetInputFormat.BLOOM_FILTERING_ENABLED, true) | |
+ .parquet(path + "/withBF").where("value = 0").noop() | |
+ } | |
+ benchmark.run() | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
- writeBenchmark() | |
- readBenchmark() | |
+ writeORCBenchmark() | |
+ readORCBenchmark() | |
+ writeParquetBenchmark() | |
+ readParquetBenchmark() | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala | |
index 361deb0d3e..45d50b5e11 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala | |
@@ -16,6 +16,9 @@ | |
*/ | |
package org.apache.spark.sql.execution.benchmark | |
+import org.apache.parquet.column.ParquetProperties | |
+import org.apache.parquet.hadoop.ParquetOutputFormat | |
+ | |
import org.apache.spark.sql.internal.SQLConf | |
/** | |
@@ -53,7 +56,16 @@ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { | |
formats.foreach { format => | |
runBenchmark(s"$format writer benchmark") { | |
- runDataSourceBenchmark(format) | |
+ if (format.equals("Parquet")) { | |
+ ParquetProperties.WriterVersion.values().foreach { | |
+ writeVersion => | |
+ withSQLConf(ParquetOutputFormat.WRITER_VERSION -> writeVersion.toString) { | |
+ runDataSourceBenchmark("Parquet", Some(writeVersion.toString)) | |
+ } | |
+ } | |
+ } else { | |
+ runDataSourceBenchmark(format) | |
+ } | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala | |
new file mode 100644 | |
index 0000000000..99016842d8 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ByteArrayBenchmark.scala | |
@@ -0,0 +1,124 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.benchmark | |
+ | |
+import scala.util.Random | |
+ | |
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} | |
+import org.apache.spark.unsafe.array.ByteArrayMethods | |
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | |
+ | |
+/** | |
+ * Benchmark to measure performance for byte array operators. | |
+ * {{{ | |
+ * To run this benchmark: | |
+ * 1. without sbt: | |
+ * bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar> | |
+ * 2. build/sbt "sql/test:runMain <this class>" | |
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>" | |
+ * Results will be written to "benchmarks/<this class>-results.txt". | |
+ * }}} | |
+ */ | |
+object ByteArrayBenchmark extends BenchmarkBase { | |
+ private val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
+ private val randomChar = new Random(0) | |
+ | |
+ def randomBytes(min: Int, max: Int): Array[Byte] = { | |
+ val len = randomChar.nextInt(max - min) + min | |
+ val bytes = new Array[Byte](len) | |
+ var i = 0 | |
+ while (i < len) { | |
+ bytes(i) = chars.charAt(randomChar.nextInt(chars.length())).toByte | |
+ i += 1 | |
+ } | |
+ bytes | |
+ } | |
+ | |
+ def byteArrayComparisons(iters: Long): Unit = { | |
+ val count = 16 * 1000 | |
+ val dataTiny = Seq.fill(count)(randomBytes(2, 7)).toArray | |
+ val dataSmall = Seq.fill(count)(randomBytes(8, 16)).toArray | |
+ val dataMedium = Seq.fill(count)(randomBytes(16, 32)).toArray | |
+ val dataLarge = Seq.fill(count)(randomBytes(512, 1024)).toArray | |
+ val dataLargeSlow = Seq.fill(count)( | |
+ Array.tabulate(512) {i => if (i < 511) 0.toByte else 1.toByte}).toArray | |
+ | |
+ def compareBinary(data: Array[Array[Byte]]) = { _: Int => | |
+ var sum = 0L | |
+ for (_ <- 0L until iters) { | |
+ var i = 0 | |
+ while (i < count) { | |
+ sum += ByteArray.compareBinary(data(i), data((i + 1) % count)) | |
+ i += 1 | |
+ } | |
+ } | |
+ } | |
+ | |
+ val benchmark = new Benchmark("Byte Array compareTo", count * iters, 25, output = output) | |
+ benchmark.addCase("2-7 byte")(compareBinary(dataTiny)) | |
+ benchmark.addCase("8-16 byte")(compareBinary(dataSmall)) | |
+ benchmark.addCase("16-32 byte")(compareBinary(dataMedium)) | |
+ benchmark.addCase("512-1024 byte")(compareBinary(dataLarge)) | |
+ benchmark.addCase("512 byte slow")(compareBinary(dataLargeSlow)) | |
+ benchmark.addCase("2-7 byte")(compareBinary(dataTiny)) | |
+ benchmark.run() | |
+ } | |
+ | |
+ def byteArrayEquals(iters: Long): Unit = { | |
+ def binaryEquals(inputs: Array[BinaryEqualInfo]) = { _: Int => | |
+ var res = false | |
+ for (_ <- 0L until iters) { | |
+ inputs.foreach { input => | |
+ res = ByteArrayMethods.arrayEquals( | |
+ input.s1.getBaseObject, input.s1.getBaseOffset, | |
+ input.s2.getBaseObject, input.s2.getBaseOffset + input.deltaOffset, | |
+ input.len) | |
+ } | |
+ } | |
+ } | |
+ val count = 16 * 1000 | |
+ val rand = new Random(0) | |
+ val inputs = (0 until count).map { _ => | |
+ val s1 = UTF8String.fromBytes(randomBytes(1, 16)) | |
+ val s2 = UTF8String.fromBytes(randomBytes(1, 16)) | |
+ val len = s1.numBytes().min(s2.numBytes()) | |
+ val deltaOffset = rand.nextInt(len) | |
+ BinaryEqualInfo(s1, s2, deltaOffset, len) | |
+ }.toArray | |
+ | |
+ val benchmark = new Benchmark("Byte Array equals", count * iters, 25, output = output) | |
+ benchmark.addCase("Byte Array equals")(binaryEquals(inputs)) | |
+ benchmark.run() | |
+ } | |
+ | |
+ case class BinaryEqualInfo( | |
+ s1: UTF8String, | |
+ s2: UTF8String, | |
+ deltaOffset: Int, | |
+ len: Int) | |
+ | |
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
+ runBenchmark("byte array comparisons") { | |
+ byteArrayComparisons(1024 * 4) | |
+ } | |
+ | |
+ runBenchmark("byte array equals") { | |
+ byteArrayEquals(1000 * 10) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala | |
index 0fc43c7052..b35aa73e14 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala | |
@@ -21,11 +21,14 @@ import java.io.File | |
import scala.collection.JavaConverters._ | |
import scala.util.Random | |
-import org.apache.spark.SparkConf | |
+import org.apache.parquet.column.ParquetProperties | |
+import org.apache.parquet.hadoop.ParquetOutputFormat | |
+ | |
+import org.apache.spark.{SparkConf, TestUtils} | |
import org.apache.spark.benchmark.Benchmark | |
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} | |
import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} | |
+import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.vectorized.ColumnVector | |
@@ -66,16 +69,23 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
try f finally tableNames.foreach(spark.catalog.dropTempView) | |
} | |
- private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { | |
+ private def prepareTable( | |
+ dir: File, | |
+ df: DataFrame, | |
+ partition: Option[String] = None, | |
+ onlyParquetOrc: Boolean = false): Unit = { | |
val testDf = if (partition.isDefined) { | |
df.write.partitionBy(partition.get) | |
} else { | |
df.write | |
} | |
- saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") | |
- saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") | |
- saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") | |
+ if (!onlyParquetOrc) { | |
+ saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") | |
+ saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") | |
+ } | |
+ saveAsParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1") | |
+ saveAsParquetV2Table(testDf, dir.getCanonicalPath + "/parquetV2") | |
saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") | |
} | |
@@ -89,9 +99,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.read.json(dir).createOrReplaceTempView("jsonTable") | |
} | |
- private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { | |
+ private def saveAsParquetV1Table(df: DataFrameWriter[Row], dir: String): Unit = { | |
df.mode("overwrite").option("compression", "snappy").parquet(dir) | |
- spark.read.parquet(dir).createOrReplaceTempView("parquetTable") | |
+ spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") | |
+ } | |
+ | |
+ private def saveAsParquetV2Table(df: DataFrameWriter[Row], dir: String): Unit = { | |
+ withSQLConf(ParquetOutputFormat.WRITER_VERSION -> | |
+ ParquetProperties.WriterVersion.PARQUET_2_0.toString) { | |
+ df.mode("overwrite").option("compression", "snappy").parquet(dir) | |
+ spark.read.parquet(dir).createOrReplaceTempView("parquetV2Table") | |
+ } | |
} | |
private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { | |
@@ -99,6 +117,8 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.read.orc(dir).createOrReplaceTempView("orcTable") | |
} | |
+ private def withParquetVersions(f: String => Unit): Unit = Seq("V1", "V2").foreach(f) | |
+ | |
def numericScanBenchmark(values: Int, dataType: DataType): Unit = { | |
// Benchmarks running through spark sql. | |
val sqlBenchmark = new Benchmark( | |
@@ -113,112 +133,267 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
import spark.implicits._ | |
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") | |
prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) | |
+ val query = dataType match { | |
+ case BooleanType => "sum(cast(id as bigint))" | |
+ case _ => "sum(id)" | |
+ } | |
+ | |
sqlBenchmark.addCase("SQL CSV") { _ => | |
- spark.sql("select sum(id) from csvTable").noop() | |
+ spark.sql(s"select $query from csvTable").noop() | |
} | |
sqlBenchmark.addCase("SQL Json") { _ => | |
- spark.sql("select sum(id) from jsonTable").noop() | |
+ spark.sql(s"select $query from jsonTable").noop() | |
} | |
- sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => | |
- spark.sql("select sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select $query from parquet${version}Table").noop() | |
+ } | |
} | |
- sqlBenchmark.addCase("SQL Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ sqlBenchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select $query from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
sqlBenchmark.addCase("SQL ORC Vectorized") { _ => | |
- spark.sql("SELECT sum(id) FROM orcTable").noop() | |
+ spark.sql(s"SELECT $query FROM orcTable").noop() | |
} | |
sqlBenchmark.addCase("SQL ORC MR") { _ => | |
withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("SELECT sum(id) FROM orcTable").noop() | |
+ spark.sql(s"SELECT $query FROM orcTable").noop() | |
} | |
} | |
sqlBenchmark.run() | |
- // Driving the parquet reader in batch mode directly. | |
- val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray | |
val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled | |
val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize | |
- parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => | |
- var longSum = 0L | |
- var doubleSum = 0.0 | |
- val aggregateValue: (ColumnVector, Int) => Unit = dataType match { | |
- case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) | |
- case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) | |
- case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) | |
- case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) | |
- case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) | |
- case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) | |
- } | |
- | |
- files.map(_.asInstanceOf[String]).foreach { p => | |
- val reader = new VectorizedParquetRecordReader( | |
- enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
- try { | |
- reader.initialize(p, ("id" :: Nil).asJava) | |
- val batch = reader.resultBatch() | |
- val col = batch.column(0) | |
- while (reader.nextBatch()) { | |
- val numRows = batch.numRows() | |
- var i = 0 | |
- while (i < numRows) { | |
- if (!col.isNullAt(i)) aggregateValue(col, i) | |
- i += 1 | |
+ withParquetVersions { version => | |
+ // Driving the parquet reader in batch mode directly. | |
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) | |
+ parquetReaderBenchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ => | |
+ var longSum = 0L | |
+ var doubleSum = 0.0 | |
+ val aggregateValue: (ColumnVector, Int) => Unit = dataType match { | |
+ case BooleanType => | |
+ (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L | |
+ case ByteType => | |
+ (col: ColumnVector, i: Int) => longSum += col.getByte(i) | |
+ case ShortType => | |
+ (col: ColumnVector, i: Int) => longSum += col.getShort(i) | |
+ case IntegerType => | |
+ (col: ColumnVector, i: Int) => longSum += col.getInt(i) | |
+ case LongType => | |
+ (col: ColumnVector, i: Int) => longSum += col.getLong(i) | |
+ case FloatType => | |
+ (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) | |
+ case DoubleType => | |
+ (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) | |
+ } | |
+ | |
+ files.foreach { p => | |
+ val reader = new VectorizedParquetRecordReader( | |
+ enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
+ try { | |
+ reader.initialize(p, ("id" :: Nil).asJava) | |
+ val batch = reader.resultBatch() | |
+ val col = batch.column(0) | |
+ while (reader.nextBatch()) { | |
+ val numRows = batch.numRows() | |
+ var i = 0 | |
+ while (i < numRows) { | |
+ if (!col.isNullAt(i)) aggregateValue(col, i) | |
+ i += 1 | |
+ } | |
} | |
+ } finally { | |
+ reader.close() | |
} | |
- } finally { | |
- reader.close() | |
} | |
} | |
} | |
- // Decoding in vectorized but having the reader return rows. | |
- parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => | |
- var longSum = 0L | |
- var doubleSum = 0.0 | |
- val aggregateValue: (InternalRow) => Unit = dataType match { | |
- case ByteType => (col: InternalRow) => longSum += col.getByte(0) | |
- case ShortType => (col: InternalRow) => longSum += col.getShort(0) | |
- case IntegerType => (col: InternalRow) => longSum += col.getInt(0) | |
- case LongType => (col: InternalRow) => longSum += col.getLong(0) | |
- case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) | |
- case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) | |
- } | |
- | |
- files.map(_.asInstanceOf[String]).foreach { p => | |
- val reader = new VectorizedParquetRecordReader( | |
- enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
- try { | |
- reader.initialize(p, ("id" :: Nil).asJava) | |
- val batch = reader.resultBatch() | |
- while (reader.nextBatch()) { | |
- val it = batch.rowIterator() | |
- while (it.hasNext) { | |
- val record = it.next() | |
- if (!record.isNullAt(0)) aggregateValue(record) | |
+ withParquetVersions { version => | |
+ // Driving the parquet reader in batch mode directly. | |
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) | |
+ // Decoding in vectorized but having the reader return rows. | |
+ parquetReaderBenchmark | |
+ .addCase(s"ParquetReader Vectorized -> Row: DataPage$version") { _ => | |
+ var longSum = 0L | |
+ var doubleSum = 0.0 | |
+ val aggregateValue: (InternalRow) => Unit = dataType match { | |
+ case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L | |
+ case ByteType => (col: InternalRow) => longSum += col.getByte(0) | |
+ case ShortType => (col: InternalRow) => longSum += col.getShort(0) | |
+ case IntegerType => (col: InternalRow) => longSum += col.getInt(0) | |
+ case LongType => (col: InternalRow) => longSum += col.getLong(0) | |
+ case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) | |
+ case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) | |
+ } | |
+ | |
+ files.foreach { p => | |
+ val reader = new VectorizedParquetRecordReader( | |
+ enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
+ try { | |
+ reader.initialize(p, ("id" :: Nil).asJava) | |
+ val batch = reader.resultBatch() | |
+ while (reader.nextBatch()) { | |
+ val it = batch.rowIterator() | |
+ while (it.hasNext) { | |
+ val record = it.next() | |
+ if (!record.isNullAt(0)) aggregateValue(record) | |
+ } | |
+ } | |
+ } finally { | |
+ reader.close() | |
} | |
} | |
- } finally { | |
- reader.close() | |
+ } | |
+ } | |
+ } | |
+ | |
+ parquetReaderBenchmark.run() | |
+ } | |
+ } | |
+ | |
+ /** | |
+ * Similar to [[numericScanBenchmark]] but accessed column is a struct field. | |
+ */ | |
+ def nestedNumericScanBenchmark(values: Int, dataType: DataType): Unit = { | |
+ val sqlBenchmark = new Benchmark( | |
+ s"SQL Single ${dataType.sql} Column Scan in Struct", | |
+ values, | |
+ output = output) | |
+ | |
+ withTempPath { dir => | |
+ withTempTable("t1", "parquetV1Table", "parquetV2Table", "orcTable") { | |
+ import spark.implicits._ | |
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") | |
+ | |
+ prepareTable(dir, | |
+ spark.sql(s"SELECT named_struct('f', CAST(value as ${dataType.sql})) as col FROM t1"), | |
+ onlyParquetOrc = true) | |
+ | |
+ sqlBenchmark.addCase(s"SQL ORC MR") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(col.f) from orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ sqlBenchmark.addCase(s"SQL ORC Vectorized (Nested Column Disabled)") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(col.f) from orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ sqlBenchmark.addCase(s"SQL ORC Vectorized (Nested Column Enabled)") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { | |
+ spark.sql(s"select sum(col.f) from orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ withParquetVersions { version => | |
+ sqlBenchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop() | |
+ } | |
+ } | |
+ | |
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " + | |
+ "(Nested Column Disabled)") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop() | |
+ } | |
+ } | |
+ | |
+ sqlBenchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " + | |
+ "(Nested Column Enabled)") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { | |
+ spark.sql(s"select sum(col.f) from parquet${version}Table").noop() | |
+ } | |
+ } | |
+ } | |
+ | |
+ sqlBenchmark.run() | |
+ } | |
+ } | |
+ } | |
+ | |
+ def nestedColumnScanBenchmark(values: Int): Unit = { | |
+ val benchmark = new Benchmark(s"SQL Nested Column Scan", values, minNumIters = 10, | |
+ output = output) | |
+ | |
+ withTempPath { dir => | |
+ withTempTable("t1", "parquetV1Table", "parquetV2Table", "orcTable") { | |
+ import spark.implicits._ | |
+ spark.range(values).map(_ => Random.nextLong).map { x => | |
+ val arrayOfStructColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) | |
+ val mapOfStructColumn = Map( | |
+ s"$x" -> (x * 0.1, (x, s"$x" * 100)), | |
+ (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), | |
+ (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) | |
+ (arrayOfStructColumn, mapOfStructColumn) | |
+ }.toDF("col1", "col2").createOrReplaceTempView("t1") | |
+ | |
+ prepareTable(dir, spark.sql(s"SELECT * FROM t1"), onlyParquetOrc = true) | |
+ | |
+ benchmark.addCase("SQL ORC MR") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ benchmark.addCase("SQL ORC Vectorized (Nested Column Disabled)") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { | |
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ benchmark.addCase("SQL ORC Vectorized (Nested Column Enabled)") { _ => | |
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { | |
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() | |
+ } | |
+ } | |
+ | |
+ | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table") | |
+ .noop() | |
+ } | |
+ } | |
+ | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " + | |
+ s"(Nested Column Disabled)") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { | |
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table") | |
+ .noop() | |
+ } | |
+ } | |
+ | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version " + | |
+ s"(Nested Column Enabled)") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { | |
+ spark.sql(s"SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquet${version}Table") | |
+ .noop() | |
} | |
} | |
} | |
- parquetReaderBenchmark.run() | |
+ benchmark.run() | |
} | |
} | |
} | |
@@ -227,7 +402,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
val benchmark = new Benchmark("Int and String Scan", values, output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
import spark.implicits._ | |
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") | |
@@ -243,13 +418,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql("select sum(c1), sum(length(c2)) from jsonTable").noop() | |
} | |
- benchmark.addCase("SQL Parquet Vectorized") { _ => | |
- spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("SQL Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(c1), sum(length(c2)) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(c1), sum(length(c2)) from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -272,7 +451,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
val benchmark = new Benchmark("Repeated String", values, output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
import spark.implicits._ | |
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") | |
@@ -288,13 +467,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql("select sum(length(c1)) from jsonTable").noop() | |
} | |
- benchmark.addCase("SQL Parquet Vectorized") { _ => | |
- spark.sql("select sum(length(c1)) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("SQL Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(length(c1)) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(length(c1)) from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -317,7 +500,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
val benchmark = new Benchmark("Partitioned Table", values, output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
import spark.implicits._ | |
spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") | |
@@ -331,13 +514,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql("select sum(id) from jsonTable").noop() | |
} | |
- benchmark.addCase("Data column - Parquet Vectorized") { _ => | |
- spark.sql("select sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Data column - Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(id) from parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("Data column - Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Data column - Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(id) from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -359,13 +546,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql("select sum(p) from jsonTable").noop() | |
} | |
- benchmark.addCase("Partition column - Parquet Vectorized") { _ => | |
- spark.sql("select sum(p) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Partition column - Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(p) from parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("Partition column - Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(p) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Partition column - Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(p) from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -387,13 +578,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql("select sum(p), sum(id) from jsonTable").noop() | |
} | |
- benchmark.addCase("Both columns - Parquet Vectorized") { _ => | |
- spark.sql("select sum(p), sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Both columns - Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("Both columns - Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(p), sum(id) from parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"Both columns - Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(p), sum(id) from parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -418,7 +613,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
spark.range(values).createOrReplaceTempView("t1") | |
prepareTable( | |
@@ -437,39 +632,45 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
"not NULL and c2 is not NULL").noop() | |
} | |
- benchmark.addCase("SQL Parquet Vectorized") { _ => | |
- spark.sql("select sum(length(c2)) from parquetTable where c1 is " + | |
- "not NULL and c2 is not NULL").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " + | |
+ "not NULL and c2 is not NULL").noop() | |
+ } | |
} | |
- benchmark.addCase("SQL Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql("select sum(length(c2)) from parquetTable where c1 is " + | |
- "not NULL and c2 is not NULL").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"select sum(length(c2)) from parquet${version}Table where c1 is " + | |
+ "not NULL and c2 is not NULL").noop() | |
+ } | |
} | |
} | |
- val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray | |
- val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled | |
- val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize | |
- benchmark.addCase("ParquetReader Vectorized") { num => | |
- var sum = 0 | |
- files.map(_.asInstanceOf[String]).foreach { p => | |
- val reader = new VectorizedParquetRecordReader( | |
- enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
- try { | |
- reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) | |
- val batch = reader.resultBatch() | |
- while (reader.nextBatch()) { | |
- val rowIterator = batch.rowIterator() | |
- while (rowIterator.hasNext) { | |
- val row = rowIterator.next() | |
- val value = row.getUTF8String(0) | |
- if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() | |
+ withParquetVersions { version => | |
+ val files = TestUtils.listDirectory(new File(dir, s"parquet$version")) | |
+ val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled | |
+ val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize | |
+ benchmark.addCase(s"ParquetReader Vectorized: DataPage$version") { _ => | |
+ var sum = 0 | |
+ files.foreach { p => | |
+ val reader = new VectorizedParquetRecordReader( | |
+ enableOffHeapColumnVector, vectorizedReaderBatchSize) | |
+ try { | |
+ reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) | |
+ val batch = reader.resultBatch() | |
+ while (reader.nextBatch()) { | |
+ val rowIterator = batch.rowIterator() | |
+ while (rowIterator.hasNext) { | |
+ val row = rowIterator.next() | |
+ val value = row.getUTF8String(0) | |
+ if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() | |
+ } | |
} | |
+ } finally { | |
+ reader.close() | |
} | |
- } finally { | |
- reader.close() | |
} | |
} | |
} | |
@@ -498,7 +699,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
output = output) | |
withTempPath { dir => | |
- withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { | |
+ withTempTable("t1", "csvTable", "jsonTable", "parquetV1Table", "parquetV2Table", "orcTable") { | |
import spark.implicits._ | |
val middle = width / 2 | |
val selectExpr = (1 to width).map(i => s"value as c$i") | |
@@ -515,13 +716,17 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
spark.sql(s"SELECT sum(c$middle) FROM jsonTable").noop() | |
} | |
- benchmark.addCase("SQL Parquet Vectorized") { _ => | |
- spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet Vectorized: DataPage$version") { _ => | |
+ spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop() | |
+ } | |
} | |
- benchmark.addCase("SQL Parquet MR") { _ => | |
- withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
- spark.sql(s"SELECT sum(c$middle) FROM parquetTable").noop() | |
+ withParquetVersions { version => | |
+ benchmark.addCase(s"SQL Parquet MR: DataPage$version") { _ => | |
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { | |
+ spark.sql(s"SELECT sum(c$middle) FROM parquet${version}Table").noop() | |
+ } | |
} | |
} | |
@@ -542,10 +747,18 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { | |
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
runBenchmark("SQL Single Numeric Column Scan") { | |
- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { | |
+ Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { | |
dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) | |
} | |
} | |
+ runBenchmark("SQL Single Numeric Column Scan in Struct") { | |
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { | |
+ dataType => nestedNumericScanBenchmark(1024 * 1024 * 15, dataType) | |
+ } | |
+ } | |
+ runBenchmark("SQL Nested Column Scan") { | |
+ nestedColumnScanBenchmark(1024 * 1024) | |
+ } | |
runBenchmark("Int and String Scan") { | |
intStringScanBenchmark(1024 * 1024 * 10) | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala | |
index 405d60794e..77e26048e0 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala | |
@@ -66,7 +66,7 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark { | |
} | |
} | |
- def runDataSourceBenchmark(format: String): Unit = { | |
+ def runDataSourceBenchmark(format: String, extraInfo: Option[String] = None): Unit = { | |
val tableInt = "tableInt" | |
val tableDouble = "tableDouble" | |
val tableIntString = "tableIntString" | |
@@ -75,7 +75,12 @@ trait DataSourceWriteBenchmark extends SqlBasedBenchmark { | |
withTempTable(tempTable) { | |
spark.range(numRows).createOrReplaceTempView(tempTable) | |
withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { | |
- val benchmark = new Benchmark(s"$format writer benchmark", numRows, output = output) | |
+ val writerName = extraInfo match { | |
+ case Some(extra) => s"$format($extra)" | |
+ case _ => format | |
+ } | |
+ val benchmark = | |
+ new Benchmark(s"$writerName writer benchmark", numRows, output = output) | |
writeNumeric(tableInt, format, benchmark, "Int") | |
writeNumeric(tableDouble, format, benchmark, "Double") | |
writeIntString(tableIntString, format, benchmark) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala | |
index 4e42330088..918f665238 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala | |
@@ -61,7 +61,8 @@ object DateTimeBenchmark extends SqlBasedBenchmark { | |
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
withDefaultTimeZone(LA) { | |
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) { | |
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId, | |
+ SQLConf.LEGACY_INTERVAL_ENABLED.key -> true.toString) { | |
val N = 10000000 | |
runBenchmark("datetime +/- interval") { | |
val benchmark = new Benchmark("datetime +/- interval", N, output = output) | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala | |
index 849c413072..787fdc7b59 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala | |
@@ -44,7 +44,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) | |
codegenBenchmark("Join w long", N) { | |
val df = spark.range(N).join(dim, (col("id") % M) === col("k")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -55,7 +55,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val dim = broadcast(spark.range(M).selectExpr("cast(id/10 as long) as k")) | |
codegenBenchmark("Join w long duplicated", N) { | |
val df = spark.range(N).join(dim, (col("id") % M) === col("k")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -70,7 +70,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val df = spark.range(N).join(dim2, | |
(col("id") % M).cast(IntegerType) === col("k1") | |
&& (col("id") % M).cast(IntegerType) === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -84,7 +84,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
codegenBenchmark("Join w 2 longs", N) { | |
val df = spark.range(N).join(dim3, | |
(col("id") % M) === col("k1") && (col("id") % M) === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -98,7 +98,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
codegenBenchmark("Join w 2 longs duplicated", N) { | |
val df = spark.range(N).join(dim4, | |
(col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -109,7 +109,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) | |
codegenBenchmark("outer join w long", N) { | |
val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "left") | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -120,7 +120,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) | |
codegenBenchmark("semi join w long", N) { | |
val df = spark.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi") | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -131,7 +131,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val df1 = spark.range(N).selectExpr(s"id * 2 as k1") | |
val df2 = spark.range(N).selectExpr(s"id * 3 as k2") | |
val df = df1.join(df2, col("k1") === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -144,7 +144,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val df2 = spark.range(N) | |
.selectExpr(s"(id * 15485867) % ${N*10} as k2") | |
val df = df1.join(df2, col("k1") === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -159,7 +159,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val df1 = spark.range(N).selectExpr(s"id as k1") | |
val df2 = spark.range(N / 3).selectExpr(s"id * 3 as k2") | |
val df = df1.join(df2, col("k1") === col("k2")) | |
- assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[ShuffledHashJoinExec])) | |
df.noop() | |
} | |
} | |
@@ -172,8 +172,7 @@ object JoinBenchmark extends SqlBasedBenchmark { | |
val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v")) | |
codegenBenchmark("broadcast nested loop join", N) { | |
val df = spark.range(N).join(dim) | |
- assert(df.queryExecution.sparkPlan.find( | |
- _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined) | |
+ assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[BroadcastNestedLoopJoinExec])) | |
df.noop() | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala | |
index e9bdff5853..31d5fd9ffd 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala | |
@@ -49,7 +49,7 @@ object RangeBenchmark extends SqlBasedBenchmark { | |
} | |
benchmark.addCase("filter after range", numIters = 4) { _ => | |
- spark.range(N).filter('id % 100 === 0).noop() | |
+ spark.range(N).filter(Symbol("id") % 100 === 0).noop() | |
} | |
benchmark.addCase("count after range", numIters = 4) { _ => | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala | |
index f84172278b..78d6b01580 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala | |
@@ -18,6 +18,7 @@ | |
package org.apache.spark.sql.execution.benchmark | |
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} | |
+import org.apache.spark.internal.config.MAX_RESULT_SIZE | |
import org.apache.spark.internal.config.UI.UI_ENABLED | |
import org.apache.spark.sql.{Dataset, SparkSession} | |
import org.apache.spark.sql.SaveMode.Overwrite | |
@@ -41,6 +42,7 @@ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper { | |
.config(SQLConf.SHUFFLE_PARTITIONS.key, 1) | |
.config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1) | |
.config(UI_ENABLED.key, false) | |
+ .config(MAX_RESULT_SIZE.key, "3g") | |
.getOrCreate() | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala | |
new file mode 100644 | |
index 0000000000..a98c8d8a23 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala | |
@@ -0,0 +1,370 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.benchmark | |
+ | |
+import scala.util.Random | |
+ | |
+import org.apache.hadoop.conf.Configuration | |
+ | |
+import org.apache.spark.benchmark.Benchmark | |
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} | |
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProvider} | |
+import org.apache.spark.sql.internal.SQLConf | |
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType} | |
+import org.apache.spark.util.Utils | |
+ | |
+/** | |
+ * Synthetic benchmark for State Store basic operations. | |
+ * To run this benchmark: | |
+ * {{{ | |
+ * 1. without sbt: | |
+ * bin/spark-submit --class <this class> | |
+ * --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar> | |
+ * 2. build/sbt "sql/test:runMain <this class>" | |
+ * 3. generate result: | |
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>" | |
+ * Results will be written to "benchmarks/StateStoreBasicOperationsBenchmark-results.txt". | |
+ * }}} | |
+ */ | |
+object StateStoreBasicOperationsBenchmark extends SqlBasedBenchmark { | |
+ | |
+ private val keySchema = StructType( | |
+ Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true))) | |
+ private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) | |
+ | |
+ private val keyProjection = UnsafeProjection.create(keySchema) | |
+ private val valueProjection = UnsafeProjection.create(valueSchema) | |
+ | |
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
+ runPutBenchmark() | |
+ runDeleteBenchmark() | |
+ runEvictBenchmark() | |
+ } | |
+ | |
+ final def skip(benchmarkName: String)(func: => Any): Unit = { | |
+ output.foreach(_.write(s"$benchmarkName is skipped".getBytes)) | |
+ } | |
+ | |
+ private def runPutBenchmark(): Unit = { | |
+ def registerPutBenchmarkCase( | |
+ benchmark: Benchmark, | |
+ testName: String, | |
+ provider: StateStoreProvider, | |
+ version: Long, | |
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { | |
+ benchmark.addTimerCase(testName) { timer => | |
+ val store = provider.getStore(version) | |
+ | |
+ timer.startTiming() | |
+ updateRows(store, rows) | |
+ timer.stopTiming() | |
+ | |
+ store.abort() | |
+ } | |
+ } | |
+ | |
+ runBenchmark("put rows") { | |
+ val numOfRows = Seq(10000) | |
+ val overwriteRates = Seq(100, 75, 50, 25, 10, 5, 0) | |
+ | |
+ numOfRows.foreach { numOfRow => | |
+ val testData = constructRandomizedTestData(numOfRow, | |
+ (1 to numOfRow).map(_ * 1000L).toList, 0) | |
+ | |
+ val inMemoryProvider = newHDFSBackedStateStoreProvider() | |
+ val rocksDBProvider = newRocksDBStateProvider() | |
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) | |
+ | |
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) | |
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) | |
+ val committedRocksDBWithNoTrackVersion = loadInitialData( | |
+ rocksDBWithNoTrackProvider, testData) | |
+ | |
+ overwriteRates.foreach { overwriteRate => | |
+ val numOfRowsToOverwrite = numOfRow * overwriteRate / 100 | |
+ | |
+ val numOfNewRows = numOfRow - numOfRowsToOverwrite | |
+ val newRows = if (numOfNewRows > 0) { | |
+ constructRandomizedTestData(numOfNewRows, | |
+ (1 to numOfNewRows).map(_ * 1000L).toList, 0) | |
+ } else { | |
+ Seq.empty[(UnsafeRow, UnsafeRow)] | |
+ } | |
+ val existingRows = if (numOfRowsToOverwrite > 0) { | |
+ Random.shuffle(testData).take(numOfRowsToOverwrite) | |
+ } else { | |
+ Seq.empty[(UnsafeRow, UnsafeRow)] | |
+ } | |
+ val rowsToPut = Random.shuffle(newRows ++ existingRows) | |
+ | |
+ val benchmark = new Benchmark(s"putting $numOfRow rows " + | |
+ s"($numOfRowsToOverwrite rows to overwrite - rate $overwriteRate)", | |
+ numOfRow, minNumIters = 10000, output = output) | |
+ | |
+ registerPutBenchmarkCase(benchmark, "In-memory", inMemoryProvider, | |
+ committedInMemoryVersion, rowsToPut) | |
+ registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", | |
+ rocksDBProvider, committedRocksDBVersion, rowsToPut) | |
+ registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", | |
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, rowsToPut) | |
+ | |
+ benchmark.run() | |
+ } | |
+ | |
+ inMemoryProvider.close() | |
+ rocksDBProvider.close() | |
+ rocksDBWithNoTrackProvider.close() | |
+ } | |
+ } | |
+ } | |
+ | |
+ private def runDeleteBenchmark(): Unit = { | |
+ def registerDeleteBenchmarkCase( | |
+ benchmark: Benchmark, | |
+ testName: String, | |
+ provider: StateStoreProvider, | |
+ version: Long, | |
+ keys: Seq[UnsafeRow]): Unit = { | |
+ benchmark.addTimerCase(testName) { timer => | |
+ val store = provider.getStore(version) | |
+ | |
+ timer.startTiming() | |
+ deleteRows(store, keys) | |
+ timer.stopTiming() | |
+ | |
+ store.abort() | |
+ } | |
+ } | |
+ | |
+ runBenchmark("delete rows") { | |
+ val numOfRows = Seq(10000) | |
+ val nonExistRates = Seq(100, 75, 50, 25, 10, 5, 0) | |
+ numOfRows.foreach { numOfRow => | |
+ val testData = constructRandomizedTestData(numOfRow, | |
+ (1 to numOfRow).map(_ * 1000L).toList, 0) | |
+ | |
+ val inMemoryProvider = newHDFSBackedStateStoreProvider() | |
+ val rocksDBProvider = newRocksDBStateProvider() | |
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) | |
+ | |
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) | |
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) | |
+ val committedRocksDBWithNoTrackVersion = loadInitialData( | |
+ rocksDBWithNoTrackProvider, testData) | |
+ | |
+ nonExistRates.foreach { nonExistRate => | |
+ val numOfRowsNonExist = numOfRow * nonExistRate / 100 | |
+ | |
+ val numOfExistingRows = numOfRow - numOfRowsNonExist | |
+ val nonExistingRows = if (numOfRowsNonExist > 0) { | |
+ constructRandomizedTestData(numOfRowsNonExist, | |
+ (numOfRow + 1 to numOfRow + numOfRowsNonExist).map(_ * 1000L).toList, 0) | |
+ } else { | |
+ Seq.empty[(UnsafeRow, UnsafeRow)] | |
+ } | |
+ val existingRows = if (numOfExistingRows > 0) { | |
+ Random.shuffle(testData).take(numOfExistingRows) | |
+ } else { | |
+ Seq.empty[(UnsafeRow, UnsafeRow)] | |
+ } | |
+ val keysToDelete = Random.shuffle(nonExistingRows ++ existingRows).map(_._1) | |
+ | |
+ val benchmark = new Benchmark(s"trying to delete $numOfRow rows " + | |
+ s"from $numOfRow rows" + | |
+ s"($numOfRowsNonExist rows are non-existing - rate $nonExistRate)", | |
+ numOfRow, minNumIters = 10000, output = output) | |
+ | |
+ registerDeleteBenchmarkCase(benchmark, "In-memory", inMemoryProvider, | |
+ committedInMemoryVersion, keysToDelete) | |
+ registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", | |
+ rocksDBProvider, committedRocksDBVersion, keysToDelete) | |
+ registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", | |
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, keysToDelete) | |
+ | |
+ benchmark.run() | |
+ } | |
+ | |
+ inMemoryProvider.close() | |
+ rocksDBProvider.close() | |
+ rocksDBWithNoTrackProvider.close() | |
+ } | |
+ } | |
+ } | |
+ | |
+ private def runEvictBenchmark(): Unit = { | |
+ def registerEvictBenchmarkCase( | |
+ benchmark: Benchmark, | |
+ testName: String, | |
+ provider: StateStoreProvider, | |
+ version: Long, | |
+ maxTimestampToEvictInMillis: Long, | |
+ expectedNumOfRows: Long): Unit = { | |
+ benchmark.addTimerCase(testName) { timer => | |
+ val store = provider.getStore(version) | |
+ | |
+ timer.startTiming() | |
+ evictAsFullScanAndRemove(store, maxTimestampToEvictInMillis, | |
+ expectedNumOfRows) | |
+ timer.stopTiming() | |
+ | |
+ store.abort() | |
+ } | |
+ } | |
+ | |
+ runBenchmark("evict rows") { | |
+ val numOfRows = Seq(10000) | |
+ val numOfEvictionRates = Seq(100, 75, 50, 25, 10, 5, 0) | |
+ | |
+ numOfRows.foreach { numOfRow => | |
+ val timestampsInMicros = (0L until numOfRow).map(ts => ts * 1000L).toList | |
+ | |
+ val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0) | |
+ | |
+ val inMemoryProvider = newHDFSBackedStateStoreProvider() | |
+ val rocksDBProvider = newRocksDBStateProvider() | |
+ val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) | |
+ | |
+ val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) | |
+ val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) | |
+ val committedRocksDBWithNoTrackVersion = loadInitialData( | |
+ rocksDBWithNoTrackProvider, testData) | |
+ | |
+ numOfEvictionRates.foreach { numOfEvictionRate => | |
+ val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100 | |
+ val maxTimestampToEvictInMillis = timestampsInMicros | |
+ .take(numOfRow * numOfEvictionRate / 100) | |
+ .lastOption.map(_ / 1000).getOrElse(-1L) | |
+ | |
+ val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " + | |
+ s"(maxTimestampToEvictInMillis: $maxTimestampToEvictInMillis) " + | |
+ s"from $numOfRow rows", | |
+ numOfRow, minNumIters = 10000, output = output) | |
+ | |
+ registerEvictBenchmarkCase(benchmark, "In-memory", inMemoryProvider, | |
+ committedInMemoryVersion, maxTimestampToEvictInMillis, numOfRowsToEvict) | |
+ | |
+ registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", | |
+ rocksDBProvider, committedRocksDBVersion, maxTimestampToEvictInMillis, | |
+ numOfRowsToEvict) | |
+ | |
+ registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", | |
+ rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, | |
+ maxTimestampToEvictInMillis, numOfRowsToEvict) | |
+ | |
+ benchmark.run() | |
+ } | |
+ | |
+ inMemoryProvider.close() | |
+ rocksDBProvider.close() | |
+ rocksDBWithNoTrackProvider.close() | |
+ } | |
+ } | |
+ } | |
+ | |
+ private def getRows(store: StateStore, keys: Seq[UnsafeRow]): Seq[UnsafeRow] = { | |
+ keys.map(store.get) | |
+ } | |
+ | |
+ private def loadInitialData( | |
+ provider: StateStoreProvider, | |
+ data: Seq[(UnsafeRow, UnsafeRow)]): Long = { | |
+ val store = provider.getStore(0) | |
+ updateRows(store, data) | |
+ store.commit() | |
+ } | |
+ | |
+ private def updateRows( | |
+ store: StateStore, | |
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { | |
+ rows.foreach { case (key, value) => | |
+ store.put(key, value) | |
+ } | |
+ } | |
+ | |
+ private def deleteRows( | |
+ store: StateStore, | |
+ rows: Seq[UnsafeRow]): Unit = { | |
+ rows.foreach { key => | |
+ store.remove(key) | |
+ } | |
+ } | |
+ | |
+ private def evictAsFullScanAndRemove( | |
+ store: StateStore, | |
+ maxTimestampToEvictMillis: Long, | |
+ expectedNumOfRows: Long): Unit = { | |
+ var removedRows: Long = 0 | |
+ store.iterator().foreach { r => | |
+ if (r.key.getLong(1) <= maxTimestampToEvictMillis * 1000L) { | |
+ store.remove(r.key) | |
+ removedRows += 1 | |
+ } | |
+ } | |
+ assert(removedRows == expectedNumOfRows, | |
+ s"expected: $expectedNumOfRows actual: $removedRows") | |
+ } | |
+ | |
+ // This prevents created keys to be in order, which may affect the performance on RocksDB. | |
+ private def constructRandomizedTestData( | |
+ numRows: Int, | |
+ timestamps: List[Long], | |
+ minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = { | |
+ assert(numRows >= timestamps.length) | |
+ assert(numRows % timestamps.length == 0) | |
+ | |
+ (1 to numRows).map { idx => | |
+ val keyRow = new GenericInternalRow(2) | |
+ keyRow.setInt(0, Random.nextInt(Int.MaxValue)) | |
+ keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds | |
+ val valueRow = new GenericInternalRow(1) | |
+ valueRow.setInt(0, minIdx + idx) | |
+ | |
+ val keyUnsafeRow = keyProjection(keyRow).copy() | |
+ val valueUnsafeRow = valueProjection(valueRow).copy() | |
+ | |
+ (keyUnsafeRow, valueUnsafeRow) | |
+ } | |
+ } | |
+ | |
+ private def newHDFSBackedStateStoreProvider(): StateStoreProvider = { | |
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0) | |
+ val provider = new HDFSBackedStateStoreProvider() | |
+ val storeConf = new StateStoreConf(new SQLConf()) | |
+ provider.init( | |
+ storeId, keySchema, valueSchema, 0, | |
+ storeConf, new Configuration) | |
+ provider | |
+ } | |
+ | |
+ private def newRocksDBStateProvider( | |
+ trackTotalNumberOfRows: Boolean = true): StateStoreProvider = { | |
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0) | |
+ val provider = new RocksDBStateStoreProvider() | |
+ val sqlConf = new SQLConf() | |
+ sqlConf.setConfString("spark.sql.streaming.stateStore.rocksdb.trackTotalNumberOfRows", | |
+ trackTotalNumberOfRows.toString) | |
+ val storeConf = new StateStoreConf(sqlConf) | |
+ | |
+ provider.init( | |
+ storeId, keySchema, valueSchema, 0, | |
+ storeConf, new Configuration) | |
+ provider | |
+ } | |
+ | |
+ private def newDir(): String = Utils.createTempDir().toString | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala | |
index 099a1aa996..645dc870d2 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala | |
@@ -17,11 +17,13 @@ | |
package org.apache.spark.sql.execution.columnar | |
+import scala.collection.JavaConverters._ | |
+ | |
import org.apache.spark.SparkConf | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.{QueryTest, Row} | |
import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} | |
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection} | |
import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} | |
import org.apache.spark.sql.execution.columnar.InMemoryRelation.clearSerializer | |
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector | |
@@ -101,7 +103,13 @@ class TestSingleIntColumnarCachedBatchSerializer extends CachedBatchSerializer { | |
cacheAttributes: Seq[Attribute], | |
selectedAttributes: Seq[Attribute], | |
conf: SQLConf): RDD[InternalRow] = { | |
- throw new IllegalStateException("This does not work. This is only for testing") | |
+ convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) | |
+ .mapPartitionsInternal { batches => | |
+ val toUnsafe = UnsafeProjection.create(selectedAttributes, selectedAttributes) | |
+ batches.flatMap { batch => | |
+ batch.rowIterator().asScala.map(toUnsafe) | |
+ } | |
+ } | |
} | |
override def buildFilter( | |
@@ -138,8 +146,9 @@ class CachedBatchSerializerSuite extends QueryTest with SharedSparkSession { | |
input.write.parquet(workDirPath) | |
val data = spark.read.parquet(workDirPath) | |
data.cache() | |
- assert(data.count() == 3) | |
- checkAnswer(data, Row(100) :: Row(200) :: Row(300) :: Nil) | |
+ val df = data.union(data) | |
+ assert(df.count() == 6) | |
+ checkAnswer(df, Row(100) :: Row(200) :: Row(300) :: Row(100) :: Row(200) :: Row(300) :: Nil) | |
} | |
} | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala | |
new file mode 100644 | |
index 0000000000..55d9fb2731 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarBenchmark.scala | |
@@ -0,0 +1,68 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+package org.apache.spark.sql.execution.columnar | |
+ | |
+import org.apache.spark.benchmark.Benchmark | |
+import org.apache.spark.sql.execution.ColumnarToRowExec | |
+import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark | |
+ | |
+/** | |
+ * Benchmark to low level memory access using different ways to manage buffers. | |
+ * To run this benchmark: | |
+ * {{{ | |
+ * 1. without sbt: | |
+ * bin/spark-submit --class <this class> | |
+ * --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar> <rowsNum> | |
+ * 2. build/sbt "sql/Test/runMain <this class> <rowsNum>" | |
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class> | |
+ * <rowsNum>" | |
+ * Results will be written to "benchmarks/InMemoryColumnarBenchmark-results.txt". | |
+ * }}} | |
+ */ | |
+object InMemoryColumnarBenchmark extends SqlBasedBenchmark { | |
+ def intCache(rowsNum: Long, numIters: Int): Unit = { | |
+ val data = spark.range(0, rowsNum, 1, 1).toDF("i").cache() | |
+ | |
+ val inMemoryScan = data.queryExecution.executedPlan.collect { | |
+ case m: InMemoryTableScanExec => m | |
+ } | |
+ | |
+ val columnarScan = ColumnarToRowExec(inMemoryScan(0)) | |
+ val rowScan = inMemoryScan(0) | |
+ | |
+ assert(inMemoryScan.size == 1) | |
+ | |
+ val benchmark = new Benchmark("Int In-Memory scan", rowsNum, output = output) | |
+ | |
+ benchmark.addCase("columnar deserialization + columnar-to-row", numIters) { _ => | |
+ columnarScan.executeCollect() | |
+ } | |
+ | |
+ benchmark.addCase("row-based deserialization", numIters) { _ => | |
+ rowScan.executeCollect() | |
+ } | |
+ | |
+ benchmark.run() | |
+ } | |
+ | |
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
+ val rowsNum = if (mainArgs.length > 0) mainArgs(0).toLong else 1000000 | |
+ runBenchmark(s"Int In-memory with $rowsNum rows") { | |
+ intCache(rowsNum = rowsNum, numIters = 3) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala | |
index b8f73f4563..779aa49a34 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala | |
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar | |
import java.nio.charset.StandardCharsets | |
import java.sql.{Date, Timestamp} | |
+import java.util.concurrent.atomic.AtomicInteger | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.{DataFrame, QueryTest, Row} | |
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In} | |
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning | |
import org.apache.spark.sql.columnar.CachedBatch | |
-import org.apache.spark.sql.execution.{ColumnarToRowExec, FilterExec, InputAdapter, WholeStageCodegenExec} | |
+import org.apache.spark.sql.execution.{FilterExec, InputAdapter, WholeStageCodegenExec} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.SharedSparkSession | |
@@ -152,7 +153,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { | |
} | |
test("projection") { | |
- val logicalPlan = testData.select('value, 'key).logicalPlan | |
+ val logicalPlan = testData.select(Symbol("value"), Symbol("key")).logicalPlan | |
val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan | |
val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), | |
MEMORY_ONLY, plan, None, logicalPlan) | |
@@ -504,8 +505,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { | |
val df2 = df1.where("y = 3") | |
val planBeforeFilter = df2.queryExecution.executedPlan.collect { | |
- case FilterExec(_, c: ColumnarToRowExec) => c.child | |
- case WholeStageCodegenExec(FilterExec(_, ColumnarToRowExec(i: InputAdapter))) => i.child | |
+ case f: FilterExec => f.child | |
+ case WholeStageCodegenExec(FilterExec(_, i: InputAdapter)) => i.child | |
} | |
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) | |
@@ -563,4 +564,56 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { | |
} | |
} | |
} | |
+ | |
+ test("SPARK-39104: InMemoryRelation#isCachedColumnBuffersLoaded should be thread-safe") { | |
+ val plan = spark.range(1).queryExecution.executedPlan | |
+ val serializer = new TestCachedBatchSerializer(true, 1) | |
+ val cachedRDDBuilder = CachedRDDBuilder(serializer, MEMORY_ONLY, plan, None) | |
+ | |
+ @volatile var isCachedColumnBuffersLoaded = false | |
+ @volatile var stopped = false | |
+ | |
+ val th1 = new Thread { | |
+ override def run(): Unit = { | |
+ while (!isCachedColumnBuffersLoaded && !stopped) { | |
+ cachedRDDBuilder.cachedColumnBuffers | |
+ cachedRDDBuilder.clearCache() | |
+ } | |
+ } | |
+ } | |
+ | |
+ val th2 = new Thread { | |
+ override def run(): Unit = { | |
+ while (!isCachedColumnBuffersLoaded && !stopped) { | |
+ isCachedColumnBuffersLoaded = cachedRDDBuilder.isCachedColumnBuffersLoaded | |
+ } | |
+ } | |
+ } | |
+ | |
+ val th3 = new Thread { | |
+ override def run(): Unit = { | |
+ Thread.sleep(3000L) | |
+ stopped = true | |
+ } | |
+ } | |
+ | |
+ val exceptionCnt = new AtomicInteger | |
+ val exceptionHandler: Thread.UncaughtExceptionHandler = (_: Thread, cause: Throwable) => { | |
+ exceptionCnt.incrementAndGet | |
+ fail(cause) | |
+ } | |
+ | |
+ th1.setUncaughtExceptionHandler(exceptionHandler) | |
+ th2.setUncaughtExceptionHandler(exceptionHandler) | |
+ th1.start() | |
+ th2.start() | |
+ th3.start() | |
+ th1.join() | |
+ th2.join() | |
+ th3.join() | |
+ | |
+ cachedRDDBuilder.clearCache() | |
+ | |
+ assert(exceptionCnt.get == 0) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala | |
new file mode 100644 | |
index 0000000000..bc1ffb93fe | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala | |
@@ -0,0 +1,41 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace} | |
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan | |
+import org.apache.spark.sql.catalyst.plans.logical.SetNamespaceLocation | |
+ | |
+class AlterNamespaceSetLocationParserSuite extends AnalysisTest { | |
+ test("set namespace location") { | |
+ comparePlans( | |
+ parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"), | |
+ SetNamespaceLocation( | |
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) | |
+ | |
+ comparePlans( | |
+ parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"), | |
+ SetNamespaceLocation( | |
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) | |
+ | |
+ comparePlans( | |
+ parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"), | |
+ SetNamespaceLocation( | |
+ UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..25bae01821 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala | |
@@ -0,0 +1,83 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import org.apache.spark.sql.{AnalysisException, QueryTest} | |
+import org.apache.spark.sql.connector.catalog.SupportsNamespaces | |
+ | |
+/** | |
+ * This base suite contains unified tests for the `ALTER NAMESPACE ... SET LOCATION` command that | |
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are located | |
+ * in more specific test suites: | |
+ * | |
+ * - V2 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v2.AlterNamespaceSetLocationSuite` | |
+ * - V1 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuiteBase` | |
+ * - V1 In-Memory catalog: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuite` | |
+ * - V1 Hive External catalog: | |
+ * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetLocationSuite` | |
+ */ | |
+trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUtils { | |
+ override val command = "ALTER NAMESPACE ... SET LOCATION" | |
+ | |
+ protected def namespace: String | |
+ | |
+ protected def notFoundMsgPrefix: String | |
+ | |
+ test("Empty location string") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ val message = intercept[IllegalArgumentException] { | |
+ sql(s"ALTER NAMESPACE $ns SET LOCATION ''") | |
+ }.getMessage | |
+ assert(message.contains("Can not create a Path from an empty string")) | |
+ } | |
+ } | |
+ | |
+ test("Namespace does not exist") { | |
+ val ns = "not_exist" | |
+ val message = intercept[AnalysisException] { | |
+ sql(s"ALTER DATABASE $catalog.$ns SET LOCATION 'loc'") | |
+ }.getMessage | |
+ assert(message.contains(s"$notFoundMsgPrefix '$ns' not found")) | |
+ } | |
+ | |
+ // Hive catalog does not support "ALTER NAMESPACE ... SET LOCATION", thus | |
+ // this is called from non-Hive v1 and v2 tests. | |
+ protected def runBasicTest(): Unit = { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE IF NOT EXISTS $ns COMMENT " + | |
+ "'test namespace' LOCATION '/tmp/loc_test_1'") | |
+ sql(s"ALTER NAMESPACE $ns SET LOCATION '/tmp/loc_test_2'") | |
+ assert(getLocation(ns).contains("file:/tmp/loc_test_2")) | |
+ } | |
+ } | |
+ | |
+ protected def getLocation(namespace: String): String = { | |
+ val locationRow = sql(s"DESCRIBE NAMESPACE EXTENDED $namespace") | |
+ .toDF("key", "value") | |
+ .where(s"key like '${SupportsNamespaces.PROP_LOCATION.capitalize}%'") | |
+ .collect() | |
+ assert(locationRow.length == 1) | |
+ locationRow(0).getString(1) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala | |
new file mode 100644 | |
index 0000000000..868dc275b8 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala | |
@@ -0,0 +1,49 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace} | |
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan | |
+import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.catalyst.plans.logical.SetNamespaceProperties | |
+ | |
+class AlterNamespaceSetPropertiesParserSuite extends AnalysisTest { | |
+ test("set namespace properties") { | |
+ Seq("DATABASE", "SCHEMA", "NAMESPACE").foreach { nsToken => | |
+ Seq("PROPERTIES", "DBPROPERTIES").foreach { propToken => | |
+ comparePlans( | |
+ parsePlan(s"ALTER $nsToken a.b.c SET $propToken ('a'='a', 'b'='b', 'c'='c')"), | |
+ SetNamespaceProperties( | |
+ UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a", "b" -> "b", "c" -> "c"))) | |
+ | |
+ comparePlans( | |
+ parsePlan(s"ALTER $nsToken a.b.c SET $propToken ('a'='a')"), | |
+ SetNamespaceProperties( | |
+ UnresolvedNamespace(Seq("a", "b", "c")), Map("a" -> "a"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("property values must be set") { | |
+ val e = intercept[ParseException] { | |
+ parsePlan("ALTER NAMESPACE my_db SET PROPERTIES('key_without_value', 'key_with_value'='x')") | |
+ } | |
+ assert(e.getMessage.contains( | |
+ "Operation not allowed: Values must be specified for key(s): [key_without_value]")) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..1351d09e03 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala | |
@@ -0,0 +1,117 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import org.apache.spark.sql.{AnalysisException, QueryTest} | |
+import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces} | |
+import org.apache.spark.sql.internal.SQLConf | |
+ | |
+/** | |
+ * This base suite contains unified tests for the `ALTER NAMESPACE ... SET PROPERTIES` command that | |
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are located | |
+ * in more specific test suites: | |
+ * | |
+ * - V2 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v2.AlterNamespaceSetPropertiesSuite` | |
+ * - V1 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetPropertiesSuiteBase` | |
+ * - V1 In-Memory catalog: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetPropertiesSuite` | |
+ * - V1 Hive External catalog: | |
+ * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetPropertiesSuite` | |
+ */ | |
+trait AlterNamespaceSetPropertiesSuiteBase extends QueryTest with DDLCommandTestUtils { | |
+ override val command = "ALTER NAMESPACE ... SET PROPERTIES" | |
+ | |
+ protected def namespace: String | |
+ | |
+ protected def notFoundMsgPrefix: String | |
+ | |
+ test("Namespace does not exist") { | |
+ val ns = "not_exist" | |
+ val message = intercept[AnalysisException] { | |
+ sql(s"ALTER DATABASE $catalog.$ns SET PROPERTIES ('d'='d')") | |
+ }.getMessage | |
+ assert(message.contains(s"$notFoundMsgPrefix '$ns' not found")) | |
+ } | |
+ | |
+ test("basic test") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ assert(getProperties(ns) === "") | |
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('a'='a', 'b'='b', 'c'='c')") | |
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c))") | |
+ sql(s"ALTER DATABASE $ns SET PROPERTIES ('d'='d')") | |
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c), (d,d))") | |
+ } | |
+ } | |
+ | |
+ test("test with properties set while creating namespace") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns WITH PROPERTIES ('a'='a','b'='b','c'='c')") | |
+ assert(getProperties(ns) === "((a,a), (b,b), (c,c))") | |
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('a'='b', 'b'='a')") | |
+ assert(getProperties(ns) === "((a,b), (b,a), (c,c))") | |
+ } | |
+ } | |
+ | |
+ test("test reserved properties") { | |
+ import SupportsNamespaces._ | |
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
+ val ns = s"$catalog.$namespace" | |
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { | |
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ val exception = intercept[ParseException] { | |
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('$key'='dummyVal')") | |
+ } | |
+ assert(exception.getMessage.contains(s"$key is a reserved namespace property")) | |
+ } | |
+ } | |
+ } | |
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) { | |
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
+ withNamespace(ns) { | |
+ // Set the location explicitly because v2 catalog may not set the default location. | |
+ // Without this, `meta.get(key)` below may return null. | |
+ sql(s"CREATE NAMESPACE $ns LOCATION 'tmp/prop_test'") | |
+ assert(getProperties(ns) === "") | |
+ sql(s"ALTER NAMESPACE $ns SET PROPERTIES ('$key'='foo')") | |
+ assert(getProperties(ns) === "", s"$key is a reserved namespace property and ignored") | |
+ val meta = spark.sessionState.catalogManager.catalog(catalog) | |
+ .asNamespaceCatalog.loadNamespaceMetadata(namespace.split('.')) | |
+ assert(!meta.get(key).contains("foo"), | |
+ "reserved properties should not have side effects") | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ protected def getProperties(namespace: String): String = { | |
+ val propsRow = sql(s"DESCRIBE NAMESPACE EXTENDED $namespace") | |
+ .toDF("key", "value") | |
+ .where("key like 'Properties%'") | |
+ .collect() | |
+ assert(propsRow.length == 1) | |
+ propsRow(0).getString(1) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..9a9b3378e8 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddColumnsSuiteBase.scala | |
@@ -0,0 +1,53 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import java.time.{Duration, Period} | |
+ | |
+import org.apache.spark.sql.{QueryTest, Row} | |
+ | |
+/** | |
+ * This base suite contains unified tests for the `ALTER TABLE .. ADD COLUMNS` command that | |
+ * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are | |
+ * located in more specific test suites: | |
+ * | |
+ * - V2 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v2.AlterTableAddColumnsSuite` | |
+ * - V1 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterTableAddColumnsSuiteBase` | |
+ * - V1 In-Memory catalog: | |
+ * `org.apache.spark.sql.execution.command.v1.AlterTableAddColumnsSuite` | |
+ * - V1 Hive External catalog: | |
+ * `org.apache.spark.sql.hive.execution.command.AlterTableAddColumnsSuite` | |
+ */ | |
+trait AlterTableAddColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { | |
+ override val command = "ALTER TABLE .. ADD COLUMNS" | |
+ | |
+ test("add an ANSI interval columns") { | |
+ assume(!catalogVersion.contains("Hive")) // Hive catalog doesn't support the interval types | |
+ | |
+ withNamespaceAndTable("ns", "tbl") { t => | |
+ sql(s"CREATE TABLE $t (id bigint) $defaultUsing") | |
+ sql(s"ALTER TABLE $t ADD COLUMNS (ym INTERVAL YEAR, dt INTERVAL HOUR)") | |
+ sql(s"INSERT INTO $t SELECT 0, INTERVAL '100' YEAR, INTERVAL '10' HOUR") | |
+ checkAnswer( | |
+ sql(s"SELECT id, ym, dt data FROM $t"), | |
+ Seq(Row(0, Period.ofYears(100), Duration.ofHours(10)))) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala | |
index e2e15917d7..dee14953c0 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala | |
@@ -17,7 +17,9 @@ | |
package org.apache.spark.sql.execution.command | |
-import org.apache.spark.sql.{AnalysisException, QueryTest} | |
+import java.time.{Duration, Period} | |
+ | |
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row} | |
import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException | |
import org.apache.spark.sql.internal.SQLConf | |
@@ -189,4 +191,40 @@ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils | |
checkPartitions(t, Map("part" ->"2020-01-01")) | |
} | |
} | |
+ | |
+ test("SPARK-37261: Add ANSI intervals as partition values") { | |
+ assume(!catalogVersion.contains("Hive")) // Hive catalog doesn't support the interval types | |
+ | |
+ withNamespaceAndTable("ns", "tbl") { t => | |
+ sql( | |
+ s"""CREATE TABLE $t ( | |
+ | ym INTERVAL YEAR, | |
+ | dt INTERVAL DAY, | |
+ | data STRING) $defaultUsing | |
+ |PARTITIONED BY (ym, dt)""".stripMargin) | |
+ sql( | |
+ s"""ALTER TABLE $t ADD PARTITION ( | |
+ | ym = INTERVAL '100' YEAR, | |
+ | dt = INTERVAL '10' DAY | |
+ |) LOCATION 'loc'""".stripMargin) | |
+ | |
+ checkPartitions(t, Map("ym" -> "INTERVAL '100' YEAR", "dt" -> "INTERVAL '10' DAY")) | |
+ checkLocation(t, Map("ym" -> "INTERVAL '100' YEAR", "dt" -> "INTERVAL '10' DAY"), "loc") | |
+ | |
+ sql( | |
+ s"""INSERT INTO $t PARTITION ( | |
+ | ym = INTERVAL '100' YEAR, | |
+ | dt = INTERVAL '10' DAY) SELECT 'aaa'""".stripMargin) | |
+ sql( | |
+ s"""INSERT INTO $t PARTITION ( | |
+ | ym = INTERVAL '1' YEAR, | |
+ | dt = INTERVAL '-1' DAY) SELECT 'bbb'""".stripMargin) | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT ym, dt, data FROM $t"), | |
+ Seq( | |
+ Row(Period.ofYears(100), Duration.ofDays(10), "aaa"), | |
+ Row(Period.ofYears(1), Duration.ofDays(-1), "bbb"))) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala | |
index ebc1bd3468..394392299b 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala | |
@@ -29,7 +29,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa | |
val errMsg = intercept[ParseException] { | |
parsePlan("ALTER TABLE RECOVER PARTITIONS") | |
}.getMessage | |
- assert(errMsg.contains("no viable alternative at input 'ALTER TABLE RECOVER PARTITIONS'")) | |
+ assert(errMsg.contains("Syntax error at or near 'PARTITIONS'")) | |
} | |
test("recover partitions of a table") { | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala | |
index 1803ec0469..2942d61f7f 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala | |
@@ -136,4 +136,12 @@ trait AlterTableRenameSuiteBase extends QueryTest with DDLCommandTestUtils { | |
checkAnswer(spark.table(dst), Row(1, 2)) | |
} | |
} | |
+ | |
+ test("SPARK-38587: use formatted names") { | |
+ withNamespaceAndTable("CaseUpperCaseLower", "CaseUpperCaseLower") { t => | |
+ sql(s"CREATE TABLE ${t}_Old (i int) $defaultUsing") | |
+ sql(s"ALTER TABLE ${t}_Old RENAME TO CaseUpperCaseLower.CaseUpperCaseLower") | |
+ assert(spark.table(t).isEmpty) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..fed40767c2 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableReplaceColumnsSuiteBase.scala | |
@@ -0,0 +1,54 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import java.time.{Duration, Period} | |
+ | |
+import org.apache.spark.sql.{QueryTest, Row} | |
+ | |
+/** | |
+ * This base suite contains unified tests for the `ALTER TABLE .. REPLACE COLUMNS` command that | |
+ * check the V2 table catalog. The tests that cannot run for all supported catalogs are | |
+ * located in more specific test suites: | |
+ * | |
+ * - V2 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v2.AlterTableReplaceColumnsSuite` | |
+ */ | |
+trait AlterTableReplaceColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { | |
+ override val command = "ALTER TABLE .. REPLACE COLUMNS" | |
+ | |
+ test("SPARK-37304: Replace columns by ANSI intervals") { | |
+ withNamespaceAndTable("ns", "tbl") { t => | |
+ sql(s"CREATE TABLE $t (ym INTERVAL MONTH, dt INTERVAL HOUR, data STRING) $defaultUsing") | |
+ // TODO(SPARK-37303): Uncomment the command below after REPLACE COLUMNS is fixed | |
+ // sql(s"INSERT INTO $t SELECT INTERVAL '1' MONTH, INTERVAL '2' HOUR, 'abc'") | |
+ sql( | |
+ s""" | |
+ |ALTER TABLE $t REPLACE COLUMNS ( | |
+ | new_ym INTERVAL YEAR, | |
+ | new_dt INTERVAL MINUTE, | |
+ | new_data INT)""".stripMargin) | |
+ sql(s"INSERT INTO $t SELECT INTERVAL '3' YEAR, INTERVAL '4' MINUTE, 5") | |
+ | |
+ checkAnswer( | |
+ sql(s"SELECT new_ym, new_dt, new_data FROM $t"), | |
+ Seq( | |
+ Row(Period.ofYears(3), Duration.ofMinutes(4), 5))) | |
+ } | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala | |
index ba683c049a..f77b6336b8 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala | |
@@ -98,23 +98,75 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { | |
} | |
} | |
- def checkTableSchemaTypeStr(expected: Seq[Row]): Unit = { | |
- checkAnswer(sql("desc t").selectExpr("data_type").where("data_type like '%char%'"), expected) | |
+ def checkTableSchemaTypeStr(table: String, expected: Seq[Row]): Unit = { | |
+ checkAnswer( | |
+ sql(s"desc $table").selectExpr("data_type").where("data_type like '%char%'"), | |
+ expected) | |
} | |
test("SPARK-33901: alter table add columns should not change original table's schema") { | |
withTable("t") { | |
sql(s"CREATE TABLE t(i CHAR(5), c VARCHAR(4)) USING $format") | |
sql("ALTER TABLE t ADD COLUMNS (d VARCHAR(5))") | |
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)"), Row("varchar(5)"))) | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)"), Row("varchar(5)"))) | |
} | |
} | |
test("SPARK-33901: ctas should should not change table's schema") { | |
- withTable("t", "tt") { | |
- sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format") | |
- sql(s"CREATE TABLE t USING $format AS SELECT * FROM tt") | |
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)"))) | |
+ withTable("t1", "t2") { | |
+ sql(s"CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING $format") | |
+ sql(s"CREATE TABLE t2 USING $format AS SELECT * FROM t1") | |
+ checkTableSchemaTypeStr("t2", Seq(Row("char(5)"), Row("varchar(4)"))) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37160: CREATE TABLE with CHAR_AS_VARCHAR") { | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37160: CREATE TABLE AS SELECT with CHAR_AS_VARCHAR") { | |
+ withTable("t1", "t2") { | |
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)"))) | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ sql(s"CREATE TABLE t2 USING $format AS SELECT * FROM t1") | |
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37160: ALTER TABLE ADD COLUMN with CHAR_AS_VARCHAR") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"))) | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ sql("ALTER TABLE t ADD COLUMN c2 CHAR(10)") | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(10)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") | |
+ checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), | |
+ Row("varchar(3)")) | |
+ checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), | |
+ Row("char(5)")) | |
+ } | |
+ } | |
+ | |
+ test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") | |
+ val rest = sql("SHOW CREATE TABLE t").head().getString(0) | |
+ assert(rest.contains("VARCHAR(3)")) | |
+ assert(rest.contains("CHAR(5)")) | |
} | |
} | |
} | |
@@ -130,7 +182,7 @@ class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with Shar | |
withTable("t", "tt") { | |
sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format") | |
sql("CREATE TABLE t LIKE tt") | |
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)"))) | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)"))) | |
} | |
} | |
@@ -140,7 +192,19 @@ class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with Shar | |
sql(s"CREATE TABLE tt(i CHAR(5), c VARCHAR(4)) USING $format") | |
withView("t") { | |
sql("CREATE VIEW t AS SELECT * FROM tt") | |
- checkTableSchemaTypeStr(Seq(Row("char(5)"), Row("varchar(4)"))) | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(4)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ // TODO(SPARK-33902): MOVE TO SUPER CLASS AFTER THE TARGET TICKET RESOLVED | |
+ test("SPARK-37160: CREATE TABLE LIKE with CHAR_AS_VARCHAR") { | |
+ withTable("t1", "t2") { | |
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)"))) | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ sql(s"CREATE TABLE t2 LIKE t1") | |
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)"))) | |
} | |
} | |
} | |
@@ -196,4 +260,40 @@ class DSV2CharVarcharDDLTestSuite extends CharVarcharDDLTestBase | |
assert(e.getMessage contains "char(4) cannot be cast to varchar(3)") | |
} | |
} | |
+ | |
+ test("SPARK-37160: REPLACE TABLE with CHAR_AS_VARCHAR") { | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(col INT) USING $format") | |
+ sql(s"REPLACE TABLE t(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37160: REPLACE TABLE AS SELECT with CHAR_AS_VARCHAR") { | |
+ withTable("t1", "t2") { | |
+ sql(s"CREATE TABLE t1(col CHAR(5)) USING $format") | |
+ checkTableSchemaTypeStr("t1", Seq(Row("char(5)"))) | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ sql(s"CREATE TABLE t2(col INT) USING $format") | |
+ sql(s"REPLACE TABLE t2 AS SELECT * FROM t1") | |
+ checkTableSchemaTypeStr("t2", Seq(Row("varchar(5)"))) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("SPARK-37160: ALTER TABLE ALTER/REPLACE COLUMN with CHAR_AS_VARCHAR") { | |
+ withTable("t") { | |
+ sql(s"CREATE TABLE t(col CHAR(5), c2 VARCHAR(10)) USING $format") | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(10)"))) | |
+ withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") { | |
+ sql("ALTER TABLE t ALTER c2 TYPE CHAR(20)") | |
+ checkTableSchemaTypeStr("t", Seq(Row("char(5)"), Row("varchar(20)"))) | |
+ | |
+ sql("ALTER TABLE t REPLACE COLUMNS (col CHAR(5))") | |
+ checkTableSchemaTypeStr("t", Seq(Row("varchar(5)"))) | |
+ } | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala | |
new file mode 100644 | |
index 0000000000..6c59512148 | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala | |
@@ -0,0 +1,112 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedDBObjectName} | |
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan | |
+import org.apache.spark.sql.catalyst.plans.logical.CreateNamespace | |
+ | |
+class CreateNamespaceParserSuite extends AnalysisTest { | |
+ test("create namespace -- backward compatibility with DATABASE/DBPROPERTIES") { | |
+ val expected = CreateNamespace( | |
+ UnresolvedDBObjectName(Seq("a", "b", "c"), true), | |
+ ifNotExists = true, | |
+ Map( | |
+ "a" -> "a", | |
+ "b" -> "b", | |
+ "c" -> "c", | |
+ "comment" -> "namespace_comment", | |
+ "location" -> "/home/user/db")) | |
+ | |
+ comparePlans( | |
+ parsePlan( | |
+ """ | |
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c | |
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') | |
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db' | |
+ """.stripMargin), | |
+ expected) | |
+ | |
+ comparePlans( | |
+ parsePlan( | |
+ """ | |
+ |CREATE DATABASE IF NOT EXISTS a.b.c | |
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') | |
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db' | |
+ """.stripMargin), | |
+ expected) | |
+ } | |
+ | |
+ test("create namespace -- check duplicates") { | |
+ def createNamespace(duplicateClause: String): String = { | |
+ s""" | |
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c | |
+ |$duplicateClause | |
+ |$duplicateClause | |
+ """.stripMargin | |
+ } | |
+ val sql1 = createNamespace("COMMENT 'namespace_comment'") | |
+ val sql2 = createNamespace("LOCATION '/home/user/db'") | |
+ val sql3 = createNamespace("WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')") | |
+ val sql4 = createNamespace("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") | |
+ | |
+ intercept(sql1, "Found duplicate clauses: COMMENT") | |
+ intercept(sql2, "Found duplicate clauses: LOCATION") | |
+ intercept(sql3, "Found duplicate clauses: WITH PROPERTIES") | |
+ intercept(sql4, "Found duplicate clauses: WITH DBPROPERTIES") | |
+ } | |
+ | |
+ test("create namespace - property values must be set") { | |
+ intercept( | |
+ "CREATE NAMESPACE a.b.c WITH PROPERTIES('key_without_value', 'key_with_value'='x')", | |
+ "Operation not allowed: Values must be specified for key(s): [key_without_value]") | |
+ } | |
+ | |
+ test("create namespace -- either PROPERTIES or DBPROPERTIES is allowed") { | |
+ val sql = | |
+ s""" | |
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c | |
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') | |
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') | |
+ """.stripMargin | |
+ intercept(sql, "The feature is not supported: " + | |
+ "set PROPERTIES and DBPROPERTIES at the same time.") | |
+ } | |
+ | |
+ test("create namespace - support for other types in PROPERTIES") { | |
+ val sql = | |
+ """ | |
+ |CREATE NAMESPACE a.b.c | |
+ |LOCATION '/home/user/db' | |
+ |WITH PROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE) | |
+ """.stripMargin | |
+ comparePlans( | |
+ parsePlan(sql), | |
+ CreateNamespace( | |
+ UnresolvedDBObjectName(Seq("a", "b", "c"), true), | |
+ ifNotExists = false, | |
+ Map( | |
+ "a" -> "1", | |
+ "b" -> "0.1", | |
+ "c" -> "true", | |
+ "location" -> "/home/user/db"))) | |
+ } | |
+ | |
+ private def intercept(sqlCommand: String, messages: String*): Unit = | |
+ interceptParseException(parsePlan)(sqlCommand, messages: _*)() | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala | |
new file mode 100644 | |
index 0000000000..7db8fba8ac | |
--- /dev/null | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala | |
@@ -0,0 +1,140 @@ | |
+/* | |
+ * Licensed to the Apache Software Foundation (ASF) under one or more | |
+ * contributor license agreements. See the NOTICE file distributed with | |
+ * this work for additional information regarding copyright ownership. | |
+ * The ASF licenses this file to You under the Apache License, Version 2.0 | |
+ * (the "License"); you may not use this file except in compliance with | |
+ * the License. You may obtain a copy of the License at | |
+ * | |
+ * http://www.apache.org/licenses/LICENSE-2.0 | |
+ * | |
+ * Unless required by applicable law or agreed to in writing, software | |
+ * distributed under the License is distributed on an "AS IS" BASIS, | |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
+ * See the License for the specific language governing permissions and | |
+ * limitations under the License. | |
+ */ | |
+ | |
+package org.apache.spark.sql.execution.command | |
+ | |
+import scala.collection.JavaConverters._ | |
+ | |
+import org.apache.hadoop.fs.Path | |
+ | |
+import org.apache.spark.sql.QueryTest | |
+import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException | |
+import org.apache.spark.sql.catalyst.parser.ParseException | |
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, SupportsNamespaces} | |
+import org.apache.spark.sql.execution.command.DDLCommandTestUtils.V1_COMMAND_VERSION | |
+import org.apache.spark.sql.internal.SQLConf | |
+ | |
+/** | |
+ * This base suite contains unified tests for the `CREATE NAMESPACE` command that check V1 and V2 | |
+ * table catalogs. The tests that cannot run for all supported catalogs are located in more | |
+ * specific test suites: | |
+ * | |
+ * - V2 table catalog tests: `org.apache.spark.sql.execution.command.v2.CreateNamespaceSuite` | |
+ * - V1 table catalog tests: | |
+ * `org.apache.spark.sql.execution.command.v1.CreateNamespaceSuiteBase` | |
+ * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.CreateNamespaceSuite` | |
+ * - V1 Hive External catalog: | |
+* `org.apache.spark.sql.hive.execution.command.CreateNamespaceSuite` | |
+ */ | |
+trait CreateNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { | |
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
+ | |
+ override val command = "CREATE NAMESPACE" | |
+ | |
+ protected def namespace: String | |
+ | |
+ protected def namespaceArray: Array[String] = namespace.split('.') | |
+ | |
+ protected def notFoundMsgPrefix: String = | |
+ if (commandVersion == V1_COMMAND_VERSION) "Database" else "Namespace" | |
+ | |
+ test("basic") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ assert(getCatalog(catalog).asNamespaceCatalog.namespaceExists(namespaceArray)) | |
+ } | |
+ } | |
+ | |
+ test("namespace with location") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ withTempDir { tmpDir => | |
+ // The generated temp path is not qualified. | |
+ val path = tmpDir.getCanonicalPath | |
+ assert(!path.startsWith("file:/")) | |
+ | |
+ val e = intercept[IllegalArgumentException] { | |
+ sql(s"CREATE NAMESPACE $ns LOCATION ''") | |
+ } | |
+ assert(e.getMessage.contains("Can not create a Path from an empty string")) | |
+ | |
+ val uri = new Path(path).toUri | |
+ sql(s"CREATE NAMESPACE $ns LOCATION '$uri'") | |
+ | |
+ // Make sure the location is qualified. | |
+ val expected = makeQualifiedPath(tmpDir.toString) | |
+ assert("file" === expected.getScheme) | |
+ assert(new Path(getNamespaceLocation(catalog, namespaceArray)).toUri === expected) | |
+ } | |
+ } | |
+ } | |
+ | |
+ test("Namespace already exists") { | |
+ val ns = s"$catalog.$namespace" | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ | |
+ val e = intercept[NamespaceAlreadyExistsException] { | |
+ sql(s"CREATE NAMESPACE $ns") | |
+ } | |
+ assert(e.getMessage.contains(s"$notFoundMsgPrefix '$namespace' already exists")) | |
+ | |
+ // The following will be no-op since the namespace already exists. | |
+ sql(s"CREATE NAMESPACE IF NOT EXISTS $ns") | |
+ } | |
+ } | |
+ | |
+ test("CreateNameSpace: reserved properties") { | |
+ import SupportsNamespaces._ | |
+ val ns = s"$catalog.$namespace" | |
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { | |
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
+ val exception = intercept[ParseException] { | |
+ sql(s"CREATE NAMESPACE $ns WITH DBPROPERTIES('$key'='dummyVal')") | |
+ } | |
+ assert(exception.getMessage.contains(s"$key is a reserved namespace property")) | |
+ } | |
+ } | |
+ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "true")) { | |
+ CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filterNot(_ == PROP_COMMENT).foreach { key => | |
+ withNamespace(ns) { | |
+ sql(s"CREATE NAMESPACE $ns WITH DBPROPERTIES('$key'='foo')") | |
+ assert(sql(s"DESC NAMESPACE EXTENDED $ns") | |
+ .toDF("k", "v") | |
+ .where("k='Properties'") | |
+ .where("v=''") | |
+ .count == 1, s"$key is a reserved namespace property and ignored") | |
+ val meta = | |
+ getCatalog(catalog).asNamespaceCatalog.loadNamespaceMetadata(namespaceArray) | |
+ assert(meta.get(key) == null || !meta.get(key).contains("foo"), | |
+ "reserved properties should not have side effects") | |
+ } | |
+ } | |
+ } | |
+ } | |
+ | |
+ protected def getNamespaceLocation(catalog: String, namespace: Array[String]): String = { | |
+ val metadata = getCatalog(catalog).asNamespaceCatalog | |
+ .loadNamespaceMetadata(namespace).asScala | |
+ metadata(SupportsNamespaces.PROP_LOCATION) | |
+ } | |
+ | |
+ private def getCatalog(name: String): CatalogPlugin = { | |
+ spark.sessionState.catalogManager.catalog(name) | |
+ } | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala | |
index f9e26f8277..39f2abd35c 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala | |
@@ -39,7 +39,9 @@ import org.apache.spark.sql.test.SQLTestUtils | |
*/ | |
trait DDLCommandTestUtils extends SQLTestUtils { | |
// The version of the catalog under testing such as "V1", "V2", "Hive V1". | |
- protected def version: String | |
+ protected def catalogVersion: String | |
+ // The version of the SQL command under testing such as "V1", "V2". | |
+ protected def commandVersion: String | |
// Name of the command as SQL statement, for instance "SHOW PARTITIONS" | |
protected def command: String | |
// The catalog name which can be used in SQL statements under testing | |
@@ -51,7 +53,8 @@ trait DDLCommandTestUtils extends SQLTestUtils { | |
// the failed test in logs belongs to. | |
override def test(testName: String, testTags: Tag*)(testFun: => Any) | |
(implicit pos: Position): Unit = { | |
- super.test(s"$command $version: " + testName, testTags: _*)(testFun) | |
+ val testNamePrefix = s"$command using $catalogVersion catalog $commandVersion command" | |
+ super.test(s"$testNamePrefix: $testName", testTags: _*)(testFun) | |
} | |
protected def withNamespaceAndTable(ns: String, tableName: String, cat: String = catalog) | |
@@ -170,3 +173,8 @@ trait DDLCommandTestUtils extends SQLTestUtils { | |
part1Loc | |
} | |
} | |
+ | |
+object DDLCommandTestUtils { | |
+ val V1_COMMAND_VERSION = "V1" | |
+ val V2_COMMAND_VERSION = "V2" | |
+} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala | |
index 6c337e3e82..7bf2b8ff04 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala | |
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.command | |
import java.util.Locale | |
import org.apache.spark.sql.AnalysisException | |
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute} | |
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView, LocalTempView, UnresolvedAttribute, UnresolvedDBObjectName, UnresolvedFunc} | |
+import org.apache.spark.sql.catalyst.catalog.{ArchiveResource, FileResource, FunctionResource, JarResource} | |
import org.apache.spark.sql.catalyst.dsl.expressions._ | |
import org.apache.spark.sql.catalyst.dsl.plans | |
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan | |
@@ -31,6 +32,7 @@ import org.apache.spark.sql.execution.SparkSqlParser | |
import org.apache.spark.sql.test.SharedSparkSession | |
class DDLParserSuite extends AnalysisTest with SharedSparkSession { | |
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
private lazy val parser = new SparkSqlParser() | |
private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { | |
@@ -43,15 +45,18 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { | |
} | |
} | |
+ private def intercept(sqlCommand: String, messages: String*): Unit = | |
+ interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)() | |
+ | |
private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { | |
val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) | |
comparePlans(plan, expected, checkAnalysis = false) | |
} | |
- test("alter database - property values must be set") { | |
- assertUnsupported( | |
- sql = "ALTER DATABASE my_db SET DBPROPERTIES('key_without_value', 'key_with_value'='x')", | |
- containsThesePhrases = Seq("key_without_value")) | |
+ test("show current namespace") { | |
+ comparePlans( | |
+ parser.parsePlan("SHOW CURRENT NAMESPACE"), | |
+ ShowCurrentNamespaceCommand()) | |
} | |
test("insert overwrite directory") { | |
@@ -195,21 +200,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { | |
assert(parsed.isInstanceOf[Project]) | |
} | |
- test("duplicate keys in table properties") { | |
- val e = intercept[ParseException] { | |
- parser.parsePlan("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('key1' = '1', 'key1' = '2')") | |
- }.getMessage | |
- assert(e.contains("Found duplicate keys 'key1'")) | |
- } | |
- | |
- test("duplicate columns in partition specs") { | |
- val e = intercept[ParseException] { | |
- parser.parsePlan( | |
- "ALTER TABLE dbx.tab1 PARTITION (a='1', a='2') RENAME TO PARTITION (a='100', a='200')") | |
- }.getMessage | |
- assert(e.contains("Found duplicate keys 'a'")) | |
- } | |
- | |
test("unsupported operations") { | |
intercept[ParseException] { | |
parser.parsePlan( | |
@@ -283,12 +273,12 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { | |
val s = ScriptTransformation("func", Seq.empty, p, null) | |
compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", | |
- s.copy(child = p.copy(child = p.child.where('f < 10)), | |
- output = Seq('key.string, 'value.string))) | |
+ s.copy(child = p.copy(child = p.child.where(Symbol("f") < 10)), | |
+ output = Seq(Symbol("key").string, Symbol("value").string))) | |
compareTransformQuery("map a, b using 'func' as c, d from e", | |
- s.copy(output = Seq('c.string, 'd.string))) | |
+ s.copy(output = Seq(Symbol("c").string, Symbol("d").string))) | |
compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", | |
- s.copy(output = Seq('c.int, 'd.decimal(10, 0)))) | |
+ s.copy(output = Seq(Symbol("c").int, Symbol("d").decimal(10, 0)))) | |
} | |
test("use backticks in output of Script Transform") { | |
@@ -319,6 +309,182 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { | |
""".stripMargin) | |
} | |
+ test("create view -- basic") { | |
+ val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" | |
+ val parsed1 = parser.parsePlan(v1) | |
+ | |
+ val expected1 = CreateView( | |
+ UnresolvedDBObjectName(Seq("view1"), false), | |
+ Seq.empty[(String, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment