Skip to content

Commit ea2bc4d

Browse files
authored
[CINN / Fusion] Align leaf reshape to input shape (#69478)
1 parent 609ba01 commit ea2bc4d

File tree

17 files changed

+387
-97
lines changed

17 files changed

+387
-97
lines changed

paddle/cinn/hlir/framework/pir/trivial_op_util.cc

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h"
1616

17+
#include "paddle/cinn/common/dim_expr_converter.h"
1718
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1819
#include "paddle/cinn/hlir/framework/compile_error.h"
1920
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
@@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
547548
* remove it in axes.bind()
548549
*/
549550
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
550-
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
551-
<< replaced_expr << ")";
552-
VLOG(4) << " Input is " << e;
553551
PADDLE_ENFORCE_NE(
554552
e.As<ir::ScheduleBlockRealize>(),
555553
nullptr,
@@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
562560
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
563561
->schedule_block.As<ir::ScheduleBlock>()
564562
->iter_vars;
565-
for (const auto& i_var : schedule_block_iter_vars) {
566-
PADDLE_ENFORCE_EQ(
567-
i_var.is_var(),
568-
true,
569-
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
570-
"axes.bind rhs is is not a Var."));
571-
}
572563
// find replace idx
573564
int target_idx = -1;
574565
for (int i = 0; i < schedule_block_iter_vars.size(); ++i) {
575-
VLOG(4) << "RemoveVarInScheduleBlockRealize: compare with "
576-
<< schedule_block_iter_vars[i] << " vs " << target_vars
577-
<< ", and equality is: "
578-
<< (schedule_block_iter_vars[i].as_var()->name ==
579-
target_vars->name);
580-
if (schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
566+
if (schedule_block_iter_vars[i].is_var() &&
567+
schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
581568
target_idx = i;
582569
}
583570
}
@@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
688675
.GetSingle(copied);
689676
const ir::Expr& target_block =
690677
ExprSetFinderUtils::DirectlyFather(copied).GetSingle(target_for);
691-
VLOG(4) << "RemoveOneTransformer: directly target_block of for is "
692-
<< target_block;
693678
if (target_block.As<ir::ScheduleBlockRealize>() != nullptr) {
694679
VLOG(4) << "RemoveOneTransformer: father block is root realize";
695680
ir::Expr shedule_block =
@@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
708693
shedule_block.As<ir::ScheduleBlock>()->body = for_body;
709694
}
710695
} else if (target_block.As<ir::Block>() != nullptr) {
711-
VLOG(4) << "RemoveOneTransformer: father block is Block";
712696
std::vector<ir::Expr> new_bodies;
713697
for (const auto& expr : target_block.As<ir::Block>()->stmts) {
714698
if (expr != target_for) {
@@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
728712
"RemoveOneTransformer: target for father should be a ir::Block or "
729713
"ir::ScheduleBlockRealize."));
730714
}
731-
VLOG(4) << "Remove Var to 0 in ScheduleBlockRealizer: " << copied;
732715
// Remove var to 0 in ScheduleBlockRealizer
733716
InplaceMutateSingleExpr(
734717
&copied,
@@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {
949932

950933
ir::Expr GetBodyBlock(const ir::Expr& root) {
951934
const auto& iters = GetNonReduceLoopVars(root);
935+
if (iters.empty()) {
936+
return ir::Block::Make(
937+
{ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle(root)});
938+
}
952939
const size_t reduce_size =
953940
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
954941
return v->is_reduce_axis;
@@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
965952
->body;
966953
}
967954

955+
ir::Expr ReshapeLoop(const ir::Expr& root,
956+
const std::vector<symbol::DimExpr>& in_shape,
957+
const std::vector<symbol::DimExpr>& out_shape) {
958+
auto copied = ir::ir_utils::IRCopy(root);
959+
960+
ir::ModuleExpr mod_expr({copied});
961+
ir::IRSchedule ir_sch(
962+
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);
963+
964+
const auto block_realize =
965+
(ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle(copied);
966+
const auto block_name = block_realize.As<ir::ScheduleBlockRealize>()
967+
->schedule_block.As<ir::ScheduleBlock>()
968+
->name;
969+
const auto shape_partion = fusion::PartionReshapeAxes(in_shape, out_shape);
970+
971+
for (int idx = shape_partion.size() - 1; idx > 0; --idx) {
972+
const auto& in_s = shape_partion[idx - 1].first;
973+
const auto& in_e = shape_partion[idx].first;
974+
const auto& out_s = shape_partion[idx - 1].second;
975+
const auto& out_e = shape_partion[idx].second;
976+
977+
std::vector<int> fuse_indices;
978+
for (int i = in_e - 1; i >= in_s; --i) {
979+
if (in_shape[i] != symbol::DimExpr(1)) {
980+
fuse_indices.insert(fuse_indices.begin(), i);
981+
} else {
982+
VLOG(4) << "Remove index[" << i << "]: " << in_shape[i]
983+
<< " for expr: \n"
984+
<< copied;
985+
copied = ExprTransformerUtils::RemoveOneTransformer(i)(copied);
986+
ir_sch.SetExprs({copied});
987+
for (auto& index : fuse_indices) {
988+
index--;
989+
}
990+
}
991+
}
992+
if (fuse_indices.size() > 1) {
993+
VLOG(4) << "fuse_indices: " << cinn::utils::Join(fuse_indices, ",");
994+
ir_sch.Fuse(block_name, fuse_indices);
995+
}
996+
997+
std::vector<ir::Expr> split_shapes;
998+
for (int i = out_s; i < out_e; ++i) {
999+
if (out_shape[i] != symbol::DimExpr(1)) {
1000+
split_shapes.push_back(
1001+
cinn::common::DimExprConverter().ConvertToIrExpr(out_shape[i]));
1002+
}
1003+
}
1004+
if (split_shapes.size() > 1) {
1005+
ir_sch.Split(ir_sch.GetLoops(block_name)[in_s], split_shapes)[0];
1006+
}
1007+
}
1008+
1009+
std::vector<int> insert_axis;
1010+
std::vector<ir::Var> ones_var;
1011+
for (int i = 0; i < out_shape.size(); ++i) {
1012+
if (out_shape[i] == symbol::DimExpr(1)) {
1013+
insert_axis.push_back(i);
1014+
ones_var.push_back(ir::Var(1, "one_" + std::to_string(ones_var.size())));
1015+
}
1016+
}
1017+
copied = ExprTransformerUtils::InsertForsTransformer(insert_axis,
1018+
ones_var)(copied);
1019+
1020+
return copied;
1021+
}
1022+
9681023
} // namespace trivial_fusion_detail
9691024
} // namespace pir
9701025
} // namespace framework

paddle/cinn/hlir/framework/pir/trivial_op_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root);
297297

298298
ir::Expr GetBodyBlock(const ir::Expr& root);
299299

300+
ir::Expr ReshapeLoop(const ir::Expr& root,
301+
const std::vector<symbol::DimExpr>& in_shape,
302+
const std::vector<symbol::DimExpr>& out_shape);
303+
300304
} // namespace trivial_fusion_detail
301305
} // namespace pir
302306
} // namespace framework

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
6363
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
6464
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
6565
for (int i = 0; i < iter_vars.size(); i++) {
66-
ir::Var loop_var = block->iter_values[i].as_var_ref();
67-
if (reduce_loop_vars.count(loop_var->name) > 0) {
66+
if (block->iter_values[i].is_var() &&
67+
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
6868
reduce_iter_vars.insert(iter_vars[i]->name);
6969
}
7070
}

paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ir::Expr ApplyItersTransform::operator()(const TransposeItersTransform& trans) {
3131

3232
ir::Expr ApplyItersTransform::operator()(const RemoveOnesTransform& trans) {
3333
VLOG(4) << "[ItersTransform] Before RemoveOnesTransform("
34-
<< utils::Join(trans.ones_, ",") << "'): " << expr_;
34+
<< utils::Join(trans.ones_, ",") << "): " << expr_;
3535
auto result = RemoveOnesTransformer(trans.ones_)(expr_);
3636
VLOG(4) << "[ItersTransform] After RemoveOnesTransform: " << result;
3737
return result;

paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,20 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
159159
interpreter->scope[instr->target_] = new_pattern;
160160
}
161161

162+
void RunReshapeAlignInstr(const std::shared_ptr<ReshapeAlignInstr>& instr,
163+
FusionInterpreter* interpreter) {
164+
const auto expr = std::visit(
165+
FusibleOp2Expr(), interpreter->scope[instr->input_]->fusion_ops[0])[0];
166+
VLOG(4) << "Before RunReshapeAlignInstr: \n" << expr;
167+
auto result = cinn::hlir::framework::pir::trivial_fusion_detail::ReshapeLoop(
168+
expr, instr->in_shape_, instr->out_shape_);
169+
170+
auto new_pattern = std::make_shared<ScopeElement>();
171+
new_pattern->fusion_ops.emplace_back(TrivialOp(result));
172+
interpreter->scope[instr->result_] = new_pattern;
173+
VLOG(4) << "After ReshapeAlignInstr: \n" << result;
174+
}
175+
162176
void RunPaddingInstr(const std::shared_ptr<PaddingInstr>& instr,
163177
FusionInterpreter* interpreter) {
164178
ScopeElementPtr new_pattern = std::make_shared<ScopeElement>();
@@ -229,6 +243,10 @@ std::vector<ir::Expr> FusionInterpreter::Run() {
229243
RunItersTransformInstr(
230244
dynamic_cast_instr_with_err<ItersTransformInstr>(instr), this);
231245
break;
246+
case T_ReshapeAlign:
247+
RunReshapeAlignInstr(
248+
dynamic_cast_instr_with_err<ReshapeAlignInstr>(instr), this);
249+
break;
232250
default:
233251
PADDLE_THROW(
234252
::common::errors::Unavailable("Unsupported Fusion Instrution"));

paddle/cinn/operator_fusion/fusion_tracker/tracker.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ enum InstructionType {
2525
T_Return,
2626
T_InitPattern,
2727
T_TrivialInline,
28+
T_ReshapeAlign,
2829
T_TmpTransform,
2930
T_TrivialLoopAlign,
3031
T_ItersTransform,
@@ -143,6 +144,32 @@ struct TrivialInlineInstr : public FusionInstruction {
143144
}
144145
};
145146

147+
struct ReshapeAlignInstr : public FusionInstruction {
148+
ReshapeAlignInstr(const std::string& input,
149+
const std::vector<symbol::DimExpr>& in_shape,
150+
const std::vector<symbol::DimExpr>& out_shape,
151+
const std::string& result)
152+
: input_(input),
153+
in_shape_(in_shape),
154+
out_shape_(out_shape),
155+
result_(result) {}
156+
virtual InstructionType type() const { return T_ReshapeAlign; }
157+
virtual FusionInstrPtr Clone() {
158+
return std::make_shared<ReshapeAlignInstr>(*this);
159+
}
160+
161+
std::string input_;
162+
std::vector<symbol::DimExpr> in_shape_;
163+
std::vector<symbol::DimExpr> out_shape_;
164+
std::string result_;
165+
166+
virtual std::string DebugStr() const {
167+
return "ReshapeAlignInstr || " + input_ + "(" +
168+
cinn::utils::Join(in_shape_, ",") + ") => " + result_ + "(" +
169+
cinn::utils::Join(out_shape_, ",") + ")";
170+
}
171+
};
172+
146173
struct TmpTransformInstr : public FusionInstruction {
147174
TmpTransformInstr(const std::string& upstream,
148175
const std::string& downstream,

paddle/cinn/operator_fusion/graph_transformer/matcher.h

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,39 @@ struct AlwaysTrue {
2424
}
2525
};
2626

27+
struct NonSinkNodeMatcher {
28+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
29+
return !node->downstream().empty();
30+
}
31+
};
32+
33+
struct IsOutputNodeMatcher {
34+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
35+
bool res = IsAnyFirstInSecond(node->sink_op()->results(), graph.outputs());
36+
return res;
37+
}
38+
};
39+
40+
template <int N>
41+
struct DownstreamSmallerThan {
42+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
43+
return node->downstream().size() < N;
44+
}
45+
};
46+
47+
template <int N>
48+
struct DownstreamGreaterThan {
49+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
50+
return node->downstream().size() > N;
51+
}
52+
};
53+
54+
struct OnlyOneDownstreamMatcher {
55+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
56+
return node->downstream().size() == 1;
57+
}
58+
};
59+
2760
template <typename StmtPattern>
2861
struct StmtPatternGraphMatcher {
2962
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
@@ -139,7 +172,15 @@ struct RecomputeNodeMatcher {
139172
if (node->fusion_iters().output_values.size() > 1) {
140173
return false;
141174
}
142-
175+
bool has_combine_fusion =
176+
std::any_of(node->fusion_tracker()->instructions_.begin(),
177+
node->fusion_tracker()->instructions_.end(),
178+
[](const FusionInstrPtr& instr) {
179+
return instr->type() == T_Combine;
180+
});
181+
if (has_combine_fusion) {
182+
return false;
183+
}
143184
for (const auto& op : GetOpsInPattern(node->stmt_pattern())) {
144185
const auto& op_kind = GetOpPatternKind(op);
145186
if (op_kind >= hlir::framework::kReduction) {
@@ -183,9 +224,50 @@ struct TransposeOpMatcher {
183224
}
184225
};
185226

186-
struct NonSinkNodeMatcher {
227+
struct ReshapeOpMatcher {
187228
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
188-
return !node->downstream().empty();
229+
return (node->sink_op()->name() == "cinn_op.reshape");
230+
}
231+
};
232+
233+
struct ReshapeConnectionMatcher {
234+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
235+
bool upstream_match =
236+
node->downstream().size() == 1 &&
237+
node->downstream()[0]->sink_op()->name() == "cinn_op.reshape" &&
238+
node->downstream()[0]->downstream().size() == 1;
239+
bool downstream_match = node->sink_op()->name() == "cinn_op.reshape" &&
240+
node->downstream().size() == 1;
241+
return upstream_match || downstream_match;
242+
}
243+
};
244+
245+
struct LeafReshapeConnectionMatcher {
246+
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
247+
const auto match_upstream = [&graph](const PatternNodePtr& upstream) {
248+
return StmtPatternGraphMatcher<TrivialPattern>()(graph, upstream) &&
249+
upstream->downstream().size() == 1 &&
250+
!upstream->upstream().empty() &&
251+
std::any_of(upstream->upstream().begin(),
252+
upstream->upstream().end(),
253+
[&graph](const PatternNodePtr& node) {
254+
return DownstreamGreaterThan<1>()(graph, node);
255+
});
256+
};
257+
const auto match_downstream = [&graph](const PatternNodePtr& downstream) {
258+
return downstream->sink_op()->name() == "cinn_op.reshape" &&
259+
downstream->downstream().size() == 1 &&
260+
downstream->downstream()[0]->downstream().empty() &&
261+
downstream->fusion_iters().loop_iters ==
262+
downstream->downstream()[0]->fusion_iters().loop_iters;
263+
};
264+
bool upstream_match = match_upstream(node) &&
265+
node->downstream().size() == 1 &&
266+
match_downstream(node->downstream()[0]);
267+
bool downstream_match = match_downstream(node) &&
268+
node->upstream().size() == 1 &&
269+
match_upstream(node->upstream()[0]);
270+
return upstream_match || downstream_match;
189271
}
190272
};
191273

@@ -206,26 +288,6 @@ struct NotAllElementWiseDownstreamMatcher {
206288
}
207289
};
208290

209-
struct IsOutputNodeMatcher {
210-
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
211-
bool res = IsAnyFirstInSecond(node->sink_op()->results(), graph.outputs());
212-
return res;
213-
}
214-
};
215-
216-
template <int N>
217-
struct DownstreamSmallerThan {
218-
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
219-
return node->downstream().size() < N;
220-
}
221-
};
222-
223-
template <int N>
224-
struct DownstreamGreaterThan {
225-
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
226-
return node->downstream().size() > N;
227-
}
228-
};
229291
template <typename... Args>
230292
struct And {};
231293

0 commit comments

Comments
 (0)