Skip to content

Commit 8ee1483

Browse files
ooooo-createGITD245
authored andcommitted
[Auto Parallel] Add spmd rule No.6 for mean_all and mean_all_grad ops. (PaddlePaddle#72479)
* add spmd rule for pd_op.mean_all * Fix test config for shape and dims_mapping * refine and add cpp tests * reuse ReductionInfer to MeanAll
1 parent a295f82 commit 8ee1483

File tree

10 files changed

+196
-1
lines changed

10 files changed

+196
-1
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/infermeta/spmd_rules/mean_all.h"
13+
#include "glog/logging.h"
14+
15+
namespace phi {
16+
namespace distributed {
17+
18+
SpmdInfo MeanAllInferSpmd(const DistMetaTensor& x) {
19+
VLOG(4) << "MeanAllInferSpmd Call ReductionInferSpmdBase";
20+
return ReductionInferSpmdBase(
21+
x, {}, false, static_cast<int>(ReduceType::kRedAvg));
22+
}
23+
24+
SpmdInfo MeanAllGradInferSpmd(const DistMetaTensor& x,
25+
const DistMetaTensor& out_grad) {
26+
VLOG(4) << "MeanAllGradInferSpmd Call ReductionGradInferSpmd";
27+
return ReductionGradInferSpmd(x, out_grad, {}, false, true);
28+
}
29+
30+
} // namespace distributed
31+
} // namespace phi
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
#include "paddle/phi/infermeta/spmd_rules/reduction.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
SpmdInfo MeanAllInferSpmd(const DistMetaTensor& x);
24+
25+
SpmdInfo MeanAllGradInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& out_grad);
27+
28+
} // namespace distributed
29+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,11 @@ PD_REGISTER_SPMD_RULE(
521521
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
522522
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));
523523

524+
// mean_all
525+
PD_REGISTER_SPMD_RULE(mean_all,
526+
PD_INFER_SPMD(phi::distributed::MeanAllInferSpmd),
527+
PD_INFER_SPMD(phi::distributed::MeanAllGradInferSpmd));
528+
524529
// layer_norm
525530
PD_REGISTER_SPMD_RULE(
526531
layer_norm,

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ limitations under the License. */
4545
#include "paddle/phi/infermeta/spmd_rules/layer_norm.h"
4646
#include "paddle/phi/infermeta/spmd_rules/logsumexp.h"
4747
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
48+
#include "paddle/phi/infermeta/spmd_rules/mean_all.h"
4849
#include "paddle/phi/infermeta/spmd_rules/moe_combine.h"
4950
#include "paddle/phi/infermeta/spmd_rules/moe_gate_dispatch.h"
5051
#include "paddle/phi/infermeta/spmd_rules/nonzero.h"

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,7 @@
21732173
infer_meta :
21742174
func : UnchangedExceptLayoutInferMeta
21752175
param: [x]
2176+
spmd_rule : MeanAllGradInferSpmd
21762177
kernel :
21772178
func : mean_all_grad
21782179
data_type: out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,6 +3451,7 @@
34513451
output : Tensor
34523452
infer_meta :
34533453
func : MeanAllInferMeta
3454+
spmd_rule : MeanAllInferSpmd
34543455
kernel :
34553456
func : mean_all
34563457
backward : mean_all_grad

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ if(WITH_DISTRIBUTE)
4545
py_test_modules(test_nonzero_rule MODULES test_nonzero_rule)
4646
if(NOT WITH_ROCM)
4747
py_test_modules(test_add_n_rule MODULES test_add_n_rule)
48+
py_test_modules(test_mean_all_rule MODULES test_mean_all_rule)
4849
py_test_modules(test_argmin_rule MODULES test_argmin_rule)
4950
endif()
5051
# End of unittests WITH single card WITHOUT timeout
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from paddle.distributed.auto_parallel.static.dist_attribute import (
18+
DistTensorSpec,
19+
TensorDistAttr,
20+
)
21+
from paddle.distributed.fleet import auto
22+
from paddle.framework import core
23+
24+
25+
class TestMeanAllSPMDRule(unittest.TestCase):
26+
def setUp(self):
27+
self.rule = core.get_phi_spmd_rule("mean_all")
28+
x_shape = [4, 8]
29+
out_shape = []
30+
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
31+
32+
x_tensor_dist_attr = TensorDistAttr()
33+
x_tensor_dist_attr.dims_mapping = [1, 0]
34+
x_tensor_dist_attr.process_mesh = process_mesh
35+
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
36+
37+
out_tensor_dist_attr = TensorDistAttr()
38+
out_tensor_dist_attr.dims_mapping = []
39+
out_tensor_dist_attr.process_mesh = process_mesh
40+
self.out_dist_tensor_spec = DistTensorSpec(
41+
out_shape, out_tensor_dist_attr
42+
)
43+
44+
def test_infer_forward(self):
45+
# [0, -1] --> [], partial_on_dim:[0]
46+
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
47+
result_dist_attrs = self.rule.infer_forward(self.x_dist_tensor_spec)
48+
49+
self.assertEqual(len(result_dist_attrs), 2)
50+
inferred_input_dist_attrs = result_dist_attrs[0]
51+
inferred_output_dist_attrs = result_dist_attrs[1]
52+
53+
self.assertEqual(len(inferred_input_dist_attrs), 1)
54+
self.assertEqual(len(inferred_output_dist_attrs), 1)
55+
56+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1])
57+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [])
58+
self.assertEqual(inferred_output_dist_attrs[0]._is_partial(), True)
59+
self.assertEqual(inferred_output_dist_attrs[0]._partial_dims(), {0})
60+
61+
def test_infer_backward(self):
62+
# [] --> [-1, -1], []
63+
self.out_dist_tensor_spec.shape = []
64+
self.out_dist_tensor_spec.set_dims_mapping([])
65+
result_dist_attrs = self.rule.infer_backward(
66+
self.x_dist_tensor_spec,
67+
self.out_dist_tensor_spec,
68+
)
69+
70+
self.assertEqual(len(result_dist_attrs), 2)
71+
inferred_input_dist_attrs = result_dist_attrs[0]
72+
inferred_output_dist_attrs = result_dist_attrs[1]
73+
74+
self.assertEqual(len(inferred_input_dist_attrs), 2)
75+
self.assertEqual(len(inferred_output_dist_attrs), 1)
76+
77+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, -1])
78+
self.assertEqual(inferred_input_dist_attrs[1].dims_mapping, [])
79+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [-1, -1])
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

test/cpp/auto_parallel/spmd_rule_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,6 +2426,49 @@ TEST(Dropout, Ctor) {
24262426
check_dim_mapping(backward_info.second[0], {0, -1, -1});
24272427
}
24282428

2429+
TEST(MeanAll, Ctor) {
2430+
std::vector<int64_t> mesh_shape = {2, 2};
2431+
std::vector<int64_t> process_ids = {0, 1, 2, 3};
2432+
std::vector<std::string> dim_names = {"x", "y"};
2433+
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
2434+
2435+
// test forward
2436+
// [0, -1] --> [], partial_on_dim:[0]
2437+
auto t_dist_attr = TensorDistAttr();
2438+
t_dist_attr.set_process_mesh(process_mesh);
2439+
t_dist_attr.set_dims_mapping({0, -1});
2440+
t_dist_attr.set_dynamic_dims({false, false});
2441+
phi::distributed::DistMetaTensor x =
2442+
phi::distributed::DistMetaTensor(common::make_ddim({4, 8}), t_dist_attr);
2443+
phi::distributed::SpmdInfo forward_info =
2444+
phi::distributed::MeanAllInferSpmd(x);
2445+
2446+
EXPECT_EQ(forward_info.first.size(), 1UL);
2447+
EXPECT_EQ(forward_info.second.size(), 1UL);
2448+
2449+
check_dim_mapping(forward_info.first[0], {0, -1});
2450+
check_dim_mapping(forward_info.second[0], {});
2451+
check_partial_dims(forward_info.second[0], {0});
2452+
2453+
// test backward
2454+
// [] --> [-1, -1], []
2455+
auto out_grad_dist_attr = TensorDistAttr();
2456+
out_grad_dist_attr.set_process_mesh(process_mesh);
2457+
out_grad_dist_attr.set_dims_mapping({});
2458+
out_grad_dist_attr.set_dynamic_dims({});
2459+
phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor(
2460+
common::make_ddim({}), out_grad_dist_attr);
2461+
phi::distributed::SpmdInfo backward_info =
2462+
phi::distributed::MeanAllGradInferSpmd(x, out_grad);
2463+
2464+
EXPECT_EQ(backward_info.first.size(), 2UL);
2465+
EXPECT_EQ(backward_info.second.size(), 1UL);
2466+
2467+
check_dim_mapping(backward_info.first[0], {-1, -1});
2468+
check_dim_mapping(backward_info.first[1], {});
2469+
check_dim_mapping(backward_info.second[0], {-1, -1});
2470+
}
2471+
24292472
} // namespace auto_parallel
24302473
} // namespace distributed
24312474
} // namespace paddle

third_party/zlib

Submodule zlib updated 171 files

0 commit comments

Comments
 (0)