Skip to content

Commit aba7ffd

Browse files
committed
Fix
2 parents 7de71de + 98becdc commit aba7ffd

File tree

86 files changed

+1258
-1225
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1258
-1225
lines changed

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

-15
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
5757
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
5858
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
59-
#include "paddle/fluid/pir/transforms/general/identity_op_clean_pass.h"
6059
#include "paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.h"
6160

6261
COMMON_DECLARE_bool(cinn_specify_input_dynamic_dim);
@@ -95,19 +94,6 @@ bool HasDynamicShape(const pir::Program& program) {
9594
}
9695
} // namespace
9796

98-
void ApplyIdentityOpCleanPass(
99-
::pir::Program* program,
100-
const std::function<std::shared_ptr<::pir::PassManager>()>&
101-
CreatePassManager) {
102-
// NOTE(gongshaotian):Before Paddle 3.0, useless full op and scale op are
103-
// inserted at the end of the Program when export models using Paddle. This
104-
// Pass is designed to address the above-mentioned issues encountered when
105-
// open CINN execution in some models that cannot be reexported.
106-
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
107-
pass_manager->AddPass(pir::CreateIdentityOpCleanPass());
108-
pass_manager->Run(program);
109-
}
110-
11197
void ApplyShapeOptimizationPass(
11298
::pir::Program* program,
11399
const std::function<std::shared_ptr<::pir::PassManager>()>&
@@ -278,7 +264,6 @@ void ApplyCinnPass(::pir::Program* program,
278264
.file_name("original_programs.py")
279265
.dump_symbolic_shape(FLAGS_logging_pir_py_code_dump_symbolic_dims)
280266
.SaveIfFlagEnabled();
281-
ApplyIdentityOpCleanPass(program, CreatePassManager);
282267
ApplyShapeOptimizationPass(program, CreatePassManager);
283268
ApplyPdToCinnPass(program, CreatePassManager);
284269
ApplyCinnPreprocessPass(program, CreatePassManager);

paddle/cinn/hlir/framework/pir/trivial_op_impl.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
3737
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
3838

39-
PD_DECLARE_bool(group_schedule_tiling_first);
39+
PD_DECLARE_bool(cinn_enable_grid_reduce);
4040

4141
namespace cinn {
4242
namespace hlir {
@@ -808,8 +808,10 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
808808
});
809809
}
810810

811-
group_info->can_apply_grid_reduce =
812-
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
811+
if (FLAGS_cinn_enable_grid_reduce) {
812+
group_info->can_apply_grid_reduce =
813+
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
814+
}
813815

814816
VLOG(4) << group_info->DebugPrint();
815817
return group_info;

paddle/cinn/hlir/framework/pir/trivial_op_util.cc

+33-20
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,9 @@ ExprTransformer UnsqueezeForTransformer(
471471
schedule_block->body =
472472
WrapForTransformer(to_append_var)(schedule_block->body);
473473
} else {
474-
PADDLE_THROW(
474+
PADDLE_THROW(::common::errors::PreconditionNotMet(
475475
"UnsqueezeForTransformer: only support insert after a (For / "
476-
"ScheduleBlockRealizer): %s",
477-
followed_expr);
476+
"ScheduleBlockRealizer)"));
478477
}
479478
VLOG(6) << "UnsqueezeForTransformer: After changed: " << copied_e;
480479
return copied_e;
@@ -547,19 +546,24 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
547546
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
548547
<< replaced_expr << ")";
549548
VLOG(4) << " Input is " << e;
550-
PADDLE_ENFORCE(e.As<ir::ScheduleBlockRealize>() != nullptr,
551-
"RemoveVarInScheduleBlockRealize: input expr is not a "
552-
"ScheduleBlockRealize.");
549+
PADDLE_ENFORCE_NE(
550+
e.As<ir::ScheduleBlockRealize>(),
551+
nullptr,
552+
::common::errors::InvalidArgument(
553+
"RemoveVarInScheduleBlockRealize: input expr is not a "
554+
"ScheduleBlockRealize."));
553555
auto copied_ir = ir::ir_utils::IRCopy(e);
554556
auto schedule_block_iter_vars =
555557
copied_ir.As<ir::ScheduleBlockRealize>()->iter_values;
556558
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
557559
->schedule_block.As<ir::ScheduleBlock>()
558560
->iter_vars;
559561
for (const auto& i_var : schedule_block_iter_vars) {
560-
PADDLE_ENFORCE(
562+
PADDLE_ENFORCE_EQ(
561563
i_var.is_var(),
562-
"RemoveVarInScheduleBlockRealize: axes.bind rhs is is not a Var.");
564+
true,
565+
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
566+
"axes.bind rhs is is not a Var."));
563567
}
564568
// find replace idx
565569
int target_idx = -1;
@@ -686,10 +690,11 @@ ExprTransformer RemoveOneTransformer(int one) {
686690
VLOG(4) << "RemoveOneTransformer: father block is root realize";
687691
ir::Expr shedule_block =
688692
target_block.As<ir::ScheduleBlockRealize>()->schedule_block;
689-
PADDLE_ENFORCE_EQ(shedule_block.As<ir::ScheduleBlock>()->body,
690-
target_for,
691-
::common::errors::PreconditionNotMet(
692-
"Root realize body should be equal to target for"));
693+
PADDLE_ENFORCE_EQ(
694+
shedule_block.As<ir::ScheduleBlock>()->body,
695+
target_for,
696+
::common::errors::InvalidArgument(
697+
"Root realize body should be equal to target for."));
693698
const auto for_body = target_for.As<ir::For>()->body;
694699
const auto for_body_stmts = for_body.As<ir::Block>()->stmts;
695700
if (for_body_stmts.size() == 1 &&
@@ -747,12 +752,17 @@ ExprTransformer RemoveOnesTransformer(const std::vector<int32_t>& ones) {
747752
ExprTransformer TransposeForsTransformer(const std::vector<int32_t>& perm) {
748753
const auto& f = [=](const ir::Expr& root) -> ir::Expr {
749754
const auto& iters = GetNonReduceLoopVars(root);
750-
PADDLE_ENFORCE_EQ(iters.size(),
751-
perm.size(),
752-
"Transposed iters size and perm size should be equal.");
755+
PADDLE_ENFORCE_EQ(
756+
iters.size(),
757+
perm.size(),
758+
::common::errors::InvalidArgument(
759+
"Transposed iters size and perm size should be equal."));
753760
for (size_t i = 0; i < perm.size(); ++i) {
754761
if (iters[i]->is_reduce_axis) {
755-
PADDLE_ENFORCE_EQ(i, perm[i], "Can only transpose non reduce iters.");
762+
PADDLE_ENFORCE_EQ(i,
763+
perm[i],
764+
::common::errors::InvalidArgument(
765+
"Can only transpose non reduce iters."));
756766
}
757767
}
758768
const auto transposed_iters = cinn::fusion::TransposeVector(iters, perm);
@@ -773,7 +783,7 @@ ExprTransformer InsertForsTransformer(const std::vector<int32_t>& axis,
773783
axis.size(),
774784
vars.size(),
775785
::common::errors::InvalidArgument(
776-
"The number of axis to insert and vars should be equal"));
786+
"The number of axis to insert and vars should be equal."));
777787
const size_t reduce_size =
778788
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
779789
return v->is_reduce_axis;
@@ -782,7 +792,7 @@ ExprTransformer InsertForsTransformer(const std::vector<int32_t>& axis,
782792
PADDLE_ENFORCE_LE(axis[i],
783793
iters.size() - reduce_size,
784794
::common::errors::OutOfRange(
785-
"Insert axis should not be behind reduce axis"));
795+
"Insert axis should not be behind reduce axis."));
786796
iters.insert(iters.begin() + axis[i], vars[i]);
787797
}
788798
const auto non_reduce_iters =
@@ -837,7 +847,7 @@ int InplaceMutateSingleExpr(ir::Expr* root,
837847
PADDLE_ENFORCE_EQ(
838848
source.size(),
839849
1,
840-
::common::errors::InvalidArgument("Only one expr should be found"));
850+
::common::errors::InvalidArgument("Only one expr should be found."));
841851
const auto& target = transformer(source[0]);
842852
ComposeUtils::MappingTargetExprToDestExprMutator(source[0], target)(root);
843853
return 1; // operation number.
@@ -880,7 +890,10 @@ void CheckFusionInputValid(const std::vector<ir::Expr>& op_compute_bodies,
880890
VLOG(4) << " op_patterns.size() = " << op_compute_bodies.size();
881891
VLOG(4) << "op_compute_bodies.size() = " << op_patterns.size();
882892
PADDLE_ENFORCE_EQ(
883-
op_patterns.size(), op_compute_bodies.size(), "ops and size not equal");
893+
op_patterns.size(),
894+
op_compute_bodies.size(),
895+
::common::errors::InvalidArgument(
896+
"The number of op_compute_bodies and op_patterns should be equal."));
884897
}
885898

886899
std::vector<ir::Var> AppendBound(const std::vector<ir::Var> vars,

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

+1-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
3434
// A tensor is downstream of reduce either if it is produced by a reduce, or
3535
// if it has data dependency on another tensor that is downstream of reduce.
3636
std::unordered_set<std::string> reduce_downstream_tensor_names;
37-
int reduce_count = 0;
3837

3938
const auto IsReduceDownstream = [&](const ir::Expr& expr_block) {
4039
for (auto& expr_load : ChildTensorLoads(expr_block)) {
@@ -90,9 +89,6 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
9089
bool is_reduce_downstream = IsReduceDownstream(expr_block);
9190
bool output_has_reduce_axis = CheckOutputHasReduceAxis(body, expr_block);
9291

93-
if (is_reduce) {
94-
++reduce_count;
95-
}
9692
if (is_reduce_downstream || is_reduce) {
9793
AddReduceDownstream(expr_block);
9894
}
@@ -105,8 +101,7 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
105101
return false;
106102
}
107103
}
108-
109-
return reduce_count == 1;
104+
return true;
110105
}
111106

112107
} // namespace ir

paddle/cinn/ir/group_schedule/config/group_tile_util.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ namespace cinn {
1919
namespace ir {
2020

2121
// Check whether we can apply grid reduce in this group.
22-
// We can apply grid reduce if there is exactly one reduce, and whose result is
23-
// not broadcasted before output.
22+
// We can apply grid reduce if there is no reduce-then-broadcast dependency
23+
// in this group.
2424
bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
2525
const std::vector<int64_t>& reduce_axis);
2626

paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc

+49-5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class TileFirstGeneralTactic final : public ScheduleTactic {
7979
std::vector<int32_t> vec_flatten_axis_;
8080
std::vector<int32_t> vec_reduce_axis_;
8181
std::unordered_map<std::string, std::string> map_rf_block_;
82+
std::unordered_map<std::string, std::string> map_global_rf_block_;
8283
};
8384

8485
void TileFirstGeneralTactic::Init(ScheduleContext* context) {
@@ -381,19 +382,46 @@ void TileFirstGeneralTactic::SplitSptialInner(ir::IRSchedule* sch,
381382

382383
void TileFirstGeneralTactic::SplitReduceInner(ir::IRSchedule* sch,
383384
const std::string& block_id) {
385+
const int64_t rd_block = context_->config.tile_config.grid_reduce_num;
386+
const int64_t rd_thread = 16;
387+
const int cur_reduce_axis = 2;
388+
389+
// [ R ] => [ rd_block*rd_thread, rd_inner ]
384390
auto loops = sch->GetLoops(block_id);
385-
// [S(-1), S(32), R] => [S(-1), S(32), R(16), R(-1)]
386-
sch->Split(loops[2], std::vector<int>{16, -1});
391+
sch->Split(loops[cur_reduce_axis],
392+
std::vector<int>{-1, rd_block * rd_thread});
393+
loops = sch->GetLoops(block_id);
394+
sch->Reorder({loops[cur_reduce_axis + 1], loops[cur_reduce_axis]});
387395

388396
loops = sch->GetLoops(block_id);
389397
if (IsReductionSBlock(sch->GetBlock(block_id)) &&
390398
ir::GetLoopExtent(loops[2]) != 1) {
391399
ir::Expr rf_tensor =
392-
sch->FactorizeReduction(loops[2],
393-
0,
400+
sch->FactorizeReduction(loops[cur_reduce_axis],
401+
/* rf_axis = */ 0,
394402
/* with_write_back_block_init = */ false);
395403
map_rf_block_[block_id] = rf_tensor.as_tensor_ref()->name;
396404
}
405+
406+
// [ rd_block*rd_thread ] => [ rd_block, rd_thread ]
407+
if (rd_block > 1) {
408+
loops = sch->GetLoops(block_id);
409+
sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread});
410+
411+
if (IsReductionSBlock(sch->GetBlock(block_id))) {
412+
loops = sch->GetLoops(map_rf_block_[block_id]);
413+
sch->Split(loops[cur_reduce_axis], {rd_block, rd_thread});
414+
415+
loops = sch->GetLoops(block_id);
416+
ir::Expr rf_tensor =
417+
sch->FactorizeReduction(loops[cur_reduce_axis],
418+
/* rf_axis = */ 0,
419+
/* with_write_back_block_init = */ false);
420+
std::string tensor_name = rf_tensor.as_tensor_ref()->name;
421+
map_global_rf_block_[block_id] = tensor_name;
422+
rf_tensor.as_tensor_ref()->WithBuffer("global", "_" + tensor_name);
423+
}
424+
}
397425
}
398426

399427
void TileFirstGeneralTactic::VariableTypeAssignment(
@@ -435,6 +463,12 @@ void TileFirstGeneralTactic::SetDiscreteReduceType(
435463
->schedule_block.As<ir::ScheduleBlock>();
436464
block->reduce_method = cinn::ir::DiscreteReduceMethod();
437465
}
466+
if (map_global_rf_block_.count(block_id) > 0) {
467+
auto block = sch->GetBlock(map_global_rf_block_[block_id])
468+
.As<ir::ScheduleBlockRealize>()
469+
->schedule_block.As<ir::ScheduleBlock>();
470+
block->reduce_method = cinn::ir::DiscreteReduceMethod();
471+
}
438472
}
439473

440474
void TileFirstGeneralTactic::BindCudaInfo(ir::IRSchedule* sch,
@@ -446,14 +480,24 @@ void TileFirstGeneralTactic::BindCudaInfo(ir::IRSchedule* sch,
446480
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
447481
sch->Bind(loops[0], "blockIdx.x");
448482
sch->Bind(loops[1], "threadIdx.x");
449-
sch->Bind(loops[2], "threadIdx.y");
483+
if (context_->config.tile_config.grid_reduce_num > 1) {
484+
sch->Bind(loops[2], "blockIdx.y");
485+
if (loops.size() > 3) {
486+
sch->Bind(loops[3], "threadIdx.y");
487+
}
488+
} else {
489+
sch->Bind(loops[2], "threadIdx.y");
490+
}
450491
};
451492

452493
DoBind(sch->GetLoops(block_id));
453494

454495
if (map_rf_block_.count(block_id) > 0) {
455496
DoBind(sch->GetLoops(map_rf_block_[block_id]));
456497
}
498+
if (map_global_rf_block_.count(block_id) > 0) {
499+
DoBind(sch->GetLoops(map_global_rf_block_[block_id]));
500+
}
457501
}
458502

459503
std::unique_ptr<ScheduleTactic> CreateTileFirstGeneralTactic() {

paddle/cinn/operator_fusion/pir_graph_analyzing/fusion_iters.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ struct FusionItersManager {
3636
ShardableAxesInfoManager* axes_info)
3737
: shape_analysis_(shape_analysis), axes_info_(axes_info) {
3838
PADDLE_ENFORCE_NOT_NULL(shape_analysis,
39-
"shape_analysis should not be nullptr.");
40-
PADDLE_ENFORCE_NOT_NULL(axes_info, "axes_info should not be nullptr.");
39+
::common::errors::InvalidArgument(
40+
"shape_analysis should not be nullptr."));
41+
PADDLE_ENFORCE_NOT_NULL(
42+
axes_info,
43+
::common::errors::InvalidArgument("axes_info should not be nullptr."));
4144
}
4245
FusionItersSignature GetItersSignature(pir::Operation* op);
4346

paddle/cinn/operator_fusion/policy/iters_fusion_policy.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ struct ItersFusionPolicy final : public PolicyBase {
2525
ItersFusionPolicy(std::shared_ptr<FusionItersManager> iters_manager)
2626
: iters_manager_(iters_manager) {
2727
PADDLE_ENFORCE_NOT_NULL(iters_manager,
28-
"iters_manager should not be nullptr.");
28+
::common::errors::InvalidArgument(
29+
"iters_manager should not be nullptr."));
2930
}
3031
static constexpr PolicyKind Kind = PolicyKind::ItersFusion;
3132
std::string Name() { return "ItersFusionPolicy"; }

0 commit comments

Comments
 (0)