Skip to content

Commit 67834ee

Browse files
committed
[SPARK-53148][CONNECT][SQL] Make SqlCommand in SparkConnectPlanner side effect free
1 parent ea8b6fd commit 67834ee

File tree

4 files changed

+173
-81
lines changed

4 files changed

+173
-81
lines changed

python/pyspark/sql/tests/test_sql.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ def test_nested_dataframe(self):
168168
self.assertEqual(df3.take(1), [Row(id=4)])
169169
self.assertEqual(df3.tail(1), [Row(id=9)])
170170

171+
def test_nested_dataframe_for_sql_scripting(self):
172+
with self.sql_conf({"spark.sql.scripting.enabled": True}):
173+
df0 = self.spark.range(10)
174+
df1 = self.spark.sql(
175+
"BEGIN SELECT * FROM {df} WHERE id > 1; END;",
176+
df=df0,
177+
)
178+
self.assertEqual(df1.count(), 8)
179+
self.assertEqual(df1.take(1), [Row(id=2)])
180+
self.assertEqual(df1.tail(1), [Row(id=9)])
181+
171182
def test_lit_time(self):
172183
import datetime
173184

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
3535
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3636
import org.apache.spark.sql.connect.service.ExecuteHolder
3737
import org.apache.spark.sql.connect.utils.MetricGenerator
38-
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
38+
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
3939
import org.apache.spark.sql.execution.arrow.ArrowConverters
4040
import org.apache.spark.sql.internal.SQLConf
4141
import org.apache.spark.sql.types.StructType
@@ -83,13 +83,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
8383
val command = request.getPlan.getCommand
8484
planner.transformCommand(command) match {
8585
case Some(transformer) =>
86-
val qe = new QueryExecution(
87-
session,
88-
transformer(tracker),
86+
val plan = transformer(tracker)
87+
planner.runCommand(
88+
plan,
8989
tracker,
90-
shuffleCleanupMode = shuffleCleanupMode)
91-
qe.assertCommandExecuted()
92-
executeHolder.eventsManager.postFinished()
90+
responseObserver,
91+
command,
92+
shuffleCleanupMode = Some(shuffleCleanupMode))
9393
case None =>
9494
planner.process(command, responseObserver)
9595
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 114 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
3535
import org.apache.spark.annotation.{DeveloperApi, Since}
3636
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
3737
import org.apache.spark.connect.proto
38-
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
38+
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
3939
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
4040
import org.apache.spark.connect.proto.Parse.ParseFormat
4141
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
6969
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
7070
import org.apache.spark.sql.connect.utils.MetricGenerator
7171
import org.apache.spark.sql.errors.QueryCompilationErrors
72-
import org.apache.spark.sql.execution.QueryExecution
72+
import org.apache.spark.sql.execution.{QueryExecution, ShuffleCleanupMode}
7373
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
7474
import org.apache.spark.sql.execution.arrow.ArrowConverters
7575
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
@@ -343,13 +343,6 @@ class SparkConnectPlanner(
343343
}
344344
}
345345

346-
private def transformSqlWithRefs(query: proto.WithRelations): LogicalPlan = {
347-
if (!isValidSQLWithRefs(query)) {
348-
throw InvalidInputErrors.invalidSQLWithReferences(query)
349-
}
350-
executeSQLWithRefs(query).logicalPlan
351-
}
352-
353346
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
354347
val aliasIdentifier =
355348
if (alias.getQualifierCount > 0) {
@@ -2651,6 +2644,8 @@ class SparkConnectPlanner(
26512644
Some(transformWriteOperation(command.getWriteOperation))
26522645
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
26532646
Some(transformWriteOperationV2(command.getWriteOperationV2))
2647+
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2648+
Some(transformSqlCommand(command.getSqlCommand))
26542649
case _ =>
26552650
None
26562651
}
@@ -2661,7 +2656,8 @@ class SparkConnectPlanner(
26612656
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
26622657
val transformerOpt = transformCommand(command)
26632658
if (transformerOpt.isDefined) {
2664-
transformAndRunCommand(transformerOpt.get)
2659+
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2660+
runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
26652661
return
26662662
}
26672663
command.getCommandTypeCase match {
@@ -2675,8 +2671,6 @@ class SparkConnectPlanner(
26752671
handleCreateViewCommand(command.getCreateDataframeView)
26762672
case proto.Command.CommandTypeCase.EXTENSION =>
26772673
handleCommandPlugin(command.getExtension)
2678-
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2679-
handleSqlCommand(command.getSqlCommand, responseObserver)
26802674
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
26812675
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
26822676
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2775,8 @@ class SparkConnectPlanner(
27812775
.build())
27822776
}
27832777

2784-
private def handleSqlCommand(
2785-
command: SqlCommand,
2786-
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
2787-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788-
2789-
val relation = if (command.hasInput) {
2778+
private def getRelationFromSQLCommand(command: proto.SqlCommand): proto.Relation = {
2779+
if (command.hasInput) {
27902780
command.getInput
27912781
} else {
27922782
// for backward compatibility
@@ -2803,19 +2793,47 @@ class SparkConnectPlanner(
28032793
.build())
28042794
.build()
28052795
}
2796+
}
28062797

2807-
val df = relation.getRelTypeCase match {
2798+
private def transformSqlCommand(command: proto.SqlCommand)(
2799+
tracker: QueryPlanningTracker): LogicalPlan = {
2800+
val relation = getRelationFromSQLCommand(command)
2801+
2802+
relation.getRelTypeCase match {
28082803
case proto.Relation.RelTypeCase.SQL =>
2809-
executeSQL(relation.getSql, tracker)
2804+
transformSQL(relation.getSql, tracker)
28102805
case proto.Relation.RelTypeCase.WITH_RELATIONS =>
2811-
executeSQLWithRefs(relation.getWithRelations, tracker)
2806+
transformSQLWithRefs(relation.getWithRelations, tracker)
28122807
case other =>
28132808
throw InvalidInputErrors.sqlCommandExpectsSqlOrWithRelations(other)
28142809
}
2810+
}
2811+
2812+
private def runSQLCommand(
2813+
command: LogicalPlan,
2814+
tracker: QueryPlanningTracker,
2815+
responseObserver: StreamObserver[ExecutePlanResponse],
2816+
protoSQLCommand: proto.SqlCommand,
2817+
shuffleCleanupMode: Option[ShuffleCleanupMode]): Unit = {
2818+
val isSqlScript = command.isInstanceOf[CompoundBody]
2819+
val refs = if (isSqlScript && protoSQLCommand.getInput.hasWithRelations) {
2820+
protoSQLCommand.getInput.getWithRelations.getReferencesList.asScala
2821+
.map(_.getSubqueryAlias)
2822+
.toSeq
2823+
} else {
2824+
Seq.empty
2825+
}
2826+
2827+
val df = runWithRefs(refs) {
2828+
if (shuffleCleanupMode.isDefined) {
2829+
Dataset.ofRows(session, command, tracker, shuffleCleanupMode.get)
2830+
} else {
2831+
Dataset.ofRows(session, command, tracker)
2832+
}
2833+
}
28152834

28162835
// Check if command or SQL Script has been executed.
28172836
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
2818-
val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
28192837
val rows = df.logicalPlan match {
28202838
case lr: LocalRelation => lr.data
28212839
case cr: CommandResult => cr.rows
@@ -2867,7 +2885,7 @@ class SparkConnectPlanner(
28672885
} else {
28682886
// No execution triggered for relations. Manually set ready
28692887
tracker.setReadyForExecution()
2870-
result.setRelation(relation)
2888+
result.setRelation(getRelationFromSQLCommand(protoSQLCommand))
28712889
}
28722890
executeHolder.eventsManager.postFinished(Some(rows.size))
28732891
// Exactly one SQL Command Result Batch
@@ -2909,59 +2927,82 @@ class SparkConnectPlanner(
29092927
true
29102928
}
29112929

2912-
private def executeSQLWithRefs(
2913-
query: proto.WithRelations,
2914-
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2915-
if (!isValidSQLWithRefs(query)) {
2916-
throw InvalidInputErrors.invalidSQLWithReferences(query)
2930+
private def runWithRefs[T](refs: Seq[proto.SubqueryAlias])(f: => T): T = {
2931+
if (refs.isEmpty) {
2932+
return f
29172933
}
2918-
2919-
// Eagerly execute commands of the provided SQL string, with given references.
2920-
val sql = query.getRoot.getSql
29212934
this.synchronized {
29222935
try {
2923-
query.getReferencesList.asScala.foreach { ref =>
2936+
refs.foreach { ref =>
29242937
Dataset
2925-
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
2926-
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
2938+
.ofRows(session, transformRelation(ref.getInput))
2939+
.createOrReplaceTempView(ref.getAlias)
29272940
}
2928-
executeSQL(sql, tracker)
2941+
f
29292942
} finally {
29302943
// drop all temporary views
2931-
query.getReferencesList.asScala.foreach { ref =>
2932-
session.catalog.dropTempView(ref.getSubqueryAlias.getAlias)
2944+
refs.foreach { ref =>
2945+
session.catalog.dropTempView(ref.getAlias)
29332946
}
29342947
}
29352948
}
29362949
}
29372950

2938-
private def executeSQL(
2939-
sql: proto.SQL,
2951+
private def transformSQLWithRefs(
2952+
query: proto.WithRelations,
2953+
tracker: QueryPlanningTracker): LogicalPlan = {
2954+
if (!isValidSQLWithRefs(query)) {
2955+
throw InvalidInputErrors.invalidSQLWithReferences(query)
2956+
}
2957+
2958+
transformSQL(
2959+
query.getRoot.getSql,
2960+
tracker,
2961+
query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq)
2962+
}
2963+
2964+
private def executeSQLWithRefs(
2965+
query: proto.WithRelations,
29402966
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2941-
// Eagerly execute commands of the provided SQL string.
2967+
Dataset.ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2968+
}
2969+
2970+
private def transformSQL(
2971+
sql: proto.SQL,
2972+
tracker: QueryPlanningTracker,
2973+
refsToAnalyze: Seq[proto.SubqueryAlias] = Seq.empty): LogicalPlan = {
29422974
val args = sql.getArgsMap
29432975
val namedArguments = sql.getNamedArgumentsMap
29442976
val posArgs = sql.getPosArgsList
29452977
val posArguments = sql.getPosArgumentsList
2946-
if (!namedArguments.isEmpty) {
2947-
session.sql(
2978+
val parsedPlan = if (!namedArguments.isEmpty) {
2979+
session.sqlParsedPlan(
29482980
sql.getQuery,
29492981
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
29502982
tracker)
29512983
} else if (!posArguments.isEmpty) {
2952-
session.sql(
2984+
session.sqlParsedPlan(
29532985
sql.getQuery,
29542986
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
29552987
tracker)
29562988
} else if (!args.isEmpty) {
2957-
session.sql(
2989+
session.sqlParsedPlan(
29582990
sql.getQuery,
29592991
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
29602992
tracker)
29612993
} else if (!posArgs.isEmpty) {
2962-
session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2994+
session.sqlParsedPlan(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
29632995
} else {
2964-
session.sql(sql.getQuery, Map.empty[String, Any], tracker)
2996+
session.sqlParsedPlan(sql.getQuery, Map.empty[String, Any], tracker)
2997+
}
2998+
if (parsedPlan.isInstanceOf[CompoundBody]) {
2999+
// If the parsed plan is a CompoundBody, skip analysis and return it.
3000+
// SQL scripting is a special case as execution occurs during the analysis phase.
3001+
parsedPlan
3002+
} else {
3003+
runWithRefs(refsToAnalyze) {
3004+
new QueryExecution(session, parsedPlan, tracker).analyzed
3005+
}
29653006
}
29663007
}
29673008

@@ -3157,11 +3198,32 @@ class SparkConnectPlanner(
31573198
}
31583199
}
31593200

3160-
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
3161-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162-
val qe = new QueryExecution(session, transformer(tracker), tracker)
3163-
qe.assertCommandExecuted()
3164-
executeHolder.eventsManager.postFinished()
3201+
private[connect] def runCommand(
3202+
command: LogicalPlan,
3203+
tracker: QueryPlanningTracker,
3204+
responseObserver: StreamObserver[ExecutePlanResponse],
3205+
protoCommand: proto.Command,
3206+
shuffleCleanupMode: Option[ShuffleCleanupMode] = None): Unit = {
3207+
if (protoCommand.getCommandTypeCase == proto.Command.CommandTypeCase.SQL_COMMAND) {
3208+
runSQLCommand(
3209+
command,
3210+
tracker,
3211+
responseObserver,
3212+
protoCommand.getSqlCommand,
3213+
shuffleCleanupMode)
3214+
} else {
3215+
val qe = if (shuffleCleanupMode.isDefined) {
3216+
new QueryExecution(
3217+
session,
3218+
command,
3219+
tracker = tracker,
3220+
shuffleCleanupMode = shuffleCleanupMode.get)
3221+
} else {
3222+
new QueryExecution(session, command, tracker = tracker)
3223+
}
3224+
qe.assertCommandExecuted()
3225+
executeHolder.eventsManager.postFinished()
3226+
}
31653227
}
31663228

31673229
/**
@@ -4105,7 +4167,7 @@ class SparkConnectPlanner(
41054167

41064168
private def transformWithRelations(getWithRelations: proto.WithRelations): LogicalPlan = {
41074169
if (isValidSQLWithRefs(getWithRelations)) {
4108-
transformSqlWithRefs(getWithRelations)
4170+
executeSQLWithRefs(getWithRelations).logicalPlan
41094171
} else {
41104172
// Wrap the plan to keep the original planId.
41114173
val plan = Project(Seq(UnresolvedStar(None)), transformRelation(getWithRelations.getRoot))

0 commit comments

Comments
 (0)