Skip to content

[spmd_rules] Add SPMD rules for ArgSort operator #72388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions paddle/phi/infermeta/spmd_rules/argsort.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/argsort.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi::distributed {

SpmdInfo ArgSortInferSpmd(const DistMetaTensor& x,
int axis,
bool descending,
bool stable) {
auto x_shape = common::vectorize(x.dims());
int x_ndim = static_cast<int>(x_shape.size());
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
errors::InvalidArgument(
"ArgSort input rank [%d] should be equal to dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));

axis = axis < 0 ? axis + x_ndim : axis;

PADDLE_ENFORCE_EQ(
0 <= axis && axis < x_ndim,
true,
phi::errors::InvalidArgument(
"The axis of argsort should be in range [0, %d), but got %d.",
x_ndim,
axis));

std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping);
x_dims_mapping_dst[axis] = -1;
std::vector<int64_t> y_dims_mapping_dst(x_dims_mapping_dst);
std::vector<int64_t> indices_dims_mapping_dst(x_dims_mapping_dst);
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
auto y_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
auto indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
y_dist_attr_dst.set_dims_mapping(y_dims_mapping_dst);
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);

VLOG(4) << "ArgSortInferSpmdBase:" << std::endl;
VLOG(4) << "x_dist_attr_src: " << x_dist_attr_src.to_string()
<< " x_dist_attr_dst: " << x_dist_attr_dst.to_string() << std::endl;
VLOG(4) << "y_dist_attr_dst: " << y_dist_attr_dst.to_string() << std::endl;

return {{x_dist_attr_dst}, {y_dist_attr_dst, indices_dist_attr_dst}};
}

SpmdInfo ArgSortGradInferSpmd(const DistMetaTensor& indices,
const DistMetaTensor& x,
const DistMetaTensor& out_grad,
int axis,
bool descending,
bool stable) {
// step 0: check invalidation of parameters
auto x_shape = common::vectorize(x.dims());
int x_ndim = static_cast<int>(x_shape.size());
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
errors::InvalidArgument("ArgSortGrad input x rank [%d] should be equal "
"to dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));

auto ind_shape = common::vectorize(indices.dims());
int ind_ndim = static_cast<int>(ind_shape.size());
auto ind_dist_attr_src = indices.dist_attr();
std::vector<int64_t> ind_dims_mapping = ind_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
ind_ndim,
ind_dims_mapping.size(),
errors::InvalidArgument("ArgSortGrad indices rank [%d] should be equal "
"to dims_mapping size [%d].",
ind_ndim,
ind_dims_mapping.size()));

auto out_grad_shape = common::vectorize(out_grad.dims());
int out_grad_ndim = static_cast<int>(out_grad_shape.size());
auto out_grad_dist_attr_src = out_grad.dist_attr();
std::vector<int64_t> out_grad_dims_mapping =
out_grad_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
out_grad_ndim,
out_grad_dims_mapping.size(),
errors::InvalidArgument("ArgSortGrad out_grad rank [%d] should be equal "
"to dims_mapping size [%d].",
out_grad_ndim,
out_grad_dims_mapping.size()));

PADDLE_ENFORCE_EQ(
x_ndim == ind_ndim && x_ndim == out_grad_ndim,
1,
errors::InvalidArgument("ArgSortGrad x rank [%d] should be equal to "
"indices rank [%d] and out_grad rank [%d].",
x_ndim,
ind_ndim,
out_grad_ndim));

for (int i = 0; i < x_ndim; ++i) {
PADDLE_ENFORCE_EQ(
x_dims_mapping[i] == ind_dims_mapping[i],
1,
errors::InvalidArgument("ArgSortGrad x dims_mapping[%d]=[%d] should be "
"equal to indices dims_mapping[%d]=[%d].",
i,
x_dims_mapping[i],
i,
ind_dims_mapping[i]));
}

axis = axis < 0 ? axis + x_ndim : axis;

PADDLE_ENFORCE_EQ(
0 <= axis && axis < x_ndim,
true,
phi::errors::InvalidArgument(
"The axis of argsort should be in range [0, %d), but got %d.",
x_ndim,
axis));

// step 1: infer spmd info
std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping);
x_dims_mapping_dst[axis] = -1;
std::vector<int64_t> out_grad_dims_mapping_dst(x_dims_mapping_dst);
std::vector<int64_t> indices_dims_mapping_dst(x_dims_mapping_dst);
std::vector<int64_t> x_grad_dims_mapping_dst(x_dims_mapping_dst);

auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
auto out_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
auto indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
auto x_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);

x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_dst);
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);
x_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);

VLOG(4) << "ArgSortGradInferSpmdBase:" << std::endl;
VLOG(4) << "indices_dist_attr_src: " << ind_dist_attr_src.to_string()
<< " indices_dist_attr_dst: " << indices_dist_attr_dst.to_string()
<< std::endl;
VLOG(4) << "x_dist_attr_src: " << x_dist_attr_src.to_string()
<< " x_dist_attr_dst: " << x_dist_attr_dst.to_string() << std::endl;
VLOG(4) << "out_grad_dist_attr_src: " << out_grad_dist_attr_dst.to_string()
<< " out_grad_dist_attr_dst: " << out_grad_dist_attr_dst.to_string()
<< std::endl;
VLOG(4) << "x_grad_dist_attr_dst: " << x_grad_dist_attr_dst.to_string()
<< std::endl;
return {{indices_dist_attr_dst, x_dist_attr_dst, out_grad_dist_attr_dst},
{x_grad_dist_attr_dst}};
}

} // namespace phi::distributed
37 changes: 37 additions & 0 deletions paddle/phi/infermeta/spmd_rules/argsort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {

SpmdInfo ArgSortInferSpmd(const DistMetaTensor& x,
int axis,
bool descending,
bool stable);

SpmdInfo ArgSortGradInferSpmd(const DistMetaTensor& indices,
const DistMetaTensor& x,
const DistMetaTensor& out_grad,
int axis,
bool descending,
bool stable);

} // namespace distributed
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,4 +752,9 @@ PD_REGISTER_SPMD_RULE(nonzero,

// add_n
PD_REGISTER_SPMD_RULE(add_n, PD_INFER_SPMD(phi::distributed::AddNInferSpmd));

// argsort
PD_REGISTER_SPMD_RULE(argsort,
PD_INFER_SPMD(phi::distributed::ArgSortInferSpmd),
PD_INFER_SPMD(phi::distributed::ArgSortGradInferSpmd));
} // namespace phi::distributed
1 change: 1 addition & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/amp_ops.h"
#include "paddle/phi/infermeta/spmd_rules/argmax.h"
#include "paddle/phi/infermeta/spmd_rules/argmin.h"
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
#include "paddle/phi/infermeta/spmd_rules/c_embedding.h"
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h"
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
spmd_rule : ArgSortGradInferSpmd
param : [x]
kernel :
func : argsort_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@
output : Tensor(out), Tensor(indices)
infer_meta :
func : ArgsortInferMeta
spmd_rule : ArgSortInferSpmd
kernel :
func : argsort
backward : argsort_grad
Expand Down
1 change: 1 addition & 0 deletions test/auto_parallel/spmd_rules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_add_n_rule MODULES test_add_n_rule)
py_test_modules(test_mean_all_rule MODULES test_mean_all_rule)
py_test_modules(test_argmin_rule MODULES test_argmin_rule)
py_test_modules(test_argsort_rule MODULES test_argsort_rule)
endif()
# End of unittests WITH single card WITHOUT timeout

Expand Down
Loading
Loading