Skip to content

Commit de9ab1f

Browse files
authored
Update expand.cc
1 parent b2b2721 commit de9ab1f

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

paddle/phi/infermeta/spmd_rules/expand.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ SpmdInfo ExpandGradInferSpmd(const DistMetaTensor& x,
4141
const IntArray& shape) {
4242
EXTRACT_SHAPE_AND_DIST_ATTR(x);
4343
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
44+
if (x_shape.size() == out_grad_shape.size()) {
45+
return {{x_dist_attr_src, out_grad_dist_attr_src}, {x_dist_attr_src}};
46+
}
4447
size_t axis =
4548
std::abs(static_cast<int>(out_grad.dims().size() - x.dims().size()));
4649
std::vector<int64_t> x_grad_dims_mapping;

0 commit comments

Comments
 (0)