diff --git a/paddle/ap/include/paddle/pass/add_pcc_pass.h b/paddle/ap/include/paddle/pass/add_pcc_pass.h new file mode 100644 index 0000000000000..24f8e272a688e --- /dev/null +++ b/paddle/ap/include/paddle/pass/add_pcc_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace pir { + +class PassManager; +class Program; + +} // namespace pir + +namespace ap::paddle { + +void ApplyPccPass( + ::pir::Program* program, + const std::function()>& CreatePassManager, + bool is_train_mode = false); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pass/ap_drr_helper.h b/paddle/ap/include/paddle/pass/ap_drr_helper.h index a251cff82a66b..c23b02a46b4eb 100644 --- a/paddle/ap/include/paddle/pass/ap_drr_helper.h +++ b/paddle/ap/include/paddle/pass/ap_drr_helper.h @@ -19,7 +19,7 @@ #include "paddle/ap/include/drr/value.h" #include "paddle/ap/include/registry/abstract_drr_pass_registry_item.h" -namespace cinn::dialect::ir { +namespace ap::paddle { struct ApDrrHelper { public: @@ -53,4 +53,4 @@ struct ApDrrHelper { mutable ap::drr::DrrInterpreter drr_interpreter_; }; -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h b/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h index 9a5667c21521b..244dc2ef14c91 100644 --- a/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h +++ b/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h @@ -30,9 +30,7 @@ struct Value; } -namespace cinn { -namespace dialect { -namespace ir { +namespace ap::paddle { std::optional> CreateApGenericAbstractDrrPass( const std::weak_ptr& circlable_ref_list); @@ -50,6 +48,4 @@ std::optional> CreateCustomAccessTopoDrrPass( std::optional steps_limit, const ap::axpr::Value& mut_matched_pattern_as_programs); -} // namespace ir -} // namespace dialect -} // namespace cinn +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h b/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h index 69f4fa40ad6f8..fb2b8fc8af646 100644 --- a/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h +++ b/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h @@ -21,7 +21,7 @@ #include "paddle/ap/include/code_module/code_module.h" #include "paddle/ap/include/paddle/pir_node.h" -namespace cinn::dialect::ir { +namespace ap::paddle { struct ApKernelDefineHelper { std::weak_ptr circlable_ref_list_; @@ -41,4 +41,4 @@ struct ApKernelDefineHelper { const CodeGenCtx& code_gen_ctx); }; -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pass/ap_registry_helper.h b/paddle/ap/include/paddle/pass/ap_registry_helper.h index b18b3d628e566..8954a92d3eacd 100644 --- a/paddle/ap/include/paddle/pass/ap_registry_helper.h +++ b/paddle/ap/include/paddle/pass/ap_registry_helper.h @@ -17,10 +17,10 @@ #include "paddle/ap/include/adt/adt.h" #include "paddle/ap/include/registry/registry.h" -namespace cinn::dialect::ir { +namespace ap::paddle { struct ApRegistryHelper { ap::adt::Result SingletonRegistry(); }; -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h b/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h index 79786f072f1a5..038dfe9a23f45 100644 --- a/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h +++ b/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h @@ -30,12 +30,10 @@ struct Value; } -namespace cinn { -namespace dialect { -namespace ir { +namespace ap { +namespace paddle { std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass(); -} // namespace ir -} // namespace dialect -} // namespace cinn +} // namespace paddle +} // namespace ap diff --git a/paddle/ap/include/paddle/pass/fallback_fusion_op_to_phi_pass.h b/paddle/ap/include/paddle/pass/fallback_fusion_op_to_phi_pass.h new file mode 100644 index 0000000000000..1f8576f31dc82 --- /dev/null +++ b/paddle/ap/include/paddle/pass/fallback_fusion_op_to_phi_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/pir/include/pass/pass.h" + +namespace ap::memory { + +class CirclableRefListBase; + +} + +namespace ap::axpr { + +struct Value; + +} + +namespace ap { +namespace paddle { + +std::unique_ptr<::pir::Pass> CreateFallbackFusionOpToPhiPass(); + +std::unique_ptr<::pir::Pass> CreateFallbackNestedFusionOpToPhiPass(); + +} // namespace paddle +} // namespace ap diff --git a/paddle/ap/include/paddle/pass/fuse_ap_trivial_pass.h b/paddle/ap/include/paddle/pass/fuse_ap_trivial_pass.h new file mode 100644 index 0000000000000..42a1e253ac53f --- /dev/null +++ b/paddle/ap/include/paddle/pass/fuse_ap_trivial_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/pir/include/pass/pass.h" + +namespace ap::memory { + +class CirclableRefListBase; + +} + +namespace ap::axpr { + +struct Value; + +} + +namespace ap { +namespace paddle { + +std::unique_ptr<::pir::Pass> CreateFuseApTrivialPass(); + +} // namespace paddle +} // namespace ap diff --git a/paddle/ap/include/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.h b/paddle/ap/include/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.h new file mode 100644 index 0000000000000..b0e2285beba6e --- /dev/null +++ b/paddle/ap/include/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/pir/include/pass/pass.h" + +namespace ap::memory { + +class CirclableRefListBase; + +} + +namespace ap::axpr { + +struct Value; + +} + +namespace ap { +namespace paddle { + +std::unique_ptr<::pir::Pass> CreateMoveTrivialFusionRangeToFusionOpPass(); + +} // namespace paddle +} // namespace ap diff --git a/paddle/ap/src/paddle/pass/add_pcc_pass.cc b/paddle/ap/src/paddle/pass/add_pcc_pass.cc new file mode 100644 index 0000000000000..0fc3606adb65b --- /dev/null +++ b/paddle/ap/src/paddle/pass/add_pcc_pass.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/add_pcc_pass.h" + +#include +#include "paddle/common/errors.h" +#include "paddle/common/flags.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/shape_analysis_utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/include/core/ir_context.h" +#include "paddle/pir/include/core/program.h" +#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/include/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" +#include "paddle/pir/include/pass/pass_manager.h" + +#include "paddle/ap/include/memory/guard.h" +#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" +#include "paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h" +#include "paddle/ap/include/paddle/pass/fallback_fusion_op_to_phi_pass.h" +#include "paddle/ap/include/paddle/pass/fuse_ap_trivial_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_full_int_array_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h" +#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h" +#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" +#include "paddle/pir/include/core/ir_printer.h" + +namespace ap::paddle { + +void ApplyShapeOptimizationPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pir::OriginalAttributesFilter::Instance().SetOriginalAttributesMap( + ::paddle::dialect::GetAllOpOriginalAttributes()); + + pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + pass_manager->Run(program); +} + +void ApplyGenerateShapePass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass( + cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass()); + pass_manager->AddPass( + cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass()); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); +} + +void ApplyApGenericDrrPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + ap::memory::Guard guard{}; + if (auto pass = ap::paddle::CreateApGenericClassicDrrPass( + guard.circlable_ref_list())) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); + } + if (auto pass = ap::paddle::CreateApGenericAbstractDrrPass( + guard.circlable_ref_list())) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); + } +} + +void ApplyApFacadePass(::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateConvertPdFacadeToApFacadePass()); + pass_manager->Run(program); +} + +void ApplyFuseApTrivialPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateFuseApTrivialPass()); + pass_manager->Run(program); + } + { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateFallbackNestedFusionOpToPhiPass()); + pass_manager->Run(program); + } +} + +void ApplyFallbackToPhiPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateFallbackFusionOpToPhiPass()); + pass_manager->Run(program); + } + { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass( + cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); + pass_manager->AddPass( + cinn::dialect::ir::CreateRemoveRedundantFullIntArrayPass()); + pass_manager->Run(program); + } +} + +namespace { + +struct FinishLogger { + pir::Program* program_; + int seq_no_{0}; + + void operator()(const std::string& stage_name) { + pir::IrPrinter(LOG(ERROR) + << seq_no_++ << ") after " << stage_name << "():\n") + .PrintProgram(program_); + } +}; + +} // namespace + +void ApplyPccPass( + ::pir::Program* program, + const std::function()>& CreatePassManager, + bool is_train_mode) { + LOG_FIRST_N(INFO, 1) << "Compiling subgraph with PCC backend ..."; + const uint32_t origin_num_ops = program->num_ops(); + if (origin_num_ops == 0) return; + + if (is_train_mode) { + // Skip infer symbol shape in inference, because we have run this pass in + // the previous process + ApplyShapeOptimizationPass(program, CreatePassManager); + } + FinishLogger Logger{program}; + ApplyApFacadePass(program, CreatePassManager); + Logger("ApplyApFacadePass"); + ApplyFuseApTrivialPass(program, CreatePassManager); + Logger("ApplyFuseApTrivialPass"); + ApplyGenerateShapePass(program, CreatePassManager); + Logger("ApplyGenerateShapePass"); + ApplyApGenericDrrPass(program, CreatePassManager); + Logger("ApplyApGenericDrrPass"); + ApplyFallbackToPhiPass(program, CreatePassManager); + Logger("ApplyFallbackToPhiPass"); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/ap_drr_helper.cc b/paddle/ap/src/paddle/pass/ap_drr_helper.cc index f4b596c59cc53..fcb7680159366 100644 --- a/paddle/ap/src/paddle/pass/ap_drr_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_drr_helper.cc @@ -24,7 +24,7 @@ #include "paddle/ap/include/drr/value_method_class.h" #include "paddle/ap/include/paddle/pir/pir_method_class.h" -namespace cinn::dialect::ir { +namespace ap::paddle { namespace adt = ap::adt; @@ -61,4 +61,4 @@ adt::Result ApDrrHelper::Interpret( return drr_interpreter_.InterpretPass(cls); } -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index 17aa852358b88..7b12fa33568ff 100644 --- a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -61,7 +61,7 @@ #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/pass/pass_registry.h" -namespace cinn::dialect::ir { +namespace ap::paddle { namespace adt = ap::adt; @@ -788,7 +788,7 @@ struct ApRewriter { } try { pir::Operation* op = - paddle::drr::OperationFactory::Instance().CreateOperation( + ::paddle::drr::OperationFactory::Instance().CreateOperation( res_ptn_ir_op->op_declare->op_name, inputs, attrs, *rewriter); return op->results(); } catch (const std::exception& e) { @@ -1166,7 +1166,7 @@ struct ApRewriter { const std::string& infer_meta_lambda_str, const std::string& kernel_dispatch_lambda_str, const std::string& kernel_dispatch_const_data_lambda_str) const { - auto ap_variadic = rewriter->Build( + auto ap_variadic = rewriter->Build<::paddle::dialect::ApVariadicOp>( input, num_outputs, code_gen_lambda_str, @@ -3422,4 +3422,4 @@ std::optional> CreateCustomAccessTopoDrrPass( return std::move(pass); } -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc index 9d0c68ab83f05..18cd5ace39706 100644 --- a/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc @@ -23,7 +23,7 @@ #include "paddle/ap/include/paddle/pir_node_descriptor.h" #include "paddle/ap/include/paddle/pir_node_method_class.h" -namespace cinn::dialect::ir { +namespace ap::paddle { namespace adt = ap::adt; @@ -63,4 +63,4 @@ adt::Result ApKernelDefineHelper::Interpret( return m; } -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/ap_registry_helper.cc b/paddle/ap/src/paddle/pass/ap_registry_helper.cc index 97c28316454c4..04086cda55065 100644 --- a/paddle/ap/src/paddle/pass/ap_registry_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_registry_helper.cc @@ -15,7 +15,7 @@ #include "paddle/ap/include/paddle/pass/ap_registry_helper.h" #include "paddle/ap/include/registry/registry_mgr.h" -namespace cinn::dialect::ir { +namespace ap::paddle { namespace { @@ -31,4 +31,4 @@ ap::adt::Result ApRegistryHelper::SingletonRegistry() { return registry; } -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc b/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc index 6046cd76c97b5..dffabc414d8d5 100644 --- a/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc +++ b/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc @@ -29,18 +29,18 @@ #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/pass/pass_registry.h" -namespace cinn::dialect::ir { +namespace ap::paddle { namespace adt = ap::adt; namespace { class ConvertPdFacadeToApFacadePattern - : public pir::OpRewritePattern { + : public pir::OpRewritePattern<::paddle::dialect::ApFacadeOp> { public: - using pir::OpRewritePattern::OpRewritePattern; + using pir::OpRewritePattern<::paddle::dialect::ApFacadeOp>::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::ApFacadeOp pd_facade_op, + bool MatchAndRewrite(::paddle::dialect::ApFacadeOp pd_facade_op, pir::PatternRewriter& rewriter) const override { const auto& ret = TryMatchAndRewrite(pd_facade_op, &rewriter); PADDLE_ENFORCE_EQ( @@ -56,8 +56,9 @@ class ConvertPdFacadeToApFacadePattern return ret.GetOkValue(); } - adt::Result TryMatchAndRewrite(paddle::dialect::ApFacadeOp pd_facade_op, - pir::PatternRewriter* rewriter) const { + adt::Result TryMatchAndRewrite( + ::paddle::dialect::ApFacadeOp pd_facade_op, + pir::PatternRewriter* rewriter) const { std::vector inputs{}; pir::Operation* upstream_op = nullptr; if (pd_facade_op->operand_source(0)) { @@ -93,7 +94,7 @@ class ConvertPdFacadeToApFacadePattern } adt::Result GetFacadeOpAttributes( - paddle::dialect::ApFacadeOp pd_facade_op) const { + ::paddle::dialect::ApFacadeOp pd_facade_op) const { ADT_LET_CONST_REF(serialized_attributes, GetFacadeOpSerializedAttributes(pd_facade_op)); ADT_LET_CONST_REF(lambda, CastStrToLambda(serialized_attributes)); @@ -102,7 +103,7 @@ class ConvertPdFacadeToApFacadePattern } adt::Result GetFacadeOpSerializedAttributes( - paddle::dialect::ApFacadeOp op) const { + ::paddle::dialect::ApFacadeOp op) const { const auto& iter = op->attributes().find("serialized_attributes"); ADT_CHECK(iter != op->attributes().end()); ADT_CHECK(iter->second.template isa()); @@ -132,7 +133,7 @@ class ConvertPdFacadeToApFacadePattern } adt::Result CastToPirAttributeMap( - paddle::dialect::ApFacadeOp pd_facade_op, + ::paddle::dialect::ApFacadeOp pd_facade_op, const ap::axpr::AttrMap& attr_map, const std::string& serialized_attributes) const { pir::AttributeMap attributes{}; @@ -180,4 +181,4 @@ std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass() { return std::make_unique(); } -} // namespace cinn::dialect::ir +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/fallback_fusion_op_to_phi_pass.cc b/paddle/ap/src/paddle/pass/fallback_fusion_op_to_phi_pass.cc new file mode 100644 index 0000000000000..1ead38c0d39a1 --- /dev/null +++ b/paddle/ap/src/paddle/pass/fallback_fusion_op_to_phi_pass.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/fallback_fusion_op_to_phi_pass.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" + +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/paddle/builtin_frame_util.h" +#include "paddle/ap/include/paddle/pir/ap_pir_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/ir_mapping.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace ap::paddle { + +namespace adt = ap::adt; + +namespace { + +template +class FallbackFusionOpToPhiPattern + : public pir::OpRewritePattern<::cinn::dialect::FusionOp> { + public: + using pir::OpRewritePattern<::cinn::dialect::FusionOp>::OpRewritePattern; + + bool MatchAndRewrite(::cinn::dialect::FusionOp fusion_op, + pir::PatternRewriter& rewriter) const override { + const auto& ret = TryMatchAndRewrite(fusion_op, &rewriter); + PADDLE_ENFORCE_EQ( + ret.HasError(), + false, + phi::errors::Fatal( + "FallbackFusionOpToPhiPattern::MatchAndRewrite failed. " + "\nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); + return ret.GetOkValue(); + } + + adt::Result TryMatchAndRewrite(::cinn::dialect::FusionOp fusion_op, + pir::PatternRewriter* rewriter) const { + auto* mut_block = fusion_op->GetParent(); + if constexpr (parent_must_be_fusion_op) { + if (!fusion_op->GetParentOp()->isa<::cinn::dialect::FusionOp>()) + return false; + } + pir::IrMapping ir_mapping{}; + for (pir::Value free_value : pir::GetUsedExternalValue(*fusion_op)) { + ir_mapping.Add(free_value, free_value); + } + std::vector yield_inputs{}; + { + yield_inputs.reserve(fusion_op->num_results()); + auto clone_options = pir::CloneOptions(true, true, true); + for (auto& op : *fusion_op.block()) { + if (op.isa()) { + yield_inputs = op.operands_source(); + } else { + rewriter->Insert(op.Clone(ir_mapping, clone_options)); + } + } + } + for (int i = 0; i < fusion_op->num_results(); ++i) { + rewriter->ReplaceAllUsesWith(fusion_op->result(i), + ir_mapping.Lookup(yield_inputs.at(i))); + } + rewriter->EraseOp(fusion_op); + return true; + } +}; + +template +class FallbackFusionOpToPhiPass : public pir::PatternRewritePass { + public: + FallbackFusionOpToPhiPass() + : pir::PatternRewritePass("fallback_fusion_op_to_phi_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add>(context); + return ps; + } +}; + +} // namespace + +std::unique_ptr<::pir::Pass> CreateFallbackFusionOpToPhiPass() { + return std::make_unique< + FallbackFusionOpToPhiPass>(); +} + +std::unique_ptr<::pir::Pass> CreateFallbackNestedFusionOpToPhiPass() { + return std::make_unique< + FallbackFusionOpToPhiPass>(); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/fuse_ap_trivial_pass.cc b/paddle/ap/src/paddle/pass/fuse_ap_trivial_pass.cc new file mode 100644 index 0000000000000..7672d79760c93 --- /dev/null +++ b/paddle/ap/src/paddle/pass/fuse_ap_trivial_pass.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/fuse_ap_trivial_pass.h" +#include "paddle/ap/include/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.h" + +namespace ap::paddle { + +std::unique_ptr<::pir::Pass> CreateFuseApTrivialPass() { + return CreateMoveTrivialFusionRangeToFusionOpPass(); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/ir_helper_method_class.cc b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc index d3715e86b9a82..5160cc62780a0 100644 --- a/paddle/ap/src/paddle/pass/ir_helper_method_class.cc +++ b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc @@ -46,12 +46,12 @@ struct PirHelperMethodClass { if (args.at(0).template CastableTo()) { ADT_LET_CONST_REF(drr_pass_tag_name, args.at(0).template CastTo()); - opt_pass = cinn::dialect::ir::CreateAccessTopoDrrPass( - interpreter->circlable_ref_list(), - drr_pass_tag_name, - /*steps_limit=*/std::nullopt); + opt_pass = + ap::paddle::CreateAccessTopoDrrPass(interpreter->circlable_ref_list(), + drr_pass_tag_name, + /*steps_limit=*/std::nullopt); } else { - opt_pass = cinn::dialect::ir::CreateCustomAccessTopoDrrPass( + opt_pass = ap::paddle::CreateCustomAccessTopoDrrPass( interpreter->circlable_ref_list(), args.at(0), /*steps_limit=*/std::nullopt, @@ -76,17 +76,17 @@ struct PirHelperMethodClass { if (args->at(0).template CastableTo()) { ADT_LET_CONST_REF(drr_pass_tag_name, args->at(0).template CastTo()); - opt_pass = cinn::dialect::ir::CreateAccessTopoDrrPass( - interpreter->circlable_ref_list(), - drr_pass_tag_name, - /*steps_limit=*/1); + opt_pass = + ap::paddle::CreateAccessTopoDrrPass(interpreter->circlable_ref_list(), + drr_pass_tag_name, + /*steps_limit=*/1); } else { std::optional matched_pattern_mut_list{ kwargs->OptGet("matched_pattern_mut_list")}; if (!matched_pattern_mut_list.has_value()) { matched_pattern_mut_list = adt::Nothing{}; } - opt_pass = cinn::dialect::ir::CreateCustomAccessTopoDrrPass( + opt_pass = ap::paddle::CreateCustomAccessTopoDrrPass( interpreter->circlable_ref_list(), args->at(0), /*steps_limit=*/1, @@ -335,10 +335,9 @@ struct PirHelperMethodClass { args.at(1)}; ADT_LET_CONST_REF(lambda, This{}.GetDrrCtxMaker()); axpr::Function function{lambda, std::nullopt}; - ADT_LET_CONST_REF( - drr_ctx, - cinn::dialect::ir::ApDrrHelper{interpreter->circlable_ref_list()} - .InterpretDrrCtxMaker(function, src_ptn_func_args)); + ADT_LET_CONST_REF(drr_ctx, + ap::paddle::ApDrrHelper{interpreter->circlable_ref_list()} + .InterpretDrrCtxMaker(function, src_ptn_func_args)); ADT_CHECK(drr_ctx->source_pattern_ctx.has_value()); ap::paddle::PackedIrOpInnerSourcePatternHelper src_pattern_helper{}; ADT_LET_CONST_REF( diff --git a/paddle/ap/src/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.cc new file mode 100644 index 0000000000000..81f0458d7f4a7 --- /dev/null +++ b/paddle/ap/src/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.cc @@ -0,0 +1,257 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/move_trivial_fusion_range_to_fusion_op_pass.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" + +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/paddle/builtin_frame_util.h" +#include "paddle/ap/include/paddle/pir/ap_pir_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/ir_mapping.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace ap::paddle { + +namespace adt = ap::adt; + +namespace { + +class MoveTrivialFusionRangeToFusionOpPattern + : public pir::OpRewritePattern<::paddle::dialect::ApTrivialFusionEndOp> { + public: + using pir::OpRewritePattern< + ::paddle::dialect::ApTrivialFusionEndOp>::OpRewritePattern; + + bool MatchAndRewrite( + ::paddle::dialect::ApTrivialFusionEndOp ap_trivial_fusion_end_op, + pir::PatternRewriter& rewriter) const override { + const auto& ret = TryMatchAndRewrite(ap_trivial_fusion_end_op, &rewriter); + PADDLE_ENFORCE_EQ( + ret.HasError(), + false, + phi::errors::Fatal( + "MoveTrivialFusionRangeToFusionOpPattern::MatchAndRewrite failed. " + "\nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); + return ret.GetOkValue(); + } + + adt::Result TryMatchAndRewrite( + ::paddle::dialect::ApTrivialFusionEndOp ap_trivial_fusion_end_op, + pir::PatternRewriter* rewriter) const { + rewriter->SetInsertionPointAfter(ap_trivial_fusion_end_op); + ADT_LET_CONST_REF(ap_trivial_fusion_begin_op, + GetApTrivialFusionBeginOp(ap_trivial_fusion_end_op)); + ADT_LET_CONST_REF( + old_outputs, + GetUsedOutputs(ap_trivial_fusion_begin_op, ap_trivial_fusion_end_op)); + auto fusion_op = rewriter->Build<::cinn::dialect::FusionOp>([&] { + std::vector output_types{}; + output_types.reserve(old_outputs.size()); + for (pir::Value output : old_outputs) { + output_types.emplace_back(output.type()); + } + return output_types; + }()); + pir::IrMapping ir_mapping{}; + { + ADT_LET_CONST_REF(external_inputs, + GetExternalInputs(ap_trivial_fusion_begin_op, + ap_trivial_fusion_end_op)); + for (pir::Value value : external_inputs) { + ir_mapping.Add(value, value); + } + } + std::list reversed_old_ops; + { + auto clone_options = pir::CloneOptions(true, true, true); + for (auto iter = + ++ap_trivial_fusion_begin_op->operator pir::Block::Iterator(); + iter != ap_trivial_fusion_end_op->operator pir::Block::Iterator(); + ++iter) { + fusion_op.block()->push_back(iter->Clone(ir_mapping, clone_options)); + reversed_old_ops.push_front(iter); + } + } + { + std::vector yield_inputs{}; + yield_inputs.reserve(fusion_op->num_results()); + for (int i = 0; i < fusion_op->num_results(); ++i) { + yield_inputs.push_back(ir_mapping.Lookup(old_outputs.at(i))); + } + pir::Builder builder{pir::IrContext::Instance(), fusion_op.block()}; + builder.Build(yield_inputs); + } + for (int i = 0; i < fusion_op->num_results(); ++i) { + rewriter->ReplaceAllUsesWith(old_outputs.at(i), fusion_op->result(i)); + } + { + rewriter->EraseOp(ap_trivial_fusion_end_op); + for (auto* old_op : reversed_old_ops) { + rewriter->EraseOp(old_op); + } + rewriter->EraseOp(ap_trivial_fusion_begin_op); + } + return true; + } + + adt::Result<::paddle::dialect::ApTrivialFusionBeginOp> + GetApTrivialFusionBeginOp( + ::paddle::dialect::ApTrivialFusionEndOp ap_trivial_fusion_end_op) const { + auto* block = ap_trivial_fusion_end_op->GetParent(); + auto ap_trivial_fusion_end_iter = + ap_trivial_fusion_end_op->operator pir::Block::Iterator(); + ADT_CHECK(ap_trivial_fusion_end_iter != block->begin()); + std::size_t depth = 1; + auto iter = ap_trivial_fusion_end_iter; + do { + --iter; + if (iter->isa<::paddle::dialect::ApTrivialFusionEndOp>()) { + ++depth; + } else if (iter->isa<::paddle::dialect::ApTrivialFusionBeginOp>()) { + --depth; + if (depth == 0) + return iter->dyn_cast<::paddle::dialect::ApTrivialFusionBeginOp>(); + } else { + // Do nothing. + } + } while (iter != block->begin()); + return adt::errors::NotImplementedError{ + "no pd_op.ap_trivial_fusion_begin matched."}; + } + + adt::Result> GetExternalInputs( + ::paddle::dialect::ApTrivialFusionBeginOp ap_trivial_fusion_begin_op, + ::paddle::dialect::ApTrivialFusionEndOp ap_trivial_fusion_end_op) const { + using IterT = pir::Block::Iterator; + ADT_LET_CONST_REF( + all_inputs, + GetInputsInRange( + ++ap_trivial_fusion_begin_op->operator IterT(), + ap_trivial_fusion_end_op->operator pir::Block::Iterator())); + ADT_LET_CONST_REF( + all_outputs, + GetOutputsInRange( + ++ap_trivial_fusion_begin_op->operator IterT(), + ap_trivial_fusion_end_op->operator pir::Block::Iterator())); + std::unordered_set ret; + for (pir::Value input : all_inputs) { + if (std::find(all_outputs.begin(), all_outputs.end(), input) == + all_outputs.end()) { + ret.insert(input); + } + } + return ret; + } + + adt::Result> GetUsedOutputs( + ::paddle::dialect::ApTrivialFusionBeginOp ap_trivial_fusion_begin_op, + ::paddle::dialect::ApTrivialFusionEndOp ap_trivial_fusion_end_op) const { + using IterT = pir::Block::Iterator; + ADT_LET_CONST_REF( + tmp_and_outputs, + GetOutputsInRange( + (++ap_trivial_fusion_begin_op->operator IterT()), + ap_trivial_fusion_end_op->operator pir::Block::Iterator())); + ADT_LET_CONST_REF(inputs_after_end, + GetInputsAfter(ap_trivial_fusion_end_op)); + std::vector used_outputs; + used_outputs.reserve(tmp_and_outputs.size()); + for (pir::Value value : tmp_and_outputs) { + if (inputs_after_end.count(value)) { + used_outputs.push_back(value); + } + } + return used_outputs; + } + + adt::Result> GetOutputsBefore( + pir::Operation* op) const { + auto* block = op->GetParent(); + pir::Block::Iterator end = op->operator pir::Block::Iterator(); + return GetOutputsInRange(block->begin(), end); + } + + adt::Result> GetOutputsInRange( + pir::Block::Iterator begin, pir::Block::Iterator end) const { + std::vector outputs; + for (auto iter = begin; iter != end; ++iter) { + for (int i = 0; i < iter->num_results(); ++i) { + outputs.push_back(iter->result(i)); + } + } + return outputs; + } + + adt::Result> GetInputsAfter( + pir::Operation* op) const { + auto* block = op->GetParent(); + using IterT = pir::Block::Iterator; + auto pos = op->operator IterT(); + return GetInputsInRange((++pos), block->end()); + } + + adt::Result> GetInputsInRange( + pir::Block::Iterator begin, pir::Block::Iterator end) const { + std::unordered_set inputs; + auto TryInsertInput = [&](pir::Value input) { + if (!input) return; + inputs.insert(input); + }; + for (auto iter = begin; iter != end; ++iter) { + for (int i = 0; i < iter->num_operands(); ++i) { + TryInsertInput(iter->operand_source(i)); + } + for (pir::Value free_value : pir::GetUsedExternalValue(*iter)) { + TryInsertInput(free_value); + } + } + return inputs; + } +}; + +class MoveTrivialFusionRangeToFusionOpPass : public pir::PatternRewritePass { + public: + MoveTrivialFusionRangeToFusionOpPass() + : pir::PatternRewritePass("move_trivial_fusion_range_to_if_block", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + return ps; + } +}; + +} // namespace + +std::unique_ptr<::pir::Pass> CreateMoveTrivialFusionRangeToFusionOpPass() { + return std::make_unique(); +} + +} // namespace ap::paddle diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index d86ef00658580..b08325ce5c72e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -232,15 +232,15 @@ void ApplyApGenericDrrPass( pir::IrPrinter(LOG(ERROR) << "before ConvertPdFacadeToApFacadePass:\n") .PrintProgram(program); std::shared_ptr pass_manager = CreatePassManager(); - pass_manager->AddPass(CreateConvertPdFacadeToApFacadePass()); + pass_manager->AddPass(ap::paddle::CreateConvertPdFacadeToApFacadePass()); pass_manager->Run(program); pir::IrPrinter(LOG(ERROR) << "after ConvertPdFacadeToApFacadePass:\n") .PrintProgram(program); } ap::memory::Guard guard{}; - if (auto pass = CreateApGenericClassicDrrPass(guard.circlable_ref_list())) { + if (auto pass = ap::paddle::CreateApGenericClassicDrrPass( + guard.circlable_ref_list())) { std::shared_ptr pass_manager = CreatePassManager(); - pass_manager->AddPass(CreateConvertPdFacadeToApFacadePass()); pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pir::IrPrinter(LOG(ERROR) << "before ApGenericClassicDrrPass:\n") @@ -249,7 +249,8 @@ void ApplyApGenericDrrPass( pir::IrPrinter(LOG(ERROR) << "after ApGenericClassicDrrPass:\n") .PrintProgram(program); } - if (auto pass = CreateApGenericAbstractDrrPass(guard.circlable_ref_list())) { + if (auto pass = ap::paddle::CreateApGenericAbstractDrrPass( + guard.circlable_ref_list())) { std::shared_ptr pass_manager = CreatePassManager(); pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 9b29cdadcabe5..6883715d5bfbc 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -72,6 +72,8 @@ 'AddNInferMeta', 'ApVariadicInferMeta', 'ApFacadeInferMeta', + 'ApTrivialFusionBeginInferMeta', + 'ApTrivialFusionEndInferMeta', 'AddNTensorArrayInferMeta', 'AttentionLstmInferMeta', 'AucInferMeta', diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc index ce4702ec34306..e47b8f6f37d13 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc @@ -13,18 +13,44 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h" - +#ifdef PADDLE_WITH_CINN +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h" +#endif #include "paddle/common/ddim.h" #include "paddle/common/enforce.h" #include "paddle/common/layout.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +namespace paddle::dialect { + +bool ApTrivialFusionBeginOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { #ifdef PADDLE_WITH_CINN -#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h" + symbol::ShapeOrDataDimExprs empty_shape{ + symbol::TensorShapeOrDataDimExprs{std::vector{}}}; + infer_context->SetShapeOrDataForValue(op->result(0), empty_shape); + return true; +#else + PADDLE_THROW(phi::errors::Unimplemented( + "ap_trivial_fusion_begin is not implemented when cinn is not enabled.")); + return false; #endif +} -namespace paddle::dialect { +bool ApTrivialFusionEndOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { +#ifdef PADDLE_WITH_CINN + symbol::ShapeOrDataDimExprs empty_shape{ + symbol::TensorShapeOrDataDimExprs{std::vector{}}}; + infer_context->SetShapeOrDataForValue(op->result(0), empty_shape); + return true; +#else + PADDLE_THROW(phi::errors::Unimplemented( + "ap_trivial_fusion_end is not implemented when cinn is not enabled.")); + return false; +#endif +} bool ApFacadeOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h index 65f76357ede43..e0645046e6d21 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h @@ -18,6 +18,8 @@ namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ApTrivialFusionBegin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ApTrivialFusionEnd) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ApFacade) } // namespace paddle::dialect diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index aa75119ba556a..54eecc360adc9 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -91,6 +91,7 @@ #ifdef PADDLE_WITH_CINN #include "paddle/ap/include/paddle/hlir/op_dialect.h" +#include "paddle/ap/include/paddle/pass/add_pcc_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.h" @@ -2775,6 +2776,32 @@ void ApplyCinnPass(Program &program) { // NOLINT #endif } +void ApplyPccPass(Program &program) { // NOLINT +#ifdef PADDLE_WITH_CINN + auto CreatePassManager = [&]() -> std::shared_ptr { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto pass_manager = std::make_shared(ctx); + if (FLAGS_print_ir && VLOG_IS_ON(4)) { + pass_manager->EnableIRPrinting(); + } + auto &shape_analysis = pir::ShapeAnalysisManager::Instance().Get(&program); + pass_manager->SetValueReplacedHook([&](pir::Value from, pir::Value to) { + shape_analysis.ShareShapeOrData(from, to); + }); + return pass_manager; + }; + ap::paddle::ApplyPccPass(&program, CreatePassManager); +#else + PADDLE_THROW(common::errors::Unimplemented( + "Currently we only support CINN Pass for Pir under @to_static, please " + "compile PaddlePaddle with CINN")); +#endif +} + void CheckInferSymbolicIfNeed(Program &program) { // NOLINT #ifdef PADDLE_WITH_CINN auto CreatePassManager = [&]() -> std::shared_ptr { @@ -2848,6 +2875,7 @@ std::shared_ptr ApplyFusedBnAddActPass( void BindIrPass(pybind11::module *m) { m->def("apply_cinn_pass", ApplyCinnPass); + m->def("apply_pcc_pass", ApplyPccPass); m->def("check_infer_symbolic_if_need", CheckInferSymbolicIfNeed); m->def("infer_symbolic_shape_pass", InferSymbolicShapePass); m->def("apply_cse_pass", ApplyCommonSubexpressionEliminationPass); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0fb3a776edcb5..830c7a44acf0f 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -523,6 +523,22 @@ void ApFacadeInferMeta( #endif } +void ApTrivialFusionBeginInferMeta( + const paddle::optional>& xs, + MetaTensor* out, + MetaConfig config) { + out->set_dims(common::make_ddim({})); + out->set_dtype(phi::DataType::BOOL); +} + +void ApTrivialFusionEndInferMeta( + const paddle::optional>& xs, + MetaTensor* out, + MetaConfig config) { + out->set_dims(common::make_ddim({})); + out->set_dtype(phi::DataType::BOOL); +} + // TODO(YuanRisheng) This InferMeta is used in Fluid // and will be deleted in the future. void AddNTensorArrayInferMeta(const std::vector& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 435bc8f5ece1a..84f1d668ce0a7 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -140,6 +140,16 @@ void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ApTrivialFusionBeginInferMeta( + const paddle::optional>& xs, + MetaTensor* out, + MetaConfig config = MetaConfig()); + +void ApTrivialFusionEndInferMeta( + const paddle::optional>& xs, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ApFacadeInferMeta( const paddle::optional>& xs, int64_t num_outputs, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 19b9f35bbf5a2..ac91d0092c3de 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -242,9 +242,11 @@ endif() # Remove AP kernel when CINN is not enabled. if(NOT WITH_CINN) list(REMOVE_ITEM kernel_cu "gpu/ap_facade_kernel.cu" - "gpu/ap_variadic_kernel.cu") + "gpu/ap_trivial_fusion_begin_kernel.cu" + "gpu/ap_trivial_fusion_end_kernel.cu" "gpu/ap_variadic_kernel.cu") list(REMOVE_ITEM kernel_gpu "gpu/ap_facade_kernel.cu" - "gpu/ap_variadic_kernel.cu") + "gpu/ap_trivial_fusion_begin_kernel.cu" + "gpu/ap_trivial_fusion_end_kernel.cu" "gpu/ap_variadic_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/gpu/ap_trivial_fusion_begin_kernel.cu b/paddle/phi/kernels/gpu/ap_trivial_fusion_begin_kernel.cu new file mode 100644 index 0000000000000..98f22de0fab2b --- /dev/null +++ b/paddle/phi/kernels/gpu/ap_trivial_fusion_begin_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/common/enforce.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ApTrivialFusionBeginKernel( + const Context& dev_ctx, + const paddle::optional>& xs, + DenseTensor* out) { + PADDLE_THROW(common::errors::Unimplemented( + "pd_op.ap_trivial_fusion_begin has no kernel registered.")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(ap_trivial_fusion_begin, + GPU, + ALL_LAYOUT, + phi::ApTrivialFusionBeginKernel, + float, + double, + int, + phi::dtype::bfloat16, + phi::dtype::float16, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/ap_trivial_fusion_end_kernel.cu b/paddle/phi/kernels/gpu/ap_trivial_fusion_end_kernel.cu new file mode 100644 index 0000000000000..9fb985f40f2a6 --- /dev/null +++ b/paddle/phi/kernels/gpu/ap_trivial_fusion_end_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/common/enforce.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ApTrivialFusionEndKernel( + const Context& dev_ctx, + const paddle::optional>& xs, + DenseTensor* out) { + PADDLE_THROW(common::errors::Unimplemented( + "pd_op.ap_trivial_fusion_end has no kernel registered.")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(ap_trivial_fusion_end, + GPU, + ALL_LAYOUT, + phi::ApTrivialFusionEndKernel, + float, + double, + int, + phi::dtype::bfloat16, + phi::dtype::float16, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 7b4044b5aa1c8..7c142cb4419cd 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -289,6 +289,28 @@ func : ap_facade traits : paddle::dialect::ForwardOnlyTrait +- op : ap_trivial_fusion_begin + args : (Tensor[] xs) + output : Tensor(out) + optional : xs + infer_meta : + func : ApTrivialFusionBeginInferMeta + interfaces : paddle::dialect::InferSymbolicShapeInterface + kernel : + func : ap_trivial_fusion_begin + traits : pir::SideEffectTrait, paddle::dialect::ForwardOnlyTrait + +- op : ap_trivial_fusion_end + args : (Tensor[] xs) + output : Tensor(out) + optional : xs + infer_meta : + func : ApTrivialFusionEndInferMeta + interfaces : paddle::dialect::InferSymbolicShapeInterface + kernel : + func : ap_trivial_fusion_end + traits : pir::SideEffectTrait, paddle::dialect::ForwardOnlyTrait + - op : ap_variadic args : (Tensor[] xs, int num_outputs, str code_module_lambda, str infer_meta_lambda, str rnel_dispatch_lambda, str kernel_dispatch_const_data_lambda) output : Tensor[](out){num_outputs} diff --git a/python/paddle/incubate/cc/compiler.py b/python/paddle/incubate/cc/compiler.py index e0b835fbae019..51b40bc4297bf 100644 --- a/python/paddle/incubate/cc/compiler.py +++ b/python/paddle/incubate/cc/compiler.py @@ -69,17 +69,20 @@ def _compile( ap_workspace_dir='/tmp/paddle/ap', backend_device='cuda', target_framework='paddle', + compile_engine='PCC', ): assert ap_path is not None + assert not train, "only support inference now" os.makedirs(ap_workspace_dir, exist_ok=True) build_strategy = paddle.static.BuildStrategy() + assert compile_engine in ('CINN', 'PCC') with _ap_envs(ap_path, ap_workspace_dir): static_fn = paddle.jit.to_static( func, input_spec=input_specs, build_strategy=build_strategy, full_graph=True, - backend='CINN', + backend=compile_engine, ) if not train: static_fn.eval() @@ -92,7 +95,10 @@ def _compile( ) partial_program_layer.training = static_fn._is_train_mode() # Force to generate the program immediately. - _ = partial_program_layer.train_program.forward_program + if train: + _ = partial_program_layer.train_program.forward_program + else: + _ = partial_program_layer.infer_program.forward_program return partial_program_layer diff --git a/python/paddle/incubate/cc/fuse.py b/python/paddle/incubate/cc/fuse.py index 3a17419e5273f..f6bbd55f11735 100644 --- a/python/paddle/incubate/cc/fuse.py +++ b/python/paddle/incubate/cc/fuse.py @@ -12,11 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + import paddle -__all__ = ['matmul'] +__all__ = ['matmul', 'by_register'] + + +@contextmanager +def by_register(): + paddle._C_ops.ap_trivial_fusion_begin(None) + yield + paddle._C_ops.ap_trivial_fusion_end(None) def matmul(x, w, epilogue, **kwargs): x = paddle.matmul(x, w, **kwargs) - return epilogue(x) + with by_register(): + return epilogue(x) diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index 580bb2b529e88..05dff0fde5bc0 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -284,11 +284,12 @@ def to_static( f"Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {type(build_strategy).__name__}" ) backend = Backend.from_arg(backend) - backend = ( - Backend.CINN - if infer_use_cinn_backend(backend, build_strategy) - else Backend.PHI - ) + if infer_use_cinn_backend(backend, build_strategy): + backend = Backend.CINN + elif backend.is_pcc(): + pass + else: + backend = Backend.PHI def decorated(python_func): """ diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 04db9542c272f..8d2b954dfea14 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -819,19 +819,28 @@ def _create_program(self, is_infer_mode=False) -> RunnableProgram: if is_infer_mode: def pass_fn(forward_program, backward_program, program_name_attr): - apply_general_passes( - forward_program, - enable_cse=cse_is_enabled(), - enable_delete_assert_op=self._backend.is_cinn(), - ) # if-else pass if self._backend.is_cinn(): + apply_general_passes( + forward_program, + enable_cse=cse_is_enabled(), + enable_delete_assert_op=self._backend.is_cinn(), + ) paddle.base.libpaddle.pir.bind_symbolic_constraints( forward_program, self._constraints ) paddle.base.libpaddle.pir.apply_cinn_pass(forward_program) - + elif self._backend.is_pcc(): + paddle.base.libpaddle.pir.bind_symbolic_constraints( + forward_program, self._constraints + ) + paddle.base.libpaddle.pir.apply_pcc_pass(forward_program) else: + apply_general_passes( + forward_program, + enable_cse=cse_is_enabled(), + enable_delete_assert_op=self._backend.is_cinn(), + ) paddle.base.libpaddle.pir.check_infer_symbolic_if_need( forward_program ) @@ -933,7 +942,11 @@ def get_kwargs_forward_matched_value(kw_name, kw_value): forward_program, backward_program ) paddle.base.libpaddle.pir.apply_cinn_pass(backward_program) - + elif self._backend.is_pcc(): + paddle.base.libpaddle.pir.bind_symbolic_constraints( + forward_program, self._constraints + ) + paddle.base.libpaddle.pir.apply_pcc_pass(forward_program) else: paddle.base.libpaddle.pir.check_infer_symbolic_if_need( forward_program diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index ce9ce2df3e069..52adc8eb9d8a6 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -89,6 +89,7 @@ class Backend(Enum): CINN = auto() PHI = auto() + PCC = auto() @staticmethod def from_arg(arg: str | Backend | None): @@ -98,6 +99,8 @@ def from_arg(arg: str | Backend | None): return Backend.PHI if arg.upper() == "CINN": return Backend.CINN + if arg.upper() == "PCC": + return Backend.PCC raise ValueError( f"Unknown backend {arg}. Only support 'CINN' or None for PHI." ) @@ -105,6 +108,9 @@ def from_arg(arg: str | Backend | None): def is_cinn(self): return self == Backend.CINN + def is_pcc(self): + return self == Backend.PCC + def is_phi(self): return self == Backend.PHI