Skip to content

Commit a0aa9f4

Browse files
committed
[SPARK-53738][SQL] PlannedWrite should preserve custom sort order when query output contains literal
1 parent 6cdc62e commit a0aa9f4

File tree

12 files changed

+142
-26
lines changed

12 files changed

+142
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
2828
*/
2929
trait AliasAwareOutputExpression extends SQLConfHelper {
3030
protected val aliasCandidateLimit = conf.getConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT)
31-
protected def outputExpressions: Seq[NamedExpression]
31+
def outputExpressions: Seq[NamedExpression]
3232
/**
3333
* This method can be used to strip expression which does not affect the result, for example:
3434
* strip the expression which is ordering agnostic for output ordering.
@@ -88,7 +88,7 @@ trait AliasAwareOutputExpression extends SQLConfHelper {
8888
*/
8989
trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]]
9090
extends AliasAwareOutputExpression { self: QueryPlan[T] =>
91-
protected def orderingExpressions: Seq[SortOrder]
91+
def orderingExpressions: Seq[SortOrder]
9292

9393
override protected def strip(expr: Expression): Expression = expr match {
9494
case e: Empty2Null => strip(e.child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,16 @@ trait BinaryNode extends LogicalPlan with BinaryLike[LogicalPlan]
293293

294294
trait OrderPreservingUnaryNode extends UnaryNode
295295
with AliasAwareQueryOutputOrdering[LogicalPlan] {
296-
override protected def outputExpressions: Seq[NamedExpression] = child.output
297-
override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering
296+
297+
override def outputExpressions: Seq[NamedExpression] = child match {
298+
case o: OrderPreservingUnaryNode => o.outputExpressions
299+
case _ => child.output
300+
}
301+
302+
override def orderingExpressions: Seq[SortOrder] = child match {
303+
case o: OrderPreservingUnaryNode => o.orderingExpressions
304+
case _ => child.outputOrdering
305+
}
298306
}
299307

300308
object LogicalPlanIntegrity {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ object Subquery {
7373
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
7474
extends OrderPreservingUnaryNode {
7575
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
76-
override protected def outputExpressions: Seq[NamedExpression] = projectList
76+
override def outputExpressions: Seq[NamedExpression] = projectList
7777
override def maxRows: Option[Long] = child.maxRows
7878
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
7979

@@ -906,13 +906,13 @@ case class Sort(
906906
order: Seq[SortOrder],
907907
global: Boolean,
908908
child: LogicalPlan,
909-
hint: Option[SortHint] = None) extends UnaryNode {
909+
hint: Option[SortHint] = None) extends UnaryNode with OrderPreservingUnaryNode {
910910
override def output: Seq[Attribute] = child.output
911+
override def orderingExpressions: Seq[SortOrder] = order
911912
override def maxRows: Option[Long] = child.maxRows
912913
override def maxRowsPerPartition: Option[Long] = {
913914
if (global) maxRows else child.maxRowsPerPartition
914915
}
915-
override def outputOrdering: Seq[SortOrder] = order
916916
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
917917
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
918918
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,16 @@ case class SortExec(
4242
global: Boolean,
4343
child: SparkPlan,
4444
testSpillFrequency: Int = 0)
45-
extends UnaryExecNode with BlockingOperatorWithCodegen {
45+
extends UnaryExecNode with BlockingOperatorWithCodegen with OrderPreservingUnaryExecNode {
4646

4747
override def output: Seq[Attribute] = child.output
4848

49-
override def outputOrdering: Seq[SortOrder] = sortOrder
49+
override def outputExpressions: Seq[NamedExpression] = child match {
50+
case o: OrderPreservingUnaryExecNode => o.outputExpressions
51+
case _ => child.output
52+
}
53+
54+
override def orderingExpressions: Seq[SortOrder] = sortOrder
5055

5156
// sort performed is local within a given partition so will retain
5257
// child operator's partitioning

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,13 +633,21 @@ object WholeStageCodegenExec {
633633
* used to generated code for [[BoundReference]].
634634
*/
635635
case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
636-
extends UnaryExecNode with CodegenSupport {
636+
extends UnaryExecNode with CodegenSupport with OrderPreservingUnaryExecNode {
637637

638638
override def output: Seq[Attribute] = child.output
639639

640640
override def outputPartitioning: Partitioning = child.outputPartitioning
641641

642-
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
642+
override def outputExpressions: Seq[NamedExpression] = child match {
643+
case o: OrderPreservingUnaryExecNode => o.outputExpressions
644+
case _ => child.output
645+
}
646+
647+
override def orderingExpressions: Seq[SortOrder] = child match {
648+
case o: OrderPreservingUnaryExecNode => o.orderingExpressions
649+
case _ => child.outputOrdering
650+
}
643651

644652
// This is not strictly needed because the codegen transformation happens after the columnar
645653
// transformation but just for consistency

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ trait BaseAggregateExec extends UnaryExecNode with PartitioningPreservingUnaryEx
9191

9292
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
9393

94-
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
94+
override def outputExpressions: Seq[NamedExpression] = resultExpressions
9595

9696
override def requiredChildDistribution: List[Distribution] = {
9797
requiredChildDistributionExpressions match {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ case class SortAggregateExec(
5151
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
5252
}
5353

54-
override protected def orderingExpressions: Seq[SortOrder] = {
54+
override def orderingExpressions: Seq[SortOrder] = {
5555
groupingExpressions.map(SortOrder(_, Ascending))
5656
}
5757

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
102102
}
103103
}
104104

105-
override protected def outputExpressions: Seq[NamedExpression] = projectList
105+
override def outputExpressions: Seq[NamedExpression] = projectList
106106

107-
override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering
107+
override def orderingExpressions: Seq[SortOrder] = child.outputOrdering
108108

109109
override def verboseStringWithOperatorId(): String = {
110110
s"""

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import java.util.{Date, UUID}
2121

22+
import scala.annotation.tailrec
23+
2224
import org.apache.hadoop.conf.Configuration
2325
import org.apache.hadoop.fs.Path
2426
import org.apache.hadoop.mapreduce._
@@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3739
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
3840
import org.apache.spark.sql.classic.SparkSession
3941
import org.apache.spark.sql.connector.write.WriterCommitMessage
40-
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
42+
import org.apache.spark.sql.execution.{OrderPreservingUnaryExecNode, ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
4143
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
4244
import org.apache.spark.util.{SerializableConfiguration, Utils}
4345
import org.apache.spark.util.ArrayImplicits._
@@ -138,10 +140,6 @@ object FileFormatWriter extends Logging {
138140
statsTrackers = statsTrackers
139141
)
140142

141-
// We should first sort by dynamic partition columns, then bucket id, and finally sorting
142-
// columns.
143-
val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
144-
writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
145143
val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan)
146144

147145
// SPARK-40588: when planned writing is disabled and AQE is enabled,
@@ -153,10 +151,34 @@ object FileFormatWriter extends Logging {
153151
case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan))
154152
}
155153

154+
val query = writeFilesOpt.map(_.child).getOrElse(materializeAdaptiveSparkPlan(plan))
155+
156156
// the sort order doesn't matter
157-
val actualOrdering = writeFilesOpt.map(_.child)
158-
.getOrElse(materializeAdaptiveSparkPlan(plan))
159-
.outputOrdering
157+
val actualOrdering = query.outputOrdering
158+
159+
val queryOutput = query match {
160+
case o: OrderPreservingUnaryExecNode => o.outputExpressions
161+
case _ => query.output
162+
}
163+
164+
@tailrec
165+
def isLiteral(e: Expression, name: String): Option[String] =
166+
e match {
167+
case Alias(child, n) => isLiteral(child, n)
168+
case _: Literal => Some(name)
169+
case _ => None
170+
}
171+
172+
val literalColumns = queryOutput.flatMap { ne => isLiteral(ne, ne.name) }
173+
174+
// We should first sort by dynamic partition columns, then bucket id, and finally sorting
175+
// columns, then drop literal columns
176+
val requiredOrdering = (partitionColumns.drop(numStaticPartitionCols) ++
177+
writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns).dropWhile {
178+
case attr: Attribute => literalColumns.contains(attr.name)
179+
case _ => false
180+
}
181+
160182
val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering)
161183

162184
SQLExecution.checkSQLExecutionId(sparkSession)

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,9 @@ case class TakeOrderedAndProjectExec(
385385
}
386386
}
387387

388-
override protected def outputExpressions: Seq[NamedExpression] = projectList
388+
override def outputExpressions: Seq[NamedExpression] = projectList
389389

390-
override protected def orderingExpressions: Seq[SortOrder] = sortOrder
390+
override def orderingExpressions: Seq[SortOrder] = sortOrder
391391

392392
override def outputPartitioning: Partitioning = SinglePartition
393393

0 commit comments

Comments
 (0)