Skip to content

Commit db96ae5

Browse files
zhoutianzi666zhoukangkang
and
zhoukangkang
authored
[paddle-trt] x and y 's rank should be same in trt_skip_layernorm_pass (#56007)
* commit * commit --------- Co-authored-by: zhoukangkang <zhoukangkang@baidu.com>
1 parent 5ada98b commit db96ae5

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
152152
}
153153

154154
VLOG(4) << "handle TrtSkipLayerNorm fuse";
155+
156+
// x and y 's rank must be same
157+
if (subgraph.at(x)->Var()->GetShape().size() !=
158+
subgraph.at(y)->Var()->GetShape().size()) {
159+
return;
160+
}
161+
155162
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
156163
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
157164
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);

0 commit comments

Comments
 (0)