@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
35
35
import org .apache .spark .annotation .{DeveloperApi , Since }
36
36
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
37
37
import org .apache .spark .connect .proto
38
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
+ import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
39
39
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
40
40
import org .apache .spark .connect .proto .Parse .ParseFormat
41
41
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
69
69
import org .apache .spark .sql .connect .service .{ExecuteHolder , SessionHolder , SparkConnectService }
70
70
import org .apache .spark .sql .connect .utils .MetricGenerator
71
71
import org .apache .spark .sql .errors .QueryCompilationErrors
72
- import org .apache .spark .sql .execution .QueryExecution
72
+ import org .apache .spark .sql .execution .{ QueryExecution , ShuffleCleanupMode }
73
73
import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , TypedAggregateExpression }
74
74
import org .apache .spark .sql .execution .arrow .ArrowConverters
75
75
import org .apache .spark .sql .execution .command .{CreateViewCommand , ExternalCommandExecutor }
@@ -343,13 +343,6 @@ class SparkConnectPlanner(
343
343
}
344
344
}
345
345
346
- private def transformSqlWithRefs (query : proto.WithRelations ): LogicalPlan = {
347
- if (! isValidSQLWithRefs(query)) {
348
- throw InvalidInputErrors .invalidSQLWithReferences(query)
349
- }
350
- executeSQLWithRefs(query).logicalPlan
351
- }
352
-
353
346
private def transformSubqueryAlias (alias : proto.SubqueryAlias ): LogicalPlan = {
354
347
val aliasIdentifier =
355
348
if (alias.getQualifierCount > 0 ) {
@@ -2651,6 +2644,8 @@ class SparkConnectPlanner(
2651
2644
Some (transformWriteOperation(command.getWriteOperation))
2652
2645
case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2653
2646
Some (transformWriteOperationV2(command.getWriteOperationV2))
2647
+ case proto.Command .CommandTypeCase .SQL_COMMAND =>
2648
+ Some (transformSqlCommand(command.getSqlCommand))
2654
2649
case _ =>
2655
2650
None
2656
2651
}
@@ -2661,7 +2656,8 @@ class SparkConnectPlanner(
2661
2656
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2662
2657
val transformerOpt = transformCommand(command)
2663
2658
if (transformerOpt.isDefined) {
2664
- transformAndRunCommand(transformerOpt.get)
2659
+ val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2660
+ runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
2665
2661
return
2666
2662
}
2667
2663
command.getCommandTypeCase match {
@@ -2675,8 +2671,6 @@ class SparkConnectPlanner(
2675
2671
handleCreateViewCommand(command.getCreateDataframeView)
2676
2672
case proto.Command .CommandTypeCase .EXTENSION =>
2677
2673
handleCommandPlugin(command.getExtension)
2678
- case proto.Command .CommandTypeCase .SQL_COMMAND =>
2679
- handleSqlCommand(command.getSqlCommand, responseObserver)
2680
2674
case proto.Command .CommandTypeCase .WRITE_STREAM_OPERATION_START =>
2681
2675
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
2682
2676
case proto.Command .CommandTypeCase .STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2775,8 @@ class SparkConnectPlanner(
2781
2775
.build())
2782
2776
}
2783
2777
2784
- private def handleSqlCommand (
2785
- command : SqlCommand ,
2786
- responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2787
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788
-
2789
- val relation = if (command.hasInput) {
2778
+ private def getRelationFromSQLCommand (command : proto.SqlCommand ): proto.Relation = {
2779
+ if (command.hasInput) {
2790
2780
command.getInput
2791
2781
} else {
2792
2782
// for backward compatibility
@@ -2803,19 +2793,47 @@ class SparkConnectPlanner(
2803
2793
.build())
2804
2794
.build()
2805
2795
}
2796
+ }
2806
2797
2807
- val df = relation.getRelTypeCase match {
2798
+ private def transformSqlCommand (command : proto.SqlCommand )(
2799
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2800
+ val relation = getRelationFromSQLCommand(command)
2801
+
2802
+ relation.getRelTypeCase match {
2808
2803
case proto.Relation .RelTypeCase .SQL =>
2809
- executeSQL (relation.getSql, tracker)
2804
+ transformSQL (relation.getSql, tracker)
2810
2805
case proto.Relation .RelTypeCase .WITH_RELATIONS =>
2811
- executeSQLWithRefs (relation.getWithRelations, tracker)
2806
+ transformSQLWithRefs (relation.getWithRelations, tracker)
2812
2807
case other =>
2813
2808
throw InvalidInputErrors .sqlCommandExpectsSqlOrWithRelations(other)
2814
2809
}
2810
+ }
2811
+
2812
+ private def runSQLCommand (
2813
+ command : LogicalPlan ,
2814
+ tracker : QueryPlanningTracker ,
2815
+ responseObserver : StreamObserver [ExecutePlanResponse ],
2816
+ protoSQLCommand : proto.SqlCommand ,
2817
+ shuffleCleanupMode : Option [ShuffleCleanupMode ]): Unit = {
2818
+ val isSqlScript = command.isInstanceOf [CompoundBody ]
2819
+ val refs = if (isSqlScript && protoSQLCommand.getInput.hasWithRelations) {
2820
+ protoSQLCommand.getInput.getWithRelations.getReferencesList.asScala
2821
+ .map(_.getSubqueryAlias)
2822
+ .toSeq
2823
+ } else {
2824
+ Seq .empty
2825
+ }
2826
+
2827
+ val df = runWithRefs(refs) {
2828
+ if (shuffleCleanupMode.isDefined) {
2829
+ Dataset .ofRows(session, command, tracker, shuffleCleanupMode.get)
2830
+ } else {
2831
+ Dataset .ofRows(session, command, tracker)
2832
+ }
2833
+ }
2815
2834
2816
2835
// Check if command or SQL Script has been executed.
2817
2836
val isCommand = df.queryExecution.commandExecuted.isInstanceOf [CommandResult ]
2818
- val isSqlScript = df.queryExecution.logical.isInstanceOf [CompoundBody ]
2819
2837
val rows = df.logicalPlan match {
2820
2838
case lr : LocalRelation => lr.data
2821
2839
case cr : CommandResult => cr.rows
@@ -2867,7 +2885,7 @@ class SparkConnectPlanner(
2867
2885
} else {
2868
2886
// No execution triggered for relations. Manually set ready
2869
2887
tracker.setReadyForExecution()
2870
- result.setRelation(relation )
2888
+ result.setRelation(getRelationFromSQLCommand(protoSQLCommand) )
2871
2889
}
2872
2890
executeHolder.eventsManager.postFinished(Some (rows.size))
2873
2891
// Exactly one SQL Command Result Batch
@@ -2909,59 +2927,83 @@ class SparkConnectPlanner(
2909
2927
true
2910
2928
}
2911
2929
2912
- private def executeSQLWithRefs (
2913
- query : proto.WithRelations ,
2914
- tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2915
- if (! isValidSQLWithRefs(query)) {
2916
- throw InvalidInputErrors .invalidSQLWithReferences(query)
2930
+ private def runWithRefs [T ](refs : Seq [proto.SubqueryAlias ])(f : => T ): T = {
2931
+ if (refs.isEmpty) {
2932
+ return f
2917
2933
}
2918
-
2919
- // Eagerly execute commands of the provided SQL string, with given references.
2920
- val sql = query.getRoot.getSql
2921
2934
this .synchronized {
2922
2935
try {
2923
- query.getReferencesList.asScala .foreach { ref =>
2936
+ refs .foreach { ref =>
2924
2937
Dataset
2925
- .ofRows(session, transformRelation(ref.getSubqueryAlias. getInput))
2926
- .createOrReplaceTempView(ref.getSubqueryAlias. getAlias)
2938
+ .ofRows(session, transformRelation(ref.getInput))
2939
+ .createOrReplaceTempView(ref.getAlias)
2927
2940
}
2928
- executeSQL(sql, tracker)
2941
+ f
2929
2942
} finally {
2930
2943
// drop all temporary views
2931
- query.getReferencesList.asScala .foreach { ref =>
2932
- session.catalog.dropTempView(ref.getSubqueryAlias. getAlias)
2944
+ refs .foreach { ref =>
2945
+ session.catalog.dropTempView(ref.getAlias)
2933
2946
}
2934
2947
}
2935
2948
}
2936
2949
}
2937
2950
2938
- private def executeSQL (
2939
- sql : proto.SQL ,
2951
+ private def transformSQLWithRefs (
2952
+ query : proto.WithRelations ,
2953
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2954
+ if (! isValidSQLWithRefs(query)) {
2955
+ throw InvalidInputErrors .invalidSQLWithReferences(query)
2956
+ }
2957
+
2958
+ transformSQL(
2959
+ query.getRoot.getSql,
2960
+ tracker,
2961
+ query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq)
2962
+ }
2963
+
2964
+ private def executeSQLWithRefs (
2965
+ query : proto.WithRelations ,
2940
2966
tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2941
2967
// Eagerly execute commands of the provided SQL string.
2968
+ Dataset .ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2969
+ }
2970
+
2971
+ private def transformSQL (
2972
+ sql : proto.SQL ,
2973
+ tracker : QueryPlanningTracker ,
2974
+ refsToAnalyze : Seq [proto.SubqueryAlias ] = Seq .empty): LogicalPlan = {
2942
2975
val args = sql.getArgsMap
2943
2976
val namedArguments = sql.getNamedArgumentsMap
2944
2977
val posArgs = sql.getPosArgsList
2945
2978
val posArguments = sql.getPosArgumentsList
2946
- if (! namedArguments.isEmpty) {
2947
- session.sql (
2979
+ val parsedPlan = if (! namedArguments.isEmpty) {
2980
+ session.sqlParsedPlan (
2948
2981
sql.getQuery,
2949
2982
namedArguments.asScala.toMap.transform((_, e) => Column (transformExpression(e))),
2950
2983
tracker)
2951
2984
} else if (! posArguments.isEmpty) {
2952
- session.sql (
2985
+ session.sqlParsedPlan (
2953
2986
sql.getQuery,
2954
2987
posArguments.asScala.map(e => Column (transformExpression(e))).toArray,
2955
2988
tracker)
2956
2989
} else if (! args.isEmpty) {
2957
- session.sql (
2990
+ session.sqlParsedPlan (
2958
2991
sql.getQuery,
2959
2992
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
2960
2993
tracker)
2961
2994
} else if (! posArgs.isEmpty) {
2962
- session.sql (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2995
+ session.sqlParsedPlan (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2963
2996
} else {
2964
- session.sql(sql.getQuery, Map .empty[String , Any ], tracker)
2997
+ session.sqlParsedPlan(sql.getQuery, Map .empty[String , Any ], tracker)
2998
+ }
2999
+ if (parsedPlan.isInstanceOf [CompoundBody ]) {
3000
+ // If the parsed plan is a CompoundBody, skip analysis and return it.
3001
+ // SQL scripting is a special case as execution occurs during the analysis phase.
3002
+ parsedPlan
3003
+ } else {
3004
+ runWithRefs(refsToAnalyze) {
3005
+ new QueryExecution (session, parsedPlan, tracker).analyzed
3006
+ }
2965
3007
}
2966
3008
}
2967
3009
@@ -3157,11 +3199,32 @@ class SparkConnectPlanner(
3157
3199
}
3158
3200
}
3159
3201
3160
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3161
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162
- val qe = new QueryExecution (session, transformer(tracker), tracker)
3163
- qe.assertCommandExecuted()
3164
- executeHolder.eventsManager.postFinished()
3202
+ private [connect] def runCommand (
3203
+ command : LogicalPlan ,
3204
+ tracker : QueryPlanningTracker ,
3205
+ responseObserver : StreamObserver [ExecutePlanResponse ],
3206
+ protoCommand : proto.Command ,
3207
+ shuffleCleanupMode : Option [ShuffleCleanupMode ] = None ): Unit = {
3208
+ if (protoCommand.getCommandTypeCase == proto.Command .CommandTypeCase .SQL_COMMAND ) {
3209
+ runSQLCommand(
3210
+ command,
3211
+ tracker,
3212
+ responseObserver,
3213
+ protoCommand.getSqlCommand,
3214
+ shuffleCleanupMode)
3215
+ } else {
3216
+ val qe = if (shuffleCleanupMode.isDefined) {
3217
+ new QueryExecution (
3218
+ session,
3219
+ command,
3220
+ tracker = tracker,
3221
+ shuffleCleanupMode = shuffleCleanupMode.get)
3222
+ } else {
3223
+ new QueryExecution (session, command, tracker = tracker)
3224
+ }
3225
+ qe.assertCommandExecuted()
3226
+ executeHolder.eventsManager.postFinished()
3227
+ }
3165
3228
}
3166
3229
3167
3230
/**
@@ -4105,7 +4168,7 @@ class SparkConnectPlanner(
4105
4168
4106
4169
private def transformWithRelations (getWithRelations : proto.WithRelations ): LogicalPlan = {
4107
4170
if (isValidSQLWithRefs(getWithRelations)) {
4108
- transformSqlWithRefs (getWithRelations)
4171
+ executeSQLWithRefs (getWithRelations).logicalPlan
4109
4172
} else {
4110
4173
// Wrap the plan to keep the original planId.
4111
4174
val plan = Project (Seq (UnresolvedStar (None )), transformRelation(getWithRelations.getRoot))
0 commit comments