Skip to content

Commit cd0b795

Browse files
【Comm】add more tests for flagcx (#72009)
1 parent 92e60e0 commit cd0b795

9 files changed

+106
-8
lines changed

test/collective/CMakeLists.txt

+8-8
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
8787
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;FLAGS_enable_pir_api=0"
8888
)
8989
set_tests_properties(test_collective_barrier_api
90-
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
90+
PROPERTIES TIMEOUT "450" LABELS "RUN_TYPE=DIST")
9191
endif()
9292
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
9393
bash_test_modules(
@@ -170,7 +170,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
170170
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
171171
ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
172172
set_tests_properties(test_collective_isend_irecv_api
173-
PROPERTIES TIMEOUT "160" LABELS "RUN_TYPE=DIST")
173+
PROPERTIES TIMEOUT "320" LABELS "RUN_TYPE=DIST")
174174
endif()
175175
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
176176
py_test_modules(
@@ -268,7 +268,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
268268
test_collective_gather_api MODULES test_collective_gather_api ENVS
269269
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
270270
set_tests_properties(test_collective_gather_api
271-
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
271+
PROPERTIES TIMEOUT "360" LABELS "RUN_TYPE=DIST")
272272
endif()
273273
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
274274
py_test_modules(
@@ -279,7 +279,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
279279
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;FLAGS_enable_pir_api=0"
280280
)
281281
set_tests_properties(test_collective_sendrecv_api
282-
PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST")
282+
PROPERTIES TIMEOUT "1000" LABELS "RUN_TYPE=DIST")
283283
endif()
284284
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
285285
py_test_modules(
@@ -335,31 +335,31 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
335335
test_communication_stream_reduce_api ENVS
336336
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
337337
set_tests_properties(test_communication_stream_reduce_api
338-
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
338+
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
339339
endif()
340340
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
341341
py_test_modules(
342342
test_communication_stream_reduce_scatter_api MODULES
343343
test_communication_stream_reduce_scatter_api ENVS
344344
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
345345
set_tests_properties(test_communication_stream_reduce_scatter_api
346-
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
346+
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
347347
endif()
348348
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
349349
py_test_modules(
350350
test_communication_stream_scatter_api MODULES
351351
test_communication_stream_scatter_api ENVS
352352
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
353353
set_tests_properties(test_communication_stream_scatter_api
354-
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
354+
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
355355
endif()
356356
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
357357
py_test_modules(
358358
test_communication_stream_sendrecv_api MODULES
359359
test_communication_stream_sendrecv_api ENVS
360360
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
361361
set_tests_properties(test_communication_stream_sendrecv_api
362-
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
362+
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
363363
endif()
364364
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
365365
py_test_modules(

test/collective/test_collective_barrier_api.py

+9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def test_barrier_gloo(self):
3333
"collective_barrier_api.py", "barrier", "gloo", "5"
3434
)
3535

36+
def test_barrier_flagcx(self):
37+
if paddle.base.core.is_compiled_with_flagcx():
38+
self.check_with_place(
39+
"collective_barrier_api.py",
40+
"barrier",
41+
"flagcx",
42+
static_mode="0",
43+
)
44+
3645

3746
if __name__ == '__main__':
3847
unittest.main()

test/collective/test_collective_gather_api.py

+21
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,27 @@ def test_gather_nccl_dygraph(self):
4747
dtype=dtype,
4848
)
4949

50+
def test_gather_flagcx_dygraph(self):
51+
dtypes_to_test = [
52+
"float16",
53+
"float32",
54+
"float64",
55+
"int32",
56+
"int64",
57+
"int8",
58+
"uint8",
59+
"bool",
60+
]
61+
if paddle.base.core.is_compiled_with_flagcx():
62+
for dtype in dtypes_to_test:
63+
self.check_with_place(
64+
"collective_gather_api_dygraph.py",
65+
"gather",
66+
"flagcx",
67+
static_mode="0",
68+
dtype=dtype,
69+
)
70+
5071

5172
if __name__ == "__main__":
5273
unittest.main()

test/collective/test_collective_isend_irecv_api.py

+23
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import legacy_test.test_collective_api_base as test_base
1818

19+
import paddle
20+
1921

2022
class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
2123
def _setup_config(self):
@@ -43,6 +45,27 @@ def test_isend_irecv_nccl_dygraph(self):
4345
dtype=dtype,
4446
)
4547

48+
def test_isend_irecv_flagcx_dygraph(self):
49+
dtypes_to_test = [
50+
"float16",
51+
"float32",
52+
"float64",
53+
"int32",
54+
"int64",
55+
"int8",
56+
"uint8",
57+
"bool",
58+
]
59+
if paddle.base.core.is_compiled_with_flagcx():
60+
for dtype in dtypes_to_test:
61+
self.check_with_place(
62+
"collective_isend_irecv_api_dygraph.py",
63+
"sendrecv",
64+
"flagcx",
65+
static_mode="0",
66+
dtype=dtype,
67+
)
68+
4669

4770
if __name__ == "__main__":
4871
unittest.main()

test/collective/test_collective_sendrecv_api.py

+21
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ def test_sendrecv_nccl_dygraph(self):
7575
dtype=dtype,
7676
)
7777

78+
def test_sendrecv_flagcx_dygraph(self):
79+
dtypes_to_test = [
80+
"float16",
81+
"float32",
82+
"float64",
83+
"int32",
84+
"int64",
85+
"int8",
86+
"uint8",
87+
"bool",
88+
]
89+
if paddle.base.core.is_compiled_with_flagcx():
90+
for dtype in dtypes_to_test:
91+
self.check_with_place(
92+
"collective_sendrecv_api_dygraph.py",
93+
"sendrecv",
94+
"flagcx",
95+
static_mode="0",
96+
dtype=dtype,
97+
)
98+
7899

79100
if __name__ == "__main__":
80101
unittest.main()

test/collective/test_communication_stream_reduce_api.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import test_communication_api_base as test_base
1818

19+
import paddle
20+
1921

2022
class TestCommunicationStreamReduceAPI(test_base.CommunicationTestDistBase):
2123
def setUp(self):
@@ -26,7 +28,11 @@ def setUp(self):
2628
"dtype": "float32",
2729
"seeds": str(self._seeds),
2830
}
31+
backend_list = ["nccl"]
32+
if paddle.base.core.is_compiled_with_flagcx():
33+
backend_list.append("flagcx")
2934
self._changeable_envs = {
35+
"backend": backend_list,
3036
"sync_op": ["True", "False"],
3137
"use_calc_stream": ["True", "False"],
3238
}

test/collective/test_communication_stream_reduce_scatter_api.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import test_communication_api_base as test_base
1818

19+
import paddle
20+
1921

2022
class TestCommunicationStreamReduceScatterAPI(
2123
test_base.CommunicationTestDistBase
@@ -28,7 +30,11 @@ def setUp(self):
2830
"dtype": "float32",
2931
"seeds": str(self._seeds),
3032
}
33+
backend_list = ["nccl"]
34+
if paddle.base.core.is_compiled_with_flagcx():
35+
backend_list.append("flagcx")
3136
self._changeable_envs = {
37+
"backend": backend_list,
3238
"sync_op": ["True", "False"],
3339
"use_calc_stream": ["True", "False"],
3440
}

test/collective/test_communication_stream_scatter_api.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import test_communication_api_base as test_base
1818

19+
import paddle
20+
1921

2022
class TestCommunicationStreamScatterAPI(test_base.CommunicationTestDistBase):
2123
def setUp(self):
@@ -26,7 +28,11 @@ def setUp(self):
2628
"dtype": "float32",
2729
"seeds": str(self._seeds),
2830
}
31+
backend_list = ["nccl"]
32+
if paddle.base.core.is_compiled_with_flagcx():
33+
backend_list.append("flagcx")
2934
self._changeable_envs = {
35+
"backend": backend_list,
3036
"sync_op": ["True", "False"],
3137
"use_calc_stream": ["True", "False"],
3238
}

test/collective/test_communication_stream_sendrecv_api.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import test_communication_api_base as test_base
1818

19+
import paddle
20+
1921

2022
class TestCommunicationStreamSendRecvAPI(test_base.CommunicationTestDistBase):
2123
def setUp(self):
@@ -26,7 +28,11 @@ def setUp(self):
2628
"dtype": "float32",
2729
"seeds": str(self._seeds),
2830
}
31+
backend_list = ["nccl"]
32+
if paddle.base.core.is_compiled_with_flagcx():
33+
backend_list.append("flagcx")
2934
self._changeable_envs = {
35+
"backend": backend_list,
3036
"sync_op": ["True", "False"],
3137
"use_calc_stream": ["True", "False"],
3238
}

0 commit comments

Comments
 (0)