@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
35
35
import org .apache .spark .annotation .{DeveloperApi , Since }
36
36
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
37
37
import org .apache .spark .connect .proto
38
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
+ import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
39
39
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
40
40
import org .apache .spark .connect .proto .Parse .ParseFormat
41
41
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
69
69
import org .apache .spark .sql .connect .service .{ExecuteHolder , SessionHolder , SparkConnectService }
70
70
import org .apache .spark .sql .connect .utils .MetricGenerator
71
71
import org .apache .spark .sql .errors .QueryCompilationErrors
72
- import org .apache .spark .sql .execution .QueryExecution
72
+ import org .apache .spark .sql .execution .{ QueryExecution , ShuffleCleanupMode }
73
73
import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , TypedAggregateExpression }
74
74
import org .apache .spark .sql .execution .arrow .ArrowConverters
75
75
import org .apache .spark .sql .execution .command .{CreateViewCommand , ExternalCommandExecutor }
@@ -343,13 +343,6 @@ class SparkConnectPlanner(
343
343
}
344
344
}
345
345
346
- private def transformSqlWithRefs (query : proto.WithRelations ): LogicalPlan = {
347
- if (! isValidSQLWithRefs(query)) {
348
- throw InvalidInputErrors .invalidSQLWithReferences(query)
349
- }
350
- executeSQLWithRefs(query).logicalPlan
351
- }
352
-
353
346
private def transformSubqueryAlias (alias : proto.SubqueryAlias ): LogicalPlan = {
354
347
val aliasIdentifier =
355
348
if (alias.getQualifierCount > 0 ) {
@@ -2651,6 +2644,8 @@ class SparkConnectPlanner(
2651
2644
Some (transformWriteOperation(command.getWriteOperation))
2652
2645
case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2653
2646
Some (transformWriteOperationV2(command.getWriteOperationV2))
2647
+ case proto.Command .CommandTypeCase .SQL_COMMAND =>
2648
+ Some (transformSqlCommand(command.getSqlCommand))
2654
2649
case _ =>
2655
2650
None
2656
2651
}
@@ -2661,7 +2656,8 @@ class SparkConnectPlanner(
2661
2656
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2662
2657
val transformerOpt = transformCommand(command)
2663
2658
if (transformerOpt.isDefined) {
2664
- transformAndRunCommand(transformerOpt.get)
2659
+ val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2660
+ runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
2665
2661
return
2666
2662
}
2667
2663
command.getCommandTypeCase match {
@@ -2675,8 +2671,6 @@ class SparkConnectPlanner(
2675
2671
handleCreateViewCommand(command.getCreateDataframeView)
2676
2672
case proto.Command .CommandTypeCase .EXTENSION =>
2677
2673
handleCommandPlugin(command.getExtension)
2678
- case proto.Command .CommandTypeCase .SQL_COMMAND =>
2679
- handleSqlCommand(command.getSqlCommand, responseObserver)
2680
2674
case proto.Command .CommandTypeCase .WRITE_STREAM_OPERATION_START =>
2681
2675
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
2682
2676
case proto.Command .CommandTypeCase .STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2775,8 @@ class SparkConnectPlanner(
2781
2775
.build())
2782
2776
}
2783
2777
2784
- private def handleSqlCommand (
2785
- command : SqlCommand ,
2786
- responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2787
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788
-
2789
- val relation = if (command.hasInput) {
2778
+ private def getRelationFromSQLCommand (command : proto.SqlCommand ): proto.Relation = {
2779
+ if (command.hasInput) {
2790
2780
command.getInput
2791
2781
} else {
2792
2782
// for backward compatibility
@@ -2803,19 +2793,47 @@ class SparkConnectPlanner(
2803
2793
.build())
2804
2794
.build()
2805
2795
}
2796
+ }
2806
2797
2807
- val df = relation.getRelTypeCase match {
2798
+ private def transformSqlCommand (command : proto.SqlCommand )(
2799
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2800
+ val relation = getRelationFromSQLCommand(command)
2801
+
2802
+ relation.getRelTypeCase match {
2808
2803
case proto.Relation .RelTypeCase .SQL =>
2809
- executeSQL (relation.getSql, tracker)
2804
+ transformSQL (relation.getSql, tracker)
2810
2805
case proto.Relation .RelTypeCase .WITH_RELATIONS =>
2811
- executeSQLWithRefs (relation.getWithRelations, tracker)
2806
+ transformSQLWithRefs (relation.getWithRelations, tracker)
2812
2807
case other =>
2813
2808
throw InvalidInputErrors .sqlCommandExpectsSqlOrWithRelations(other)
2814
2809
}
2810
+ }
2811
+
2812
+ private def runSQLCommand (
2813
+ command : LogicalPlan ,
2814
+ tracker : QueryPlanningTracker ,
2815
+ responseObserver : StreamObserver [ExecutePlanResponse ],
2816
+ protoCommand : proto.Command ,
2817
+ shuffleCleanupMode : Option [ShuffleCleanupMode ]): Unit = {
2818
+ val isSqlScript = command.isInstanceOf [CompoundBody ]
2819
+ val refs = if (isSqlScript && protoCommand.getSqlCommand.getInput.hasWithRelations) {
2820
+ protoCommand.getSqlCommand.getInput.getWithRelations.getReferencesList.asScala
2821
+ .map(_.getSubqueryAlias)
2822
+ .toSeq
2823
+ } else {
2824
+ Seq .empty
2825
+ }
2826
+
2827
+ val df = runWithRefs(refs) {
2828
+ if (shuffleCleanupMode.isDefined) {
2829
+ Dataset .ofRows(session, command, tracker, shuffleCleanupMode.get)
2830
+ } else {
2831
+ Dataset .ofRows(session, command, tracker)
2832
+ }
2833
+ }
2815
2834
2816
2835
// Check if command or SQL Script has been executed.
2817
2836
val isCommand = df.queryExecution.commandExecuted.isInstanceOf [CommandResult ]
2818
- val isSqlScript = df.queryExecution.logical.isInstanceOf [CompoundBody ]
2819
2837
val rows = df.logicalPlan match {
2820
2838
case lr : LocalRelation => lr.data
2821
2839
case cr : CommandResult => cr.rows
@@ -2867,7 +2885,7 @@ class SparkConnectPlanner(
2867
2885
} else {
2868
2886
// No execution triggered for relations. Manually set ready
2869
2887
tracker.setReadyForExecution()
2870
- result.setRelation(relation )
2888
+ result.setRelation(getRelationFromSQLCommand(protoCommand.getSqlCommand) )
2871
2889
}
2872
2890
executeHolder.eventsManager.postFinished(Some (rows.size))
2873
2891
// Exactly one SQL Command Result Batch
@@ -2909,59 +2927,79 @@ class SparkConnectPlanner(
2909
2927
true
2910
2928
}
2911
2929
2912
- private def executeSQLWithRefs (
2913
- query : proto.WithRelations ,
2914
- tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2915
- if (! isValidSQLWithRefs(query)) {
2916
- throw InvalidInputErrors .invalidSQLWithReferences(query)
2917
- }
2918
-
2919
- // Eagerly execute commands of the provided SQL string, with given references.
2920
- val sql = query.getRoot.getSql
2930
+ private def runWithRefs [T ](refs : Seq [proto.SubqueryAlias ])(f : => T ): T = {
2921
2931
this .synchronized {
2922
2932
try {
2923
- query.getReferencesList.asScala .foreach { ref =>
2933
+ refs .foreach { ref =>
2924
2934
Dataset
2925
- .ofRows(session, transformRelation(ref.getSubqueryAlias. getInput))
2926
- .createOrReplaceTempView(ref.getSubqueryAlias. getAlias)
2935
+ .ofRows(session, transformRelation(ref.getInput))
2936
+ .createOrReplaceTempView(ref.getAlias)
2927
2937
}
2928
- executeSQL(sql, tracker)
2938
+ f
2929
2939
} finally {
2930
2940
// drop all temporary views
2931
- query.getReferencesList.asScala .foreach { ref =>
2932
- session.catalog.dropTempView(ref.getSubqueryAlias. getAlias)
2941
+ refs .foreach { ref =>
2942
+ session.catalog.dropTempView(ref.getAlias)
2933
2943
}
2934
2944
}
2935
2945
}
2936
2946
}
2937
2947
2938
- private def executeSQL (
2939
- sql : proto.SQL ,
2948
+ private def transformSQLWithRefs (
2949
+ query : proto.WithRelations ,
2950
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2951
+ if (! isValidSQLWithRefs(query)) {
2952
+ throw InvalidInputErrors .invalidSQLWithReferences(query)
2953
+ }
2954
+
2955
+ transformSQL(
2956
+ query.getRoot.getSql,
2957
+ tracker,
2958
+ runWithRefs(query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq))
2959
+ }
2960
+
2961
+ private def executeSQLWithRefs (
2962
+ query : proto.WithRelations ,
2940
2963
tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2941
- // Eagerly execute commands of the provided SQL string.
2964
+ Dataset .ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2965
+ }
2966
+
2967
+ private def transformSQL (
2968
+ sql : proto.SQL ,
2969
+ tracker : QueryPlanningTracker ,
2970
+ withAnalysis : ( => LogicalPlan ) => LogicalPlan = x => x): LogicalPlan = {
2942
2971
val args = sql.getArgsMap
2943
2972
val namedArguments = sql.getNamedArgumentsMap
2944
2973
val posArgs = sql.getPosArgsList
2945
2974
val posArguments = sql.getPosArgumentsList
2946
- if (! namedArguments.isEmpty) {
2947
- session.sql (
2975
+ val parsedPlan = if (! namedArguments.isEmpty) {
2976
+ session.sqlParsedPlan (
2948
2977
sql.getQuery,
2949
2978
namedArguments.asScala.toMap.transform((_, e) => Column (transformExpression(e))),
2950
2979
tracker)
2951
2980
} else if (! posArguments.isEmpty) {
2952
- session.sql (
2981
+ session.sqlParsedPlan (
2953
2982
sql.getQuery,
2954
2983
posArguments.asScala.map(e => Column (transformExpression(e))).toArray,
2955
2984
tracker)
2956
2985
} else if (! args.isEmpty) {
2957
- session.sql (
2986
+ session.sqlParsedPlan (
2958
2987
sql.getQuery,
2959
2988
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
2960
2989
tracker)
2961
2990
} else if (! posArgs.isEmpty) {
2962
- session.sql (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2991
+ session.sqlParsedPlan (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2963
2992
} else {
2964
- session.sql(sql.getQuery, Map .empty[String , Any ], tracker)
2993
+ session.sqlParsedPlan(sql.getQuery, Map .empty[String , Any ], tracker)
2994
+ }
2995
+ if (parsedPlan.isInstanceOf [CompoundBody ]) {
2996
+ // If the parsed plan is a CompoundBody, skip analysis and return it.
2997
+ // SQL scripting is a special case as execution occurs during the analysis phase.
2998
+ parsedPlan
2999
+ } else {
3000
+ withAnalysis {
3001
+ new QueryExecution (session, parsedPlan, tracker).analyzed
3002
+ }
2965
3003
}
2966
3004
}
2967
3005
@@ -3157,11 +3195,27 @@ class SparkConnectPlanner(
3157
3195
}
3158
3196
}
3159
3197
3160
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3161
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162
- val qe = new QueryExecution (session, transformer(tracker), tracker)
3163
- qe.assertCommandExecuted()
3164
- executeHolder.eventsManager.postFinished()
3198
+ private [connect] def runCommand (
3199
+ command : LogicalPlan ,
3200
+ tracker : QueryPlanningTracker ,
3201
+ responseObserver : StreamObserver [ExecutePlanResponse ],
3202
+ protoCommand : proto.Command ,
3203
+ shuffleCleanupMode : Option [ShuffleCleanupMode ] = None ): Unit = {
3204
+ if (protoCommand.getCommandTypeCase == proto.Command .CommandTypeCase .SQL_COMMAND ) {
3205
+ runSQLCommand(command, tracker, responseObserver, protoCommand, shuffleCleanupMode)
3206
+ } else {
3207
+ val qe = if (shuffleCleanupMode.isDefined) {
3208
+ new QueryExecution (
3209
+ session,
3210
+ command,
3211
+ tracker = tracker,
3212
+ shuffleCleanupMode = shuffleCleanupMode.get)
3213
+ } else {
3214
+ new QueryExecution (session, command, tracker = tracker)
3215
+ }
3216
+ qe.assertCommandExecuted()
3217
+ executeHolder.eventsManager.postFinished()
3218
+ }
3165
3219
}
3166
3220
3167
3221
/**
@@ -4105,7 +4159,7 @@ class SparkConnectPlanner(
4105
4159
4106
4160
private def transformWithRelations (getWithRelations : proto.WithRelations ): LogicalPlan = {
4107
4161
if (isValidSQLWithRefs(getWithRelations)) {
4108
- transformSqlWithRefs (getWithRelations)
4162
+ executeSQLWithRefs (getWithRelations).logicalPlan
4109
4163
} else {
4110
4164
// Wrap the plan to keep the original planId.
4111
4165
val plan = Project (Seq (UnresolvedStar (None )), transformRelation(getWithRelations.getRoot))
0 commit comments