Skip to content

Commit 6546507

Browse files
committed
[CINN] Fix eliminate if bug
1 parent 9b228fd commit 6546507

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,10 @@ void RunReturnInstr(const std::shared_ptr<ReturnInstr>& instr,
202202
auto exprs = std::visit(FusibleOp2Expr(), fusion_op);
203203
for (auto expr : exprs) {
204204
std::string output_var_name = GetOutputTensor(expr)->name;
205-
if (interpreter->global_var_names.count(output_var_name)) {
206-
expr = ExprTransformerUtils::EliminateUselessIfTransformer()(expr);
207-
} else {
205+
if (!interpreter->global_var_names.count(output_var_name)) {
208206
expr = ExprTransformerUtils::RemoveAllAppendIfTransformer()(expr);
209207
}
208+
expr = ExprTransformerUtils::EliminateUselessIfTransformer()(expr);
210209
result.push_back(expr);
211210
}
212211
}

test/ir/pir/cinn/test_horizontal_fusion.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ def init():
152152

153153
self.check_accuracy_and_kernel_num(init, func, kernel_num=2)
154154

155+
def test_reduce_horizontal_fusion(self):
156+
def func(x):
157+
a = paddle.sum(x, axis=[0], keepdim=True)
158+
a = paddle.reshape(a, [6])
159+
b = paddle.sum(x, axis=[0], keepdim=True)
160+
b = paddle.reshape(b, [6])
161+
return a, b
162+
163+
def init():
164+
x = paddle.rand((2, 2, 3))
165+
return (x,)
166+
167+
self.check_accuracy_and_kernel_num(init, func, kernel_num=1)
168+
155169

156170
if __name__ == "__main__":
157171
unittest.main()

0 commit comments

Comments
 (0)