Skip to content

【CINN】Enable AutoLayoutPass flag in train process #71891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 15, 2025
18 changes: 16 additions & 2 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1395,15 +1395,29 @@ PHI_DEFINE_EXPORTED_bool(
* Performance related FLAG
* Name: enable_auto_layout_pass
* Since Version: 3.0.0
* Value Range: bool, default=false
* Value Range: bool, default=true
* Example:
* Note: If True, using AutoLayoutInsertPass and AutuLayoutSimplifyPass by
* default
*/
PHI_DEFINE_EXPORTED_bool(enable_auto_layout_pass,
false,
true,
"Whether enable auto_layout_pass.");

/**
* Performance related FLAG
* Name: enable_auto_layout_pass_in_inference
* Since Version: 3.0.0
* Value Range: bool, default=false
* Example:
* Note: This is a temporary flag, When enabled by default in the inference
* process, this flag will be removed and enabled or disabled by the
* `enable_auto_layout_pass` flag.
*/
PHI_DEFINE_EXPORTED_bool(enable_auto_layout_pass_in_inference,
false,
"Whether enable auto_layout_pass_in_inference.");

/**
* JitLayer related FLAG
* Name: FLAGS_jit_engine_type
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
#include "paddle/pir/include/pass/pass_registry.h"

COMMON_DECLARE_bool(pir_apply_inplace_pass);
COMMON_DECLARE_bool(enable_auto_layout_pass);
COMMON_DECLARE_bool(enable_auto_layout_pass_in_inference);
namespace paddle {
namespace {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand Down Expand Up @@ -944,7 +944,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {

if (config_.enable_gpu_mixed_) {
AddAutoMixedPrecisionPass(fused_op_pm);
if (FLAGS_enable_auto_layout_pass) {
if (FLAGS_enable_auto_layout_pass_in_inference) {
AddAutoLayoutPasses(fused_op_pm);
} else {
fused_op_pm.AddPass(
Expand Down Expand Up @@ -1104,7 +1104,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
AddAutoMixedPrecisionPass(basic_pass_pm);
}
}
if (FLAGS_enable_auto_layout_pass) {
if (FLAGS_enable_auto_layout_pass_in_inference) {
AddAutoLayoutPasses(basic_pass_pm);
} else {
auto transfer_layout_pass = ::pir::CreateTransferLayoutPass();
Expand Down
63 changes: 35 additions & 28 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
auto_layout_is_enabled,
backend_guard,
cse_is_enabled,
train_guards,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -720,20 +721,23 @@ def __call__(self, inputs):
out_vars = self._prepare_outputs()
attrs = self._prepare_attributes(in_sot_mode=False)
inputs = self._valid_vars(in_vars)
_C_ops.run_program(
inputs,
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
cache_key=(
hash_with_seed(
self.program_id, self._calc_input_places_hash(inputs)
)

with train_guards(self._backend):
_C_ops.run_program(
inputs,
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
cache_key=(
hash_with_seed(
self.program_id,
self._calc_input_places_hash(inputs),
)
),
use_scope_cache=True,
),
use_scope_cache=True,
),
*attrs,
)
*attrs,
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)

Expand All @@ -744,20 +748,23 @@ def sot_call(self, inputs):
out_vars = self._prepare_outputs()
attrs = self._prepare_attributes(in_sot_mode=True)
inputs = self._valid_vars(inputs)
_C_ops.run_program(
inputs,
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
cache_key=(
hash_with_seed(
self.program_id, self._calc_input_places_hash(inputs)
)

with train_guards(self._backend):
_C_ops.run_program(
inputs,
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(
cache_key=(
hash_with_seed(
self.program_id,
self._calc_input_places_hash(inputs),
)
),
use_scope_cache=True,
),
use_scope_cache=True,
),
*attrs,
)
*attrs,
)
return self._outputs.quick_restore(out_vars)

@cached_property
Expand Down Expand Up @@ -820,7 +827,7 @@ def pass_fn(forward_program, backward_program, program_name_attr):

# TODO(xiongkun) who to transfer the pruning program?
infer_program = self.origin_runnable_program.clone()
if auto_layout_is_enabled():
if auto_layout_is_enabled() and self._backend.is_cinn():
pm = paddle.pir.PassManager(2)
pm.add_pass("auto_layout_pass", {})
pm.run(infer_program.program)
Expand All @@ -835,7 +842,7 @@ def pass_fn(forward_program, backward_program, program_name_attr):
train_program.apply_dist_pass_for_origin_program()

# Author(liujinnan): auto_layout_pass should be applied to the original_program, before append backward. So we put it here.
if auto_layout_is_enabled():
if auto_layout_is_enabled() and self._backend.is_cinn():
pm = paddle.pir.PassManager(2)
pm.add_pass("auto_layout_pass", {})
pm.run(train_program.program)
Expand Down
30 changes: 30 additions & 0 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,36 @@ def is_api_in_module_helper(obj, module_prefix):
return m is not None and m.__name__.startswith(module_prefix)


def auto_layout_guard(backend, guard_creators):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥要叫 guard 呢?或许应该叫 add_auto_layout_guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,将下一个PR中进行修复。

# AutoLayoutPass may change layout of bn to NHWC, if not enable `FLAGS_cudnn_batchnorm_spatial_persistent`, it will revert to NCHW. So if the user does not set this Flag, we set it to True.
if (
auto_layout_is_enabled()
and backend.is_cinn()
and paddle.is_compiled_with_cuda()
and os.getenv("FLAGS_cudnn_batchnorm_spatial_persistent") is None
):
guard_creators.append(
lambda: paddle.base.framework.flag_guard(
"FLAGS_cudnn_batchnorm_spatial_persistent",
True,
)
)


@contextmanager
def train_guards(backend):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_program_op 在 train 和 eval 模式都会跑,train_guards 是否不太合适?或许应该叫 runtime_guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,将下一个PR中进行修复。

"""
train_guards is the guard method before program execution, which can integrate and add various guards.
"""
guard_creators = []
# Add FLAGS_cudnn_batchnorm_spatial_persistent guard
auto_layout_guard(backend, guard_creators)
# add more guards here

with compose_guards(*guard_creators)():
yield


def auto_layout_is_enabled():
return paddle.get_flags(["FLAGS_enable_auto_layout_pass"])[
"FLAGS_enable_auto_layout_pass"
Expand Down
Loading