Skip to content

Commit b4026b7

Browse files
committed
[Auto Parallel] Add SPMD rules for ArgSort operator
1 parent b724233 commit b4026b7

File tree

9 files changed

+595
-0
lines changed

9 files changed

+595
-0
lines changed
+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
16+
17+
#include "glog/logging.h"
18+
19+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
20+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
21+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
22+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
23+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
24+
25+
namespace phi::distributed {
26+
27+
SpmdInfo ArgSortInferSpmd(const DistMetaTensor& x,
28+
int axis,
29+
bool descending,
30+
bool stable) {
31+
auto x_shape = common::vectorize(x.dims());
32+
int x_ndim = static_cast<int>(x_shape.size());
33+
auto x_dist_attr_src = x.dist_attr();
34+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
35+
PADDLE_ENFORCE_EQ(
36+
x_ndim,
37+
x_dims_mapping.size(),
38+
errors::InvalidArgument(
39+
"ArgSort input rank [%d] should be equal to dims_mapping size [%d].",
40+
x_ndim,
41+
x_dims_mapping.size()));
42+
43+
axis = axis < 0 ? axis + x_ndim : axis;
44+
45+
PADDLE_ENFORCE_EQ(
46+
0 <= axis && axis < x_ndim,
47+
true,
48+
phi::errors::InvalidArgument(
49+
"The axis of argsort should be in range [0, %d), but got %d.",
50+
x_ndim,
51+
axis));
52+
53+
std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping);
54+
x_dims_mapping_dst[axis] = -1;
55+
std::vector<int64_t> y_dims_mapping_dst(x_dims_mapping_dst);
56+
std::vector<int64_t> indices_dims_mapping_dst(x_dims_mapping_dst);
57+
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
58+
auto y_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
59+
auto indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
60+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
61+
y_dist_attr_dst.set_dims_mapping(y_dims_mapping_dst);
62+
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);
63+
64+
VLOG(4) << "ArgSortInferSpmdBase:" << std::endl;
65+
VLOG(4) << "x_dist_attr_src: " << x_dist_attr_src.to_string()
66+
<< " x_dist_attr_dst: " << x_dist_attr_dst.to_string() << std::endl;
67+
VLOG(4) << "y_dist_attr_dst: " << y_dist_attr_dst.to_string() << std::endl;
68+
69+
return {{x_dist_attr_dst}, {y_dist_attr_dst, indices_dist_attr_dst}};
70+
}
71+
72+
SpmdInfo ArgSortGradInferSpmd(const DistMetaTensor& indices,
73+
const DistMetaTensor& x,
74+
const DistMetaTensor& out_grad,
75+
int axis,
76+
bool descending,
77+
bool stable) {
78+
// step 0: check invalidation of parameters
79+
auto x_shape = common::vectorize(x.dims());
80+
int x_ndim = static_cast<int>(x_shape.size());
81+
auto x_dist_attr_src = x.dist_attr();
82+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
83+
PADDLE_ENFORCE_EQ(
84+
x_ndim,
85+
x_dims_mapping.size(),
86+
errors::InvalidArgument("ArgSortGrad input x rank [%d] should be equal "
87+
"to dims_mapping size [%d].",
88+
x_ndim,
89+
x_dims_mapping.size()));
90+
91+
auto ind_shape = common::vectorize(indices.dims());
92+
int ind_ndim = static_cast<int>(ind_shape.size());
93+
auto ind_dist_attr_src = indices.dist_attr();
94+
std::vector<int64_t> ind_dims_mapping = ind_dist_attr_src.dims_mapping();
95+
PADDLE_ENFORCE_EQ(
96+
ind_ndim,
97+
ind_dims_mapping.size(),
98+
errors::InvalidArgument("ArgSortGrad indices rank [%d] should be equal "
99+
"to dims_mapping size [%d].",
100+
ind_ndim,
101+
ind_dims_mapping.size()));
102+
103+
auto out_grad_shape = common::vectorize(out_grad.dims());
104+
int out_grad_ndim = static_cast<int>(out_grad_shape.size());
105+
auto out_grad_dist_attr_src = out_grad.dist_attr();
106+
std::vector<int64_t> out_grad_dims_mapping =
107+
out_grad_dist_attr_src.dims_mapping();
108+
PADDLE_ENFORCE_EQ(
109+
out_grad_ndim,
110+
out_grad_dims_mapping.size(),
111+
errors::InvalidArgument("ArgSortGrad out_grad rank [%d] should be equal "
112+
"to dims_mapping size [%d].",
113+
out_grad_ndim,
114+
out_grad_dims_mapping.size()));
115+
116+
PADDLE_ENFORCE_EQ(
117+
x_ndim == ind_ndim && x_ndim == out_grad_ndim,
118+
1,
119+
errors::InvalidArgument("ArgSortGrad x rank [%d] should be equal to "
120+
"indices rank [%d] and out_grad rank [%d].",
121+
x_ndim,
122+
ind_ndim,
123+
out_grad_ndim));
124+
125+
for (int i = 0; i < x_ndim; ++i) {
126+
PADDLE_ENFORCE_EQ(
127+
x_dims_mapping[i] == ind_dims_mapping[i],
128+
1,
129+
errors::InvalidArgument("ArgSortGrad x dims_mapping[%d]=[%d] should be "
130+
"equal to indices dims_mapping[%d]=[%d].",
131+
i,
132+
x_dims_mapping[i],
133+
i,
134+
ind_dims_mapping[i]));
135+
}
136+
137+
axis = axis < 0 ? axis + x_ndim : axis;
138+
139+
PADDLE_ENFORCE_EQ(
140+
0 <= axis && axis < x_ndim,
141+
true,
142+
phi::errors::InvalidArgument(
143+
"The axis of argsort should be in range [0, %d), but got %d.",
144+
x_ndim,
145+
axis));
146+
147+
// step 1: infer spmd info
148+
std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping);
149+
x_dims_mapping_dst[axis] = -1;
150+
std::vector<int64_t> out_grad_dims_mapping_dst(x_dims_mapping_dst);
151+
std::vector<int64_t> indices_dims_mapping_dst(x_dims_mapping_dst);
152+
std::vector<int64_t> x_grad_dims_mapping_dst(x_dims_mapping_dst);
153+
154+
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
155+
auto out_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
156+
auto indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
157+
auto x_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
158+
159+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
160+
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_dst);
161+
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);
162+
x_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
163+
164+
VLOG(4) << "ArgSortGradInferSpmdBase:" << std::endl;
165+
VLOG(4) << "indices_dist_attr_src: " << ind_dist_attr_src.to_string()
166+
<< " indices_dist_attr_dst: " << indices_dist_attr_dst.to_string()
167+
<< std::endl;
168+
VLOG(4) << "x_dist_attr_src: " << x_dist_attr_src.to_string()
169+
<< " x_dist_attr_dst: " << x_dist_attr_dst.to_string() << std::endl;
170+
VLOG(4) << "out_grad_dist_attr_src: " << out_grad_dist_attr_dst.to_string()
171+
<< " out_grad_dist_attr_dst: " << out_grad_dist_attr_dst.to_string()
172+
<< std::endl;
173+
VLOG(4) << "x_grad_dist_attr_dst: " << x_grad_dist_attr_dst.to_string()
174+
<< std::endl;
175+
return {{indices_dist_attr_dst, x_dist_attr_dst, out_grad_dist_attr_dst},
176+
{x_grad_dist_attr_dst}};
177+
}
178+
179+
} // namespace phi::distributed
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/scalar.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
19+
#include "paddle/phi/core/distributed/type_defs.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
24+
SpmdInfo ArgSortInferSpmd(const DistMetaTensor& x,
25+
int axis,
26+
bool descending,
27+
bool stable);
28+
29+
SpmdInfo ArgSortGradInferSpmd(const DistMetaTensor& indices,
30+
const DistMetaTensor& x,
31+
const DistMetaTensor& out_grad,
32+
int axis,
33+
bool descending,
34+
bool stable);
35+
36+
} // namespace distributed
37+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

+5
Original file line numberDiff line numberDiff line change
@@ -733,4 +733,9 @@ PD_REGISTER_SPMD_RULE(nonzero,
733733

734734
// add_n
735735
PD_REGISTER_SPMD_RULE(add_n, PD_INFER_SPMD(phi::distributed::AddNInferSpmd));
736+
737+
// argsort
738+
PD_REGISTER_SPMD_RULE(argsort,
739+
PD_INFER_SPMD(phi::distributed::ArgSortInferSpmd),
740+
PD_INFER_SPMD(phi::distributed::ArgSortGradInferSpmd));
736741
} // namespace phi::distributed

paddle/phi/infermeta/spmd_rules/rules.h

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/phi/infermeta/spmd_rules/add_n.h"
1818
#include "paddle/phi/infermeta/spmd_rules/amp_ops.h"
1919
#include "paddle/phi/infermeta/spmd_rules/argmax.h"
20+
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
2021
#include "paddle/phi/infermeta/spmd_rules/c_embedding.h"
2122
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h"
2223
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h"

paddle/phi/ops/yaml/backward.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
output : Tensor(x_grad)
155155
infer_meta :
156156
func : UnchangedInferMeta
157+
spmd_rule : ArgSortGradInferSpmd
157158
param : [x]
158159
kernel :
159160
func : argsort_grad

paddle/phi/ops/yaml/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@
316316
output : Tensor(out), Tensor(indices)
317317
infer_meta :
318318
func : ArgsortInferMeta
319+
spmd_rule : ArgSortInferSpmd
319320
kernel :
320321
func : argsort
321322
backward : argsort_grad

test/auto_parallel/spmd_rules/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ if(WITH_DISTRIBUTE)
4545
py_test_modules(test_nonzero_rule MODULES test_nonzero_rule)
4646
if(NOT WITH_ROCM)
4747
py_test_modules(test_add_n_rule MODULES test_add_n_rule)
48+
py_test_modules(test_argsort_rule MODULES test_argsort_rule)
4849
endif()
4950
# End of unittests WITH single card WITHOUT timeout
5051

0 commit comments

Comments
 (0)