Skip to content

Commit bb59a58

Browse files
authored
[CINN] Fix TileBroadcastTactic matching conditions (#72348)
1 parent e8638db commit bb59a58

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ bool IsElementwiseOrBroadcast(const ir::Store& dst, const ir::Load& src) {
206206

207207
bool CheckAllElementwiseOrBroadcast(ir::IRSchedule* sch) {
208208
for (auto& block : sch->GetAllBlocks()) {
209+
if (ir::analyzer::IsReductionSBlock(block)) return false;
209210
ir::Expr store = ir::analyzer::GetStoreOfSBlock(block);
210211
auto* store_node = store.As<ir::Store>();
211212
for (auto& load : CollectLoads(store_node->value)) {
@@ -564,6 +565,7 @@ void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
564565
block_size = CalcNumWarps(preserved_size_ >> 5);
565566
if (block_size == -1) {
566567
applied_layout_ = BroadcastLayout::Invalid;
568+
return;
567569
}
568570
block_size = std::clamp(block_size << 5, 128, 1024);
569571
}

test/dygraph_to_static/test_pir_selectedrows.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
20-
test_phi_only,
2120
)
2221

2322
import paddle
@@ -74,7 +73,6 @@ def forward_static():
7473

7574

7675
class TestSimnet(Dy2StTestBase):
77-
@test_phi_only
7876
def test_dygraph_static_same_loss(self):
7977
dygraph_value = forward_dygraph()
8078
static_value = forward_static()

0 commit comments

Comments
 (0)