Skip to content

Commit eb90714

Browse files
committed
add_test
1 parent b589534 commit eb90714

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

test/auto_parallel/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
4545
set_tests_properties(test_pipeline_scheduler
4646
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 400)
4747
endif()
48+
py_test_modules(test_process_mesh MODULES test_process_mesh)
4849
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
4950
set_tests_properties(test_reshard_r_to_p
5051
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import paddle
17+
import paddle.distributed as dist
18+
19+
20+
class TestProcessMesh:
21+
def init_dist_env(self):
22+
dist.init_parallel_env()
23+
paddle.seed(2025)
24+
25+
def test_get_submesh_with_dim(self):
26+
curr_rank = dist.get_rank()
27+
28+
# Test 2D mesh
29+
mesh_2d = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["dp", "tp"])
30+
31+
# Test case 1: Get submesh for dp dimension
32+
dp_mesh = mesh_2d.get_submesh_with_dim("dp")
33+
if curr_rank == 0:
34+
assert dp_mesh.process_ids == [0, 2]
35+
elif curr_rank == 1:
36+
assert dp_mesh.process_ids == [1, 3]
37+
38+
# Test case 2: Get submesh for tp dimension
39+
tp_mesh = mesh_2d.get_submesh_with_dim("tp")
40+
if curr_rank == 0:
41+
assert tp_mesh.process_ids == [0, 1]
42+
elif curr_rank == 1:
43+
assert tp_mesh.process_ids == [0, 1]
44+
45+
# Test case 3: 3D mesh with 8 cards (2x2x2)
46+
mesh_3d = dist.ProcessMesh(
47+
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["pp", "dp", "tp"]
48+
)
49+
50+
# Test each dimension
51+
pp_mesh = mesh_3d.get_submesh_with_dim("pp")
52+
dp_mesh = mesh_3d.get_submesh_with_dim("dp")
53+
tp_mesh = mesh_3d.get_submesh_with_dim("tp")
54+
55+
# Verify pp dimension results
56+
if curr_rank == 0:
57+
assert pp_mesh.process_ids == [0, 4]
58+
elif curr_rank == 1:
59+
assert pp_mesh.process_ids == [1, 5]
60+
61+
# Verify dp dimension results
62+
if curr_rank == 0:
63+
assert dp_mesh.process_ids == [0, 2]
64+
elif curr_rank == 1:
65+
assert dp_mesh.process_ids == [1, 3]
66+
67+
# Verify tp dimension results
68+
if curr_rank == 0:
69+
assert tp_mesh.process_ids == [0, 1]
70+
elif curr_rank == 1:
71+
assert tp_mesh.process_ids == [0, 1]
72+
73+
# Test case 4: When rank is not in the mesh
74+
mesh_small = dist.ProcessMesh([0, 1], dim_names=["x"])
75+
if curr_rank not in [0, 1]:
76+
assert mesh_small.get_submesh_with_dim("x") is None
77+
78+
def test_get_group(self):
79+
curr_rank = dist.get_rank()
80+
81+
# Test case 1: Single dimension mesh without dim_name
82+
mesh_1d = dist.ProcessMesh([0, 1], dim_names=["x"])
83+
if curr_rank in [0, 1]:
84+
group_1d = mesh_1d.get_group()
85+
assert isinstance(group_1d, dist.communication.group.Group)
86+
87+
# Test case 2: Single dimension mesh with correct dim_name
88+
group_1d_with_name = mesh_1d.get_group(dim_name="x")
89+
assert isinstance(
90+
group_1d_with_name, dist.communication.group.Group
91+
)
92+
93+
# Test case 3: Single dimension mesh with wrong dim_name
94+
try:
95+
mesh_1d.get_group(dim_name="wrong_name")
96+
raise AssertionError("Should raise ValueError")
97+
except ValueError:
98+
pass
99+
100+
# Test case 4: Multi-dimension mesh
101+
mesh_2d = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["dp", "tp"])
102+
if curr_rank in [0, 1, 2, 3]:
103+
# Test without dim_name
104+
try:
105+
mesh_2d.get_group()
106+
raise AssertionError("Should raise ValueError")
107+
except ValueError:
108+
pass
109+
110+
# Test with correct dim_name
111+
group_2d = mesh_2d.get_group(dim_name="dp")
112+
assert isinstance(group_2d, dist.communication.group.Group)
113+
114+
# Test with wrong dim_name
115+
try:
116+
mesh_2d.get_group(dim_name="wrong_name")
117+
raise AssertionError("Should raise ValueError")
118+
except ValueError:
119+
pass
120+
121+
def test_process_mesh(self):
122+
self.init_dist_env()
123+
self.test_get_submesh_with_dim()
124+
self.test_get_group()
125+
126+
127+
if __name__ == '__main__':
128+
TestProcessMesh().test_process_mesh()
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import collective.test_communication_api_base as test_base
18+
19+
20+
class TestProcessMeshPass(test_base.CommunicationTestDistBase):
21+
def setUp(self):
22+
super().setUp(
23+
num_of_devices=2,
24+
timeout=50,
25+
)
26+
self._default_envs = {
27+
"FLAGS_cudnn_deterministic": "1",
28+
"FLAGS_enable_pir_api": "1",
29+
}
30+
self._changeable_envs = {
31+
"backend": ["gpu"],
32+
}
33+
34+
def test_process_mesh(self):
35+
envs_list = test_base.gen_product_envs_list(
36+
self._default_envs, self._changeable_envs
37+
)
38+
for envs in envs_list:
39+
self.run_test_case(
40+
"process_mesh_demo_unittest.py",
41+
user_defined_envs=envs,
42+
)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main()

0 commit comments

Comments
 (0)