@@ -1152,6 +1152,21 @@ def _get_pir_program_and_executor(self, cached_data):
1152
1152
place = cached_data .place
1153
1153
scope = cached_data .scope
1154
1154
1155
+ def cinn_process (program ):
1156
+ from paddle .decomposition import decomp
1157
+
1158
+ if core ._enable_dist_prim_all ():
1159
+ logging .info ("apply decompose in executor" )
1160
+ with decomp .prim_guard ():
1161
+ decomp .decompose_dist_program (program )
1162
+
1163
+ if core ._enable_auto_recompute ():
1164
+ logging .info ("apply auto_recompute in executor" )
1165
+ program = decomp .auto_recompute_pir_program (program , None )
1166
+
1167
+ apply_cinn_pass (program )
1168
+ return program
1169
+
1155
1170
if cached_data .plan is None :
1156
1171
value_map = pir .IrMapping ()
1157
1172
_ , is_startup_program = has_fetch_operations_and_is_startup_program (
@@ -1174,6 +1189,10 @@ def _get_pir_program_and_executor(self, cached_data):
1174
1189
fetch_var_name = fetch_var_name ,
1175
1190
)
1176
1191
default_job = core .Job ("default" )
1192
+
1193
+ if not is_startup_program and in_cinn_mode ():
1194
+ cinn_process (program )
1195
+
1177
1196
type_to_program = {"default" : program }
1178
1197
plan = core .Plan ([default_job ], type_to_program )
1179
1198
else :
@@ -1200,6 +1219,11 @@ def _get_pir_program_and_executor(self, cached_data):
1200
1219
value .block .program , value , fetch_var_name + str (i ), i
1201
1220
)
1202
1221
1222
+ if in_cinn_mode ():
1223
+ for job_type in plan .job_types ():
1224
+ ir_program = plan .ir_program (job_type )
1225
+ cinn_process (ir_program )
1226
+
1203
1227
new_exe = _StandaloneExecutor (place , plan , scope )
1204
1228
1205
1229
data_op_infos = []
@@ -1216,18 +1240,7 @@ def _get_pir_program_and_executor(self, cached_data):
1216
1240
op .result (0 ).persistable ,
1217
1241
)
1218
1242
data_op_infos .append (tup )
1219
- from paddle .decomposition import decomp
1220
-
1221
- if core ._enable_dist_prim_all ():
1222
- with decomp .prim_guard ():
1223
- decomp .decompose_dist_program (program )
1224
1243
1225
- if core ._enable_auto_recompute ():
1226
- logging .info ("apply auto_recompute in executor" )
1227
- program = decomp .auto_recompute_pir_program (program , None )
1228
-
1229
- if in_cinn_mode ():
1230
- apply_cinn_pass (program )
1231
1244
return program , new_exe , data_op_infos
1232
1245
1233
1246
0 commit comments