Skip to content

Commit f678ec6

Browse files
committed
split the general func for topk, cummax and cummin;add spmd_rule for cummax and cummin.
1 parent e14eecb commit f678ec6

File tree

9 files changed

+173
-10
lines changed

9 files changed

+173
-10
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/cummax.h"
16+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
17+
18+
namespace phi {
19+
namespace distributed {
20+
21+
SpmdInfo CummaxInferSpmd(const DistMetaTensor& x, int axis, DataType dtype) {
22+
return TopkInferSpmdBase(x, axis);
23+
}
24+
25+
SpmdInfo CummaxGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& indices,
27+
const DistMetaTensor& out_grad,
28+
int axis,
29+
DataType dtype) {
30+
return TopkGradInferSpmdBase(x, indices, out_grad, axis);
31+
}
32+
33+
} // namespace distributed
34+
} // namespace phi
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
23+
SpmdInfo CummaxInferSpmd(const DistMetaTensor& x, int axis, DataType dtype);
24+
25+
SpmdInfo CummaxGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& indices,
27+
const DistMetaTensor& out_grad,
28+
int axis,
29+
DataType dtype);
30+
31+
} // namespace distributed
32+
} // namespace phi
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/cummin.h"
16+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
17+
18+
namespace phi {
19+
namespace distributed {
20+
21+
SpmdInfo CumminInferSpmd(const DistMetaTensor& x, int axis, DataType dtype) {
22+
return TopkInferSpmdBase(x, axis);
23+
}
24+
25+
SpmdInfo CumminGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& indices,
27+
const DistMetaTensor& out_grad,
28+
int axis,
29+
DataType dtype) {
30+
return TopkGradInferSpmdBase(x, indices, out_grad, axis);
31+
}
32+
33+
} // namespace distributed
34+
} // namespace phi
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
23+
SpmdInfo CumminInferSpmd(const DistMetaTensor& x, int axis, DataType dtype);
24+
25+
SpmdInfo CumminGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& indices,
27+
const DistMetaTensor& out_grad,
28+
int axis,
29+
DataType dtype);
30+
31+
} // namespace distributed
32+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/phi/infermeta/spmd_rules/rules.h"
16+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
1617

1718
/**
1819
* Design Notes:
@@ -752,4 +753,14 @@ PD_REGISTER_SPMD_RULE(nonzero,
752753

753754
// add_n
754755
PD_REGISTER_SPMD_RULE(add_n, PD_INFER_SPMD(phi::distributed::AddNInferSpmd));
756+
757+
// cummax
758+
PD_REGISTER_SPMD_RULE(cummax,
759+
PD_INFER_SPMD(phi::distributed::TopkInferSpmdBase),
760+
PD_INFER_SPMD(phi::distributed::TopkGradInferSpmdBase));
761+
// cummin
762+
PD_REGISTER_SPMD_RULE(cummin,
763+
PD_INFER_SPMD(phi::distributed::TopkInferSpmdBase),
764+
PD_INFER_SPMD(phi::distributed::TopkGradInferSpmdBase));
765+
755766
} // namespace phi::distributed

paddle/phi/infermeta/spmd_rules/topk.cc

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ limitations under the License. */
2020
namespace phi {
2121
namespace distributed {
2222

23-
SpmdInfo TopkInferSpmd(
24-
const DistMetaTensor& x, int k, int axis, bool largest, bool sorted) {
23+
SpmdInfo TopkInferSpmdBase(const DistMetaTensor& x, int axis) {
2524
// Verify input args
2625
EXTRACT_SHAPE_AND_DIST_ATTR(x);
2726
axis = axis < 0 ? x_ndim + axis : axis;
@@ -60,13 +59,10 @@ SpmdInfo TopkInferSpmd(
6059
return {{x_dist_attr_dst}, {out_dist_attr_dst, indices_dist_attr_dst}};
6160
}
6261

63-
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
64-
const DistMetaTensor& indices,
65-
const DistMetaTensor& out_grad,
66-
int k,
67-
int axis,
68-
bool largest,
69-
bool sorted) {
62+
SpmdInfo TopkGradInferSpmdBase(const DistMetaTensor& x,
63+
const DistMetaTensor& indices,
64+
const DistMetaTensor& out_grad,
65+
int axis) {
7066
// Verify input args
7167
EXTRACT_SHAPE_AND_DIST_ATTR(x);
7268
EXTRACT_SHAPE_AND_DIST_ATTR(indices);
@@ -141,6 +137,22 @@ SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
141137
return {{x_dist_attr_dst, indices_dist_attr_dst, out_grad_dist_attr_dst},
142138
{x_grad_dist_attr_dst}};
143139
}
140+
141+
SpmdInfo TopkInferSpmd(
142+
const DistMetaTensor& x, int k, int axis, bool largest, bool sorted) {
143+
return TopkInferSpmdBase(x, axis);
144+
}
145+
146+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
147+
const DistMetaTensor& indices,
148+
const DistMetaTensor& out_grad,
149+
int k,
150+
int axis,
151+
bool largest,
152+
bool sorted) {
153+
return TopkGradInferSpmdBase(x, indices, out_grad, axis);
154+
}
155+
144156
SpmdInfo TopkInferSpmdDynamic(const DistMetaTensor& x,
145157
const Scalar& k,
146158
int axis,

paddle/phi/infermeta/spmd_rules/topk.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ limitations under the License. */
2020

2121
namespace phi {
2222
namespace distributed {
23-
23+
SpmdInfo TopkInferSpmdBase(const DistMetaTensor& x, int axis);
24+
SpmdInfo TopkGradInferSpmdBase(const DistMetaTensor& x,
25+
const DistMetaTensor& indices,
26+
const DistMetaTensor& out_grad,
27+
int axis);
2428
SpmdInfo TopkInferSpmd(
2529
const DistMetaTensor& x, int k, int axis, bool largest, bool sorted);
2630

paddle/phi/ops/yaml/backward.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@
697697
infer_meta :
698698
func : UnchangedInferMeta
699699
param: [x]
700+
spmd_rule : CummaxGradInferSpmd
700701
kernel :
701702
func : cummax_grad
702703
data_type : out_grad
@@ -708,6 +709,7 @@
708709
infer_meta :
709710
func : UnchangedInferMeta
710711
param: [x]
712+
spmd_rule : CumminGradInferSpmd
711713
kernel :
712714
func : cummin_grad
713715
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,7 @@
12751275
output : Tensor(out), Tensor(indices)
12761276
infer_meta :
12771277
func : CumWithIndicesInferMeta
1278+
spmd_rule : CummaxInferSpmd
12781279
kernel :
12791280
func : cummax
12801281
data_type : x
@@ -1286,6 +1287,7 @@
12861287
output : Tensor(out), Tensor(indices)
12871288
infer_meta :
12881289
func : CumWithIndicesInferMeta
1290+
spmd_rule : CumminInferSpmd
12891291
kernel :
12901292
func : cummin
12911293
data_type : x

0 commit comments

Comments
 (0)