Skip to content

Commit c5afa97

Browse files
committed
[Auto Parallel] Add spmd rule for topk and topk_grad ops
1 parent 441816a commit c5afa97

File tree

8 files changed

+209
-0
lines changed

8 files changed

+209
-0
lines changed

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ PD_REGISTER_SPMD_RULE(
705705
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdBase),
706706
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdReverseBase));
707707

708+
// argmax
709+
PD_REGISTER_SPMD_RULE(
710+
topk,
711+
PD_INFER_SPMD(phi::distributed::TopkInferSpmd),
712+
PD_INFER_SPMD(phi::distributed::TopkGradInferSpmd));
713+
708714
// unbind
709715
PD_REGISTER_SPMD_RULE(unbind,
710716
PD_INFER_SPMD(phi::distributed::UnbindInferSpmd),

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ limitations under the License. */
6666
#include "paddle/phi/infermeta/spmd_rules/squeeze.h"
6767
#include "paddle/phi/infermeta/spmd_rules/stack.h"
6868
#include "paddle/phi/infermeta/spmd_rules/tile.h"
69+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
6970
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
7071
#include "paddle/phi/infermeta/spmd_rules/triu.h"
7172
#include "paddle/phi/infermeta/spmd_rules/unbind.h"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/stack.h"
16+
17+
#include <limits>
18+
#include <set>
19+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
20+
#include "paddle/phi/infermeta/spmd_rules/elementwise.h"
21+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
22+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
23+
24+
namespace phi {
25+
namespace distributed {
26+
27+
SpmdInfo TopkInferSpmd(const DistMetaTensor& x, const Scalar& k, int axis, bool largest, bool sorted) {
28+
// Verify input args
29+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
30+
axis = axis < 0 ? x_ndim + axis : axis;
31+
32+
// Infer output dims mapping from merged input dims mapping
33+
std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping_src);
34+
std::vector<int64_t> out_dims_mapping;
35+
std::vector<int64_t> indices_dims_mapping;
36+
x_dims_mapping_dst[axis] = -1;
37+
out_dims_mapping.assign(x_dims_mapping_dst.begin(),
38+
x_dims_mapping_dst.end());
39+
indices_dims_mapping = out_dims_mapping;
40+
41+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
42+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
43+
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
44+
out_dist_attr.set_dims_mapping(out_dims_mapping);
45+
TensorDistAttr indices_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
46+
indices_dist_attr.set_dims_mapping(indices_dims_mapping);
47+
48+
return {{x_dist_attr_dst}, {out_dist_attr,indices_dist_attr}};
49+
}
50+
51+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& indices, Tensor out_grad, Scalar k, int axis, bool largest, bool sorted) {
52+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
53+
EXTRACT_SHAPE_AND_DIST_ATTR(indices);
54+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
55+
56+
TensorDistAttr out_grad_dist_attr_dst = CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
57+
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
58+
59+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
60+
x_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
61+
62+
TensorDistAttr indices_dist_attr_dst = CopyTensorDistAttrForOutput(indices_dist_attr_src);
63+
indices_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
64+
65+
TensorDistAttr x_grad_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
66+
x_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
67+
return {{x_dist_attr_dst, indices_dist_attr_dst , out_grad_dist_attr_dst }, {x_grad_dist_attr_dst }};
68+
}
69+
70+
} // namespace distributed
71+
} // namespace phi
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/phi/common/scalar.h"
19+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
20+
#include "paddle/phi/core/distributed/type_defs.h"
21+
22+
namespace phi {
23+
namespace distributed {
24+
25+
SpmdInfo TopkInferSpmd(const DistMetaTensor& x, const Scalar& k, int axis, bool largest, bool sorted);
26+
27+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& indices, const DistMetaTensor& out_grad, Scalar k, int axis, bool largest, bool sorted);
28+
29+
} // namespace distributed
30+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,6 +3516,7 @@
35163516
infer_meta :
35173517
func : UnchangedInferMeta
35183518
param : [x]
3519+
spmd_rule: TopkGradInferSpmd
35193520
kernel :
35203521
func : topk_grad
35213522
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5128,6 +5128,7 @@
51285128
output : Tensor(out), Tensor(indices)
51295129
infer_meta :
51305130
func : TopKInferMeta
5131+
spmd_rule: TopkInferSpmd
51315132
kernel :
51325133
func : topk
51335134
data_type : x

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ if(WITH_DISTRIBUTE)
3535
py_test_modules(test_gather_rule MODULES test_gather_rule)
3636
py_test_modules(test_cumsum_rule MODULES test_cumsum_rule)
3737
py_test_modules(test_argmax_rule MODULES test_argmax_rule)
38+
py_test_modules(test_topk_rule MODULES test_topk_rule)
3839
py_test_modules(test_unbind_rule MODULES test_unbind_rule)
3940
py_test_modules(test_stack_rule MODULES test_stack_rule)
4041
py_test_modules(test_gather_nd_rule MODULES test_gather_nd_rule)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from collections import OrderedDict
17+
from paddle.distributed.auto_parallel.static.dist_attribute import (
18+
DistTensorSpec,
19+
TensorDistAttr,
20+
)
21+
from paddle.distributed.fleet import auto
22+
from paddle.framework import core
23+
24+
25+
class TestTopkSPMDRule(unittest.TestCase):
26+
def setUp(self):
27+
x_shape = [16, 16, 16]
28+
out_shape = [16, 2, 16]
29+
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
30+
31+
x_tensor_dist_attr = TensorDistAttr()
32+
x_tensor_dist_attr.dims_mapping = [-1, -1, -1]
33+
x_tensor_dist_attr.process_mesh = process_mesh
34+
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
35+
out_tensor_dist_attr = TensorDistAttr()
36+
out_tensor_dist_attr.dims_mapping = [-1, -1, -1]
37+
out_tensor_dist_attr.process_mesh = process_mesh
38+
self.out_dist_tensor_spec = DistTensorSpec(out_shape, x_tensor_dist_attr)
39+
40+
self.rule = core.get_phi_spmd_rule("topk")
41+
self.attrs = OrderedDict()
42+
self.attrs['k'] = 2
43+
self.attrs['axis'] = 1
44+
self.attrs['largest'] = True
45+
self.attrs['sorted'] = True
46+
47+
def test_topk_forward(self):
48+
# axis = 1
49+
# [0, 1, -1] --> [0, -1, -1], [0, -1, -1]
50+
self.attrs['axis'] = 1
51+
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
52+
result_dist_attrs = self.rule.infer_forward(
53+
self.x_dist_tensor_spec,
54+
self.attrs['k'],
55+
self.attrs['axis'],
56+
self.attrs['largest'],
57+
self.attrs['sorted']
58+
)
59+
inferred_input_dist_attrs = result_dist_attrs[0]
60+
inferred_output_dist_attrs = result_dist_attrs[1]
61+
62+
self.assertEqual(len(result_dist_attrs), 2)
63+
self.assertEqual(len(inferred_input_dist_attrs), 1)
64+
self.assertEqual(len(inferred_output_dist_attrs), 2)
65+
66+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
67+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
68+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [0, -1, -1])
69+
70+
def test_topk_backward(self):
71+
# axis = 1
72+
# [0, -1, 1] --> [0, -1, 1], [0, -1, 1], [0, -1, 1]
73+
self.attrs['axis'] = 1
74+
self.out_dist_tensor_spec.shape = [16, 2, 16]
75+
self.out_dist_tensor_spec.set_dims_mapping([0, -1, 1])
76+
result_dist_attrs = self.rule.infer_backward(
77+
self.x_dist_tensor_spec,
78+
self.out_dist_tensor_spec,
79+
self.out_dist_tensor_spec,
80+
self.attrs['k'],
81+
self.attrs['axis'],
82+
self.attrs['largest'],
83+
self.attrs['sorted']
84+
)
85+
inferred_input_dist_attrs = result_dist_attrs[0]
86+
inferred_output_dist_attrs = result_dist_attrs[1]
87+
self.assertEqual(len(result_dist_attrs), 2)
88+
self.assertEqual(len(inferred_input_dist_attrs), 3)
89+
self.assertEqual(len(inferred_output_dist_attrs), 1)
90+
91+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
92+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
93+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
94+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping,[0, -1, 1])
95+
96+
97+
if __name__ == "__main__":
98+
unittest.main()

0 commit comments

Comments
 (0)