Skip to content

Commit 486d259

Browse files
authored
[CINN] Add cinn process in auto_parallel (#69362)
* refine * fix * refine code
1 parent 6d939df commit 486d259

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

python/paddle/base/executor.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,21 @@ def _get_pir_program_and_executor(self, cached_data):
11521152
place = cached_data.place
11531153
scope = cached_data.scope
11541154

1155+
def cinn_process(program):
1156+
from paddle.decomposition import decomp
1157+
1158+
if core._enable_dist_prim_all():
1159+
logging.info("apply decompose in executor")
1160+
with decomp.prim_guard():
1161+
decomp.decompose_dist_program(program)
1162+
1163+
if core._enable_auto_recompute():
1164+
logging.info("apply auto_recompute in executor")
1165+
program = decomp.auto_recompute_pir_program(program, None)
1166+
1167+
apply_cinn_pass(program)
1168+
return program
1169+
11551170
if cached_data.plan is None:
11561171
value_map = pir.IrMapping()
11571172
_, is_startup_program = has_fetch_operations_and_is_startup_program(
@@ -1174,6 +1189,10 @@ def _get_pir_program_and_executor(self, cached_data):
11741189
fetch_var_name=fetch_var_name,
11751190
)
11761191
default_job = core.Job("default")
1192+
1193+
if not is_startup_program and in_cinn_mode():
1194+
cinn_process(program)
1195+
11771196
type_to_program = {"default": program}
11781197
plan = core.Plan([default_job], type_to_program)
11791198
else:
@@ -1200,6 +1219,11 @@ def _get_pir_program_and_executor(self, cached_data):
12001219
value.block.program, value, fetch_var_name + str(i), i
12011220
)
12021221

1222+
if in_cinn_mode():
1223+
for job_type in plan.job_types():
1224+
ir_program = plan.ir_program(job_type)
1225+
cinn_process(ir_program)
1226+
12031227
new_exe = _StandaloneExecutor(place, plan, scope)
12041228

12051229
data_op_infos = []
@@ -1216,18 +1240,7 @@ def _get_pir_program_and_executor(self, cached_data):
12161240
op.result(0).persistable,
12171241
)
12181242
data_op_infos.append(tup)
1219-
from paddle.decomposition import decomp
1220-
1221-
if core._enable_dist_prim_all():
1222-
with decomp.prim_guard():
1223-
decomp.decompose_dist_program(program)
12241243

1225-
if core._enable_auto_recompute():
1226-
logging.info("apply auto_recompute in executor")
1227-
program = decomp.auto_recompute_pir_program(program, None)
1228-
1229-
if in_cinn_mode():
1230-
apply_cinn_pass(program)
12311244
return program, new_exe, data_op_infos
12321245

12331246

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from paddle import pir, static, utils
3030
from paddle.base.executor import _to_name_str
3131
from paddle.base.framework import auto_complete_op_role
32+
from paddle.decomposition import decomp
3233
from paddle.distributed.fleet.meta_optimizers.common import OpRole
3334
from paddle.distributed.passes.pass_base import new_pass
3435
from paddle.distributed.passes.pass_utils import (
@@ -897,20 +898,15 @@ def _parallel_pir(self, mode):
897898
remove_unuseful_comm_op_pass(dense_program)
898899

899900
if core._enable_dist_prim_all():
900-
from paddle.decomposition import decomp
901-
901+
logging.info("apply decompose in auto parallel")
902902
with decomp.prim_guard():
903903
decomp.decompose_dist_program(dense_program)
904904

905905
if core._enable_auto_recompute():
906-
from paddle.decomposition import decomp
907-
908906
logging.info("apply auto_recompute in auto parallel")
909907
dense_program = decomp.auto_recompute_pir_program(
910908
dense_program,
911-
lambda op: bool(
912-
op.has_attr('op_role') and op.attrs()["op_role"] == 0
913-
),
909+
lambda op: bool(op.has_attr('op_role') and op.op_role == 0),
914910
)
915911

916912
if self._strategy.pipeline.enable:

0 commit comments

Comments
 (0)