Skip to content

Commit e81a049

Browse files
authored
update (#63324)
1 parent 636efd6 commit e81a049

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class AddYieldStoreInFusionOpPattern
3939
continue;
4040
}
4141

42+
rewriter.SetInsertionPointAfter(op->operand_source(i).defining_op());
4243
auto store_op = rewriter.Build<cinn::dialect::YieldStoreOp>(
4344
op->operand_source(i), op->operand_source(i).type());
4445
auto orignal_base = op->operand_source(i);

test/ir/pir/cinn/sub_graphs/test_sub_graph_47.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
6363
def test_ast_prim_cinn(self):
6464
st_out = self.train(self.net, to_static=True)
6565
cinn_out = self.train(
66-
self.net, to_static=True, with_prim=True, with_cinn=False
66+
self.net, to_static=True, with_prim=True, with_cinn=True
6767
)
6868
for st, cinn in zip(
6969
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)

test/ir/pir/cinn/sub_graphs/test_sub_graph_77.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
174174
def test_ast_prim_cinn(self):
175175
st_out = self.train(self.net, to_static=True)
176176
cinn_out = self.train(
177-
self.net, to_static=True, with_prim=True, with_cinn=False
177+
self.net, to_static=True, with_prim=True, with_cinn=True
178178
)
179179
for st, cinn in zip(
180180
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)

0 commit comments

Comments
 (0)