Skip to content

Commit f84a368

Browse files
committed
support_nested_structure
1 parent 76847bd commit f84a368

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

python/paddle/distributed/auto_parallel/pipelining/microbatch.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any
1919

2020
import paddle
21-
from paddle.utils import flatten, pack_sequence_as
21+
from paddle.utils import flatten, map_structure, pack_sequence_as
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -101,8 +101,8 @@ def _split_args_helper(
101101
)
102102
chunk_args[key] = arg_of_curr_chunk
103103

104-
# pack chunk_args as the origin args_dict
105-
chunk_args = pack_sequence_as(args_dict, chunk_args)
104+
# flatten chunk_args first, and then pack chunk_args as the origin args_dict
105+
chunk_args = pack_sequence_as(args_dict, flatten(chunk_args))
106106
args_split.append(chunk_args)
107107
return args_split
108108

@@ -111,8 +111,24 @@ def split_args_kwargs_into_chunks(
111111
args: tuple[Any, ...],
112112
kwargs: dict[str, Any] | None,
113113
chunks: int,
114-
args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
115-
kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
114+
args_chunk_spec: (
115+
tuple[
116+
tuple[TensorChunkSpec, ...]
117+
| list[TensorChunkSpec, ...]
118+
| TensorChunkSpec,
119+
...,
120+
]
121+
| None
122+
) = None,
123+
kwargs_chunk_spec: (
124+
dict[
125+
str,
126+
tuple[TensorChunkSpec, ...]
127+
| list[TensorChunkSpec, ...]
128+
| TensorChunkSpec,
129+
]
130+
| None
131+
) = None,
116132
) -> tuple[list[tuple], list[dict]]:
117133
"""
118134
Given a sequence of args and kwargs, split them into a number of chunks
@@ -134,11 +150,13 @@ def split_args_kwargs_into_chunks(
134150
kwargs = {}
135151

136152
if args_chunk_spec is None:
137-
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
153+
args_chunk_spec = map_structure(
154+
lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), args
155+
)
138156

139157
if kwargs_chunk_spec is None:
140-
kwargs_chunk_spec = dict.fromkeys(
141-
kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)
158+
kwargs_chunk_spec = map_structure(
159+
lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), kwargs
142160
)
143161

144162
args_split_dict = _split_args_helper(
@@ -186,20 +204,21 @@ def merge_chunks(
186204
return chunks
187205

188206
if chunk_spec is None:
189-
chunk0_flat = flatten(chunks[0])
190-
# the number of args need to be merged
191-
num_args = len(chunk0_flat)
192-
chunk_spec = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * num_args
207+
chunk_spec = map_structure(
208+
lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), chunks[0]
209+
)
193210

194211
chunks_flat = []
212+
# flatten chunk_spec first
213+
chunk_spec = flatten(chunk_spec)
195214
for chunk in chunks:
196215
chunk_flat = flatten(chunk)
197216
assert len(chunk_flat) == len(
198217
chunk_spec
199218
), f"Chunk {chunk} did not match chunk spec {chunk_spec}"
200219
chunks_flat.append(chunk_flat)
201220

202-
def merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None):
221+
def _merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None):
203222
# use the first chunk's value as the merged result
204223
arg_0 = chunks[0][idx]
205224
for chunk_idx in range(1, len(chunks)):
@@ -226,11 +245,11 @@ def merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None):
226245
"The TensorChunkSpec only supports paddle.Tensor type."
227246
)
228247

229-
merged_arg = merge_non_tensor_type_arg(
248+
merged_arg = _merge_non_tensor_type_arg(
230249
chunks_flat, arg_idx, chunk_spec_of_arg
231250
)
232251
else:
233-
merged_arg = merge_non_tensor_type_arg(
252+
merged_arg = _merge_non_tensor_type_arg(
234253
chunks_flat, arg_idx, chunk_spec_of_arg
235254
)
236255

python/paddle/utils/layers_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _packed_nest_with_indices(structure, flat, index):
219219
packed.append(_sequence_like(s, child))
220220
index = new_index
221221
else:
222-
# Paddle requires python version > 3.7, so dict is always OrderedDict
222+
# Paddle requires python version > 3.7, so dict is
223223
packed.append(
224224
flat[index]
225225
if not isinstance(flat, dict)

0 commit comments

Comments
 (0)