Skip to content

Commit d2997b7

Browse files
zty-kingxuxinyi389
authored andcommitted
新增输入带有分布式属性的测试案例
1 parent ea4bc20 commit d2997b7

File tree

4 files changed

+261
-143
lines changed

4 files changed

+261
-143
lines changed

python/setup.py.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ packages=['paddle',
837837
'paddle.distributed.fleet.meta_parallel.parallel_layers',
838838
'paddle.distributed.auto_parallel',
839839
'paddle.distributed.auto_parallel.intermediate',
840+
'paddle.distributed.auto_parallel.pipelining',
840841
'paddle.distributed.auto_parallel.dygraph',
841842
'paddle.distributed.auto_parallel.static',
842843
'paddle.distributed.auto_parallel.static.operators',

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,7 @@ def get_setup_parameters():
22592259
'paddle.distributed.fleet.meta_parallel.parallel_layers',
22602260
'paddle.distributed.auto_parallel',
22612261
'paddle.distributed.auto_parallel.intermediate',
2262+
'paddle.distributed.auto_parallel.pipelining',
22622263
'paddle.distributed.auto_parallel.dygraph',
22632264
'paddle.distributed.auto_parallel.static',
22642265
'paddle.distributed.auto_parallel.static.operators',

test/auto_parallel/microbatch_demo.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle.distributed.auto_parallel.pipelining.microbatch import (
17+
TensorChunkSpec,
18+
merge_chunks,
19+
split_args_kwargs_into_chunks,
20+
)
21+
22+
23+
class TestMicrobatch:
24+
def __init__(self):
25+
paddle.seed(2024)
26+
paddle.distributed.init_parallel_env()
27+
self.batch_size = 8
28+
self.feature_size = 4
29+
self.tensor = paddle.randn([self.batch_size, self.feature_size])
30+
self.rank = paddle.distributed.get_rank()
31+
32+
def test_tensor_chunk_spec(self):
33+
# Test creation and string representation of TensorChunkSpec
34+
spec = TensorChunkSpec(0)
35+
assert spec.split_axis == 0
36+
assert str(spec) == "TensorChunkSpec(0)"
37+
assert "TensorChunkSpec(0)" in repr(spec)
38+
39+
def test_split_args_kwargs(self):
40+
# Test basic parameter splitting
41+
args = (self.tensor,)
42+
kwargs = {"input": self.tensor}
43+
num_chunks = 2
44+
45+
args_split, kwargs_split = split_args_kwargs_into_chunks(
46+
args, kwargs, num_chunks
47+
)
48+
49+
assert len(args_split) == num_chunks
50+
assert len(kwargs_split) == num_chunks
51+
assert args_split[0][0].shape[0] == self.batch_size // num_chunks
52+
53+
# Test splitting with non-tensor parameters
54+
args = (self.tensor, 42, "string")
55+
kwargs = {"tensor": self.tensor, "number": 42}
56+
num_chunks = 2
57+
58+
args_split, kwargs_split = split_args_kwargs_into_chunks(
59+
args, kwargs, num_chunks
60+
)
61+
62+
# Verify non-tensor parameters remain unchanged in each chunk
63+
assert args_split[0][1] == 42
64+
assert args_split[0][2] == "string"
65+
assert kwargs_split[0]["number"] == 42
66+
67+
# Test splitting with custom specification
68+
tensor_2d = paddle.randn([4, 6])
69+
args = (tensor_2d,)
70+
args_chunk_spec = (TensorChunkSpec(1),) # Split on second dimension
71+
72+
args_split, _ = split_args_kwargs_into_chunks(
73+
args, None, 2, args_chunk_spec
74+
)
75+
76+
assert args_split[0][0].shape[1] == 3
77+
78+
def test_merge_chunks(self):
79+
# Test merging chunks
80+
chunk1 = paddle.randn([4, 4])
81+
chunk2 = paddle.randn([4, 4])
82+
chunks = [chunk1, chunk2]
83+
chunk_spec = [TensorChunkSpec(0)]
84+
85+
merged = merge_chunks(chunks, chunk_spec)
86+
assert merged.shape[0] == 8
87+
88+
# Test merging chunks containing non-tensor values
89+
chunks = [(paddle.randn([4, 4]), 42)] * 2
90+
chunk_spec = [TensorChunkSpec(0), None]
91+
92+
merged = merge_chunks(chunks, chunk_spec)
93+
assert merged[1] == 42
94+
95+
# Test error cases
96+
try:
97+
# Test error when tensor size is smaller than number of chunks
98+
small_tensor = paddle.randn([1, 4])
99+
split_args_kwargs_into_chunks((small_tensor,), None, 2)
100+
raise AssertionError("Expected ValueError")
101+
except ValueError:
102+
pass
103+
104+
try:
105+
# Test error when parameter count doesn't match chunk_spec length
106+
split_args_kwargs_into_chunks(
107+
(self.tensor,),
108+
None,
109+
2,
110+
(TensorChunkSpec(0), TensorChunkSpec(1)),
111+
)
112+
raise AssertionError("Expected ValueError")
113+
except AssertionError:
114+
pass
115+
116+
# test merge empty chunks
117+
empty_chunks = []
118+
result = merge_chunks(empty_chunks, None)
119+
assert result == []
120+
121+
# test tensor size smaller than chunks number
122+
small_tensor = paddle.randn([1, 4])
123+
try:
124+
split_args_kwargs_into_chunks((small_tensor,), None, 2)
125+
raise AssertionError("Expected ValueError")
126+
except ValueError:
127+
pass
128+
129+
# test merge non-tensor with tensor spec
130+
chunks = [(42,), (42,)]
131+
chunk_spec = (TensorChunkSpec(0),)
132+
result = merge_chunks(chunks, chunk_spec)
133+
assert result[0] == 42
134+
135+
def test_nested_structure(self):
136+
# test nested tensor
137+
nested_tensor = [
138+
[paddle.randn([4, 2]), paddle.randn([4, 2])],
139+
[paddle.randn([4, 2]), paddle.randn([4, 2])],
140+
]
141+
142+
args = (nested_tensor,)
143+
kwargs = {"nested": nested_tensor}
144+
145+
args_split, kwargs_split = split_args_kwargs_into_chunks(
146+
args, kwargs, 2
147+
)
148+
149+
assert len(args_split) == 2
150+
assert len(args_split[0][0]) == 2
151+
assert len(args_split[0][0][0]) == 2
152+
assert args_split[0][0][0][0].shape == [2, 2]
153+
154+
assert len(kwargs_split) == 2
155+
assert len(kwargs_split[0]["nested"]) == 2
156+
assert len(kwargs_split[0]["nested"][0]) == 2
157+
assert kwargs_split[0]["nested"][0][0].shape == [2, 2]
158+
159+
merged_args = merge_chunks(
160+
args_split,
161+
[
162+
[TensorChunkSpec(0), TensorChunkSpec(0)],
163+
[TensorChunkSpec(0), TensorChunkSpec(0)],
164+
],
165+
)
166+
167+
assert merged_args[0][0][0].shape == [4, 2]
168+
assert merged_args[0][1][1].shape == [4, 2]
169+
170+
assert len(merged_args[0]) == 2
171+
assert len(merged_args[0][0]) == 2
172+
173+
def test_dist_tensor_split_and_merge(self):
174+
# test dist tensor split and merge
175+
base_tensor = self.tensor
176+
dense_tensor, _ = split_args_kwargs_into_chunks(
177+
(base_tensor,),
178+
None,
179+
2,
180+
)
181+
mesh = paddle.distributed.ProcessMesh([0, 1], dim_names=["dp"])
182+
dist_tensor = paddle.distributed.shard_tensor(
183+
self.tensor,
184+
mesh,
185+
[paddle.distributed.Shard(0)],
186+
)
187+
dist_tensor_split, _ = split_args_kwargs_into_chunks(
188+
(dist_tensor,),
189+
None,
190+
2,
191+
)
192+
if self.rank == 0:
193+
is_equal = (
194+
dist_tensor_split[0][0]
195+
._local_value()
196+
.equal_all(dense_tensor[0][0][:2])
197+
)
198+
assert is_equal.item()
199+
is_equal = (
200+
dist_tensor_split[1][0]
201+
._local_value()
202+
.equal_all(dense_tensor[0][0][2:])
203+
)
204+
assert is_equal.item()
205+
else:
206+
is_equal = (
207+
dist_tensor_split[0][0]
208+
._local_value()
209+
.equal_all(dense_tensor[1][0][:2])
210+
)
211+
assert is_equal.item()
212+
is_equal = (
213+
dist_tensor_split[1][0]
214+
._local_value()
215+
.equal_all(dense_tensor[1][0][2:])
216+
)
217+
assert is_equal.item()
218+
chunk1 = dist_tensor_split[0][0]
219+
chunk2 = dist_tensor_split[1][0]
220+
chunk_spec = [TensorChunkSpec(0)]
221+
merged_chunk = merge_chunks([chunk1, chunk2], chunk_spec)
222+
if self.rank == 0:
223+
is_equal = merged_chunk._local_value().equal_all(base_tensor[:4])
224+
assert is_equal.item()
225+
else:
226+
is_equal = merged_chunk._local_value().equal_all(base_tensor[4:])
227+
assert is_equal.item()
228+
229+
def run_all_tests(self):
230+
"""Run all test methods"""
231+
self.test_tensor_chunk_spec()
232+
self.test_split_args_kwargs()
233+
self.test_merge_chunks()
234+
self.test_nested_structure()
235+
self.test_dist_tensor_split_and_merge()
236+
237+
238+
if __name__ == "__main__":
239+
TestMicrobatch().run_all_tests()

0 commit comments

Comments
 (0)