Skip to content

Commit 7942f46

Browse files
authored
[AutoParallel] Add expand spmd (#71603)
1 parent 0a8a0be commit 7942f46

File tree

7 files changed

+179
-0
lines changed

7 files changed

+179
-0
lines changed
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/infermeta/spmd_rules/expand.h"
13+
14+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
15+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
16+
17+
namespace phi::distributed {
18+
19+
SpmdInfo ExpandInferSpmd(const DistMetaTensor& x, const IntArray& shape) {
20+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
21+
auto expand_shape = shape.GetData();
22+
std::vector<int64_t> out_dims_mapping(shape.size());
23+
int diff = expand_shape.size() - x_shape.size();
24+
for (int i = expand_shape.size() - 1; i >= diff; --i) {
25+
if (expand_shape[i] != -1 && expand_shape[i] != x_shape[i - diff]) {
26+
out_dims_mapping[i] = -1;
27+
} else {
28+
out_dims_mapping[i] = x_dims_mapping_src[i - diff];
29+
}
30+
}
31+
for (int i = 0; i < diff; i++) {
32+
out_dims_mapping[i] = -1;
33+
}
34+
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
35+
out_dist_attr.set_dims_mapping(out_dims_mapping);
36+
return {{x_dist_attr_src}, {out_dist_attr}};
37+
}
38+
39+
SpmdInfo ExpandGradInferSpmd(const DistMetaTensor& x,
40+
const DistMetaTensor& out_grad,
41+
const IntArray& shape) {
42+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
43+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
44+
if (x_shape.size() == out_grad_shape.size()) {
45+
return {{x_dist_attr_src, out_grad_dist_attr_src}, {x_dist_attr_src}};
46+
}
47+
size_t axis =
48+
std::abs(static_cast<int>(out_grad.dims().size() - x.dims().size()));
49+
std::vector<int64_t> x_grad_dims_mapping;
50+
for (size_t i = 0; i < out_grad_dims_mapping_src.size(); ++i) {
51+
if (i < axis || i >= axis + x.dims().size() ||
52+
out_grad.dims()[i] != x.dims()[i - axis]) {
53+
continue;
54+
}
55+
x_grad_dims_mapping.push_back(out_grad_dims_mapping_src[i]);
56+
}
57+
TensorDistAttr x_grad_dist_attr =
58+
CopyTensorDistAttrForOutput(x_dist_attr_src);
59+
x_grad_dist_attr.set_dims_mapping(x_grad_dims_mapping);
60+
return {{x_dist_attr_src, out_grad_dist_attr_src}, {x_grad_dist_attr}};
61+
}
62+
63+
} // namespace phi::distributed
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/int_array.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
19+
#include "paddle/phi/core/distributed/type_defs.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
SpmdInfo ExpandInferSpmd(const DistMetaTensor& x, const IntArray& shape);
24+
25+
SpmdInfo ExpandGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& out_grad,
27+
const IntArray& shape);
28+
29+
} // namespace distributed
30+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.h

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License. */
3131
#include "paddle/phi/infermeta/spmd_rules/dropout.h"
3232
#include "paddle/phi/infermeta/spmd_rules/elementwise.h"
3333
#include "paddle/phi/infermeta/spmd_rules/embedding.h"
34+
#include "paddle/phi/infermeta/spmd_rules/expand.h"
3435
#include "paddle/phi/infermeta/spmd_rules/expand_as.h"
3536
#include "paddle/phi/infermeta/spmd_rules/flash_attention.h"
3637
#include "paddle/phi/infermeta/spmd_rules/flatten.h"

paddle/phi/ops/yaml/backward.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,7 @@
997997
infer_meta :
998998
func : UnchangedInferMeta
999999
param : [x]
1000+
spmd_rule : ExpandGradInferSpmd
10001001
kernel :
10011002
func : expand_grad
10021003
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,7 @@
16911691
infer_meta :
16921692
func : ExpandInferMeta
16931693
local_shape: out
1694+
spmd_rule : ExpandInferSpmd
16941695
kernel :
16951696
func : expand
16961697
data_type : x

test/cpp/auto_parallel/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ if(WITH_DISTRIBUTE)
3434
cross_entropy_softmax_spmd_rule_test SRCS
3535
cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util phi)
3636

37+
paddle_test(expand_spmd_rule_test SRCS expand_spmd_rule_test.cc DEPS
38+
spmd_rule_test_util phi)
39+
3740
paddle_test(expand_as_spmd_rule_test SRCS expand_as_spmd_rule_test.cc DEPS
3841
spmd_rule_test_util phi)
3942

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "glog/logging.h"
16+
#include "test/cpp/auto_parallel/spmd_rule_test_util.h"
17+
namespace paddle {
18+
namespace distributed {
19+
namespace auto_parallel {
20+
21+
ProcessMesh CreateProcessMesh() {
22+
std::vector<int64_t> mesh_shape = {2, 3};
23+
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
24+
std::vector<std::string> dim_names = {"x", "y"};
25+
return ProcessMesh(mesh_shape, process_ids, dim_names);
26+
}
27+
28+
phi::distributed::DistMetaTensor CreateDistMetaTensor(
29+
const std::vector<int64_t>& shape,
30+
const std::vector<int64_t>& dims_mapping,
31+
const ProcessMesh& process_mesh) {
32+
TensorDistAttr dist_attr;
33+
dist_attr.set_process_mesh(process_mesh);
34+
dist_attr.set_dims_mapping(dims_mapping);
35+
return phi::distributed::DistMetaTensor(phi::make_ddim(shape), dist_attr);
36+
}
37+
38+
TEST(ExpandInferSpmd, Ctor) {
39+
ProcessMesh process_mesh = CreateProcessMesh();
40+
41+
// Test case forward 1: Expand with shape {8, 2, 6, 1024, -1}
42+
auto x = CreateDistMetaTensor(
43+
{8, 2, 1, 1024, 128}, {0, -1, -1, 1, -1}, process_mesh);
44+
phi::IntArray shape = {8, 2, 6, 1024, -1};
45+
auto spmdinfo = ExpandInferSpmd(x, shape);
46+
EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]),
47+
std::vector<int64_t>({0, -1, -1, 1, -1}));
48+
EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]),
49+
std::vector<int64_t>({0, -1, -1, 1, -1}));
50+
51+
// Test case forward 2: Expand with shape {2, -1}
52+
auto x1 = CreateDistMetaTensor({8}, {1}, process_mesh);
53+
phi::IntArray shape1 = {2, -1};
54+
auto spmdinfo1 = ExpandInferSpmd(x1, shape1);
55+
EXPECT_EQ(get_dims_mapping(spmdinfo1.first[0]), std::vector<int64_t>({1}));
56+
EXPECT_EQ(get_dims_mapping(spmdinfo1.second[0]),
57+
std::vector<int64_t>({-1, 1}));
58+
59+
// Test case forward 3: Expand with shape {0, -1}
60+
auto x2 = CreateDistMetaTensor({8}, {1}, process_mesh);
61+
phi::IntArray shape2 = {0, -1};
62+
auto spmdinfo2 = ExpandInferSpmd(x2, shape2);
63+
EXPECT_EQ(get_dims_mapping(spmdinfo2.first[0]), std::vector<int64_t>({1}));
64+
EXPECT_EQ(get_dims_mapping(spmdinfo2.second[0]),
65+
std::vector<int64_t>({-1, 1}));
66+
67+
// Test case backward 1: ExpandGrad with shape {0, -1}
68+
auto x3 = CreateDistMetaTensor({8}, {1}, process_mesh);
69+
auto out3 = CreateDistMetaTensor({2, 8}, {-1, 1}, process_mesh);
70+
phi::IntArray shape3 = {0, -1};
71+
auto spmdinfo3 = ExpandGradInferSpmd(x3, out3, shape3);
72+
EXPECT_EQ(get_dims_mapping(spmdinfo3.first[0]), std::vector<int64_t>({1}));
73+
EXPECT_EQ(get_dims_mapping(spmdinfo3.first[1]),
74+
std::vector<int64_t>({-1, 1}));
75+
EXPECT_EQ(get_dims_mapping(spmdinfo3.second[0]), std::vector<int64_t>({1}));
76+
}
77+
78+
} // namespace auto_parallel
79+
} // namespace distributed
80+
} // namespace paddle

0 commit comments

Comments
 (0)