We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b2b2721 commit de9ab1fCopy full SHA for de9ab1f
paddle/phi/infermeta/spmd_rules/expand.cc
@@ -41,6 +41,9 @@ SpmdInfo ExpandGradInferSpmd(const DistMetaTensor& x,
41
const IntArray& shape) {
42
EXTRACT_SHAPE_AND_DIST_ATTR(x);
43
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
44
+ if (x_shape.size() == out_grad_shape.size()) {
45
+ return {{x_dist_attr_src, out_grad_dist_attr_src}, {x_dist_attr_src}};
46
+ }
47
size_t axis =
48
std::abs(static_cast<int>(out_grad.dims().size() - x.dims().size()));
49
std::vector<int64_t> x_grad_dims_mapping;
0 commit comments