14
14
15
15
#include " paddle/cinn/hlir/framework/pir/trivial_op_util.h"
16
16
17
+ #include " paddle/cinn/common/dim_expr_converter.h"
17
18
#include " paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
18
19
#include " paddle/cinn/hlir/framework/compile_error.h"
19
20
#include " paddle/cinn/hlir/framework/pir/op_lowering_util.h"
@@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
547
548
* remove it in axes.bind()
548
549
*/
549
550
const auto & f = [=](const ir::Expr& e) -> ir::Expr {
550
- VLOG (4 ) << " Start RemoveVarInScheduleBlockRealize(" << target_vars << " , "
551
- << replaced_expr << " )" ;
552
- VLOG (4 ) << " Input is " << e;
553
551
PADDLE_ENFORCE_NE (
554
552
e.As <ir::ScheduleBlockRealize>(),
555
553
nullptr ,
@@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
562
560
auto block_bound_vars = copied_ir.As <ir::ScheduleBlockRealize>()
563
561
->schedule_block .As <ir::ScheduleBlock>()
564
562
->iter_vars ;
565
- for (const auto & i_var : schedule_block_iter_vars) {
566
- PADDLE_ENFORCE_EQ (
567
- i_var.is_var (),
568
- true ,
569
- ::common::errors::InvalidArgument (" RemoveVarInScheduleBlockRealize: "
570
- " axes.bind rhs is is not a Var." ));
571
- }
572
563
// find replace idx
573
564
int target_idx = -1 ;
574
565
for (int i = 0 ; i < schedule_block_iter_vars.size (); ++i) {
575
- VLOG (4 ) << " RemoveVarInScheduleBlockRealize: compare with "
576
- << schedule_block_iter_vars[i] << " vs " << target_vars
577
- << " , and equality is: "
578
- << (schedule_block_iter_vars[i].as_var ()->name ==
579
- target_vars->name );
580
- if (schedule_block_iter_vars[i].as_var ()->name == target_vars->name ) {
566
+ if (schedule_block_iter_vars[i].is_var () &&
567
+ schedule_block_iter_vars[i].as_var ()->name == target_vars->name ) {
581
568
target_idx = i;
582
569
}
583
570
}
@@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
688
675
.GetSingle (copied);
689
676
const ir::Expr& target_block =
690
677
ExprSetFinderUtils::DirectlyFather (copied).GetSingle (target_for);
691
- VLOG (4 ) << " RemoveOneTransformer: directly target_block of for is "
692
- << target_block;
693
678
if (target_block.As <ir::ScheduleBlockRealize>() != nullptr ) {
694
679
VLOG (4 ) << " RemoveOneTransformer: father block is root realize" ;
695
680
ir::Expr shedule_block =
@@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
708
693
shedule_block.As <ir::ScheduleBlock>()->body = for_body;
709
694
}
710
695
} else if (target_block.As <ir::Block>() != nullptr ) {
711
- VLOG (4 ) << " RemoveOneTransformer: father block is Block" ;
712
696
std::vector<ir::Expr> new_bodies;
713
697
for (const auto & expr : target_block.As <ir::Block>()->stmts ) {
714
698
if (expr != target_for) {
@@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
728
712
" RemoveOneTransformer: target for father should be a ir::Block or "
729
713
" ir::ScheduleBlockRealize." ));
730
714
}
731
- VLOG (4 ) << " Remove Var to 0 in ScheduleBlockRealizer: " << copied;
732
715
// Remove var to 0 in ScheduleBlockRealizer
733
716
InplaceMutateSingleExpr (
734
717
&copied,
@@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {
949
932
950
933
ir::Expr GetBodyBlock (const ir::Expr& root) {
951
934
const auto & iters = GetNonReduceLoopVars (root);
935
+ if (iters.empty ()) {
936
+ return ir::Block::Make (
937
+ {ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle (root)});
938
+ }
952
939
const size_t reduce_size =
953
940
std::count_if (iters.begin (), iters.end (), [](const ir::Var& v) {
954
941
return v->is_reduce_axis ;
@@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
965
952
->body ;
966
953
}
967
954
955
+ ir::Expr ReshapeLoop (const ir::Expr& root,
956
+ const std::vector<symbol::DimExpr>& in_shape,
957
+ const std::vector<symbol::DimExpr>& out_shape) {
958
+ auto copied = ir::ir_utils::IRCopy (root);
959
+
960
+ ir::ModuleExpr mod_expr ({copied});
961
+ ir::IRSchedule ir_sch (
962
+ mod_expr, -1 , false , cinn::utils::ErrorMessageLevel::kGeneral , true );
963
+
964
+ const auto block_realize =
965
+ (ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle (copied);
966
+ const auto block_name = block_realize.As <ir::ScheduleBlockRealize>()
967
+ ->schedule_block .As <ir::ScheduleBlock>()
968
+ ->name ;
969
+ const auto shape_partion = fusion::PartionReshapeAxes (in_shape, out_shape);
970
+
971
+ for (int idx = shape_partion.size () - 1 ; idx > 0 ; --idx) {
972
+ const auto & in_s = shape_partion[idx - 1 ].first ;
973
+ const auto & in_e = shape_partion[idx].first ;
974
+ const auto & out_s = shape_partion[idx - 1 ].second ;
975
+ const auto & out_e = shape_partion[idx].second ;
976
+
977
+ std::vector<int > fuse_indices;
978
+ for (int i = in_e - 1 ; i >= in_s; --i) {
979
+ if (in_shape[i] != symbol::DimExpr (1 )) {
980
+ fuse_indices.insert (fuse_indices.begin (), i);
981
+ } else {
982
+ VLOG (4 ) << " Remove index[" << i << " ]: " << in_shape[i]
983
+ << " for expr: \n "
984
+ << copied;
985
+ copied = ExprTransformerUtils::RemoveOneTransformer (i)(copied);
986
+ ir_sch.SetExprs ({copied});
987
+ for (auto & index : fuse_indices) {
988
+ index --;
989
+ }
990
+ }
991
+ }
992
+ if (fuse_indices.size () > 1 ) {
993
+ VLOG (4 ) << " fuse_indices: " << cinn::utils::Join (fuse_indices, " ," );
994
+ ir_sch.Fuse (block_name, fuse_indices);
995
+ }
996
+
997
+ std::vector<ir::Expr> split_shapes;
998
+ for (int i = out_s; i < out_e; ++i) {
999
+ if (out_shape[i] != symbol::DimExpr (1 )) {
1000
+ split_shapes.push_back (
1001
+ cinn::common::DimExprConverter ().ConvertToIrExpr (out_shape[i]));
1002
+ }
1003
+ }
1004
+ if (split_shapes.size () > 1 ) {
1005
+ ir_sch.Split (ir_sch.GetLoops (block_name)[in_s], split_shapes)[0 ];
1006
+ }
1007
+ }
1008
+
1009
+ std::vector<int > insert_axis;
1010
+ std::vector<ir::Var> ones_var;
1011
+ for (int i = 0 ; i < out_shape.size (); ++i) {
1012
+ if (out_shape[i] == symbol::DimExpr (1 )) {
1013
+ insert_axis.push_back (i);
1014
+ ones_var.push_back (ir::Var (1 , " one_" + std::to_string (ones_var.size ())));
1015
+ }
1016
+ }
1017
+ copied = ExprTransformerUtils::InsertForsTransformer (insert_axis,
1018
+ ones_var)(copied);
1019
+
1020
+ return copied;
1021
+ }
1022
+
968
1023
} // namespace trivial_fusion_detail
969
1024
} // namespace pir
970
1025
} // namespace framework
0 commit comments