18
18
from typing import Any
19
19
20
20
import paddle
21
- from paddle .utils import flatten , pack_sequence_as
21
+ from paddle .utils import flatten , map_structure , pack_sequence_as
22
22
23
23
logger = logging .getLogger (__name__ )
24
24
@@ -101,8 +101,8 @@ def _split_args_helper(
101
101
)
102
102
chunk_args [key ] = arg_of_curr_chunk
103
103
104
- # pack chunk_args as the origin args_dict
105
- chunk_args = pack_sequence_as (args_dict , chunk_args )
104
+ # flatten chunk_args first, and then pack chunk_args as the origin args_dict
105
+ chunk_args = pack_sequence_as (args_dict , flatten ( chunk_args ) )
106
106
args_split .append (chunk_args )
107
107
return args_split
108
108
@@ -111,8 +111,24 @@ def split_args_kwargs_into_chunks(
111
111
args : tuple [Any , ...],
112
112
kwargs : dict [str , Any ] | None ,
113
113
chunks : int ,
114
- args_chunk_spec : tuple [TensorChunkSpec , ...] | None = None ,
115
- kwargs_chunk_spec : dict [str , TensorChunkSpec ] | None = None ,
114
+ args_chunk_spec : (
115
+ tuple [
116
+ tuple [TensorChunkSpec , ...]
117
+ | list [TensorChunkSpec , ...]
118
+ | TensorChunkSpec ,
119
+ ...,
120
+ ]
121
+ | None
122
+ ) = None ,
123
+ kwargs_chunk_spec : (
124
+ dict [
125
+ str ,
126
+ tuple [TensorChunkSpec , ...]
127
+ | list [TensorChunkSpec , ...]
128
+ | TensorChunkSpec ,
129
+ ]
130
+ | None
131
+ ) = None ,
116
132
) -> tuple [list [tuple ], list [dict ]]:
117
133
"""
118
134
Given a sequence of args and kwargs, split them into a number of chunks
@@ -134,11 +150,13 @@ def split_args_kwargs_into_chunks(
134
150
kwargs = {}
135
151
136
152
if args_chunk_spec is None :
137
- args_chunk_spec = (TensorChunkSpec (DEFAULT_CHUNK_DIM ),) * len (args )
153
+ args_chunk_spec = map_structure (
154
+ lambda _ : TensorChunkSpec (DEFAULT_CHUNK_DIM ), args
155
+ )
138
156
139
157
if kwargs_chunk_spec is None :
140
- kwargs_chunk_spec = dict . fromkeys (
141
- kwargs , TensorChunkSpec (DEFAULT_CHUNK_DIM )
158
+ kwargs_chunk_spec = map_structure (
159
+ lambda _ : TensorChunkSpec (DEFAULT_CHUNK_DIM ), kwargs
142
160
)
143
161
144
162
args_split_dict = _split_args_helper (
@@ -186,20 +204,21 @@ def merge_chunks(
186
204
return chunks
187
205
188
206
if chunk_spec is None :
189
- chunk0_flat = flatten (chunks [0 ])
190
- # the number of args need to be merged
191
- num_args = len (chunk0_flat )
192
- chunk_spec = [TensorChunkSpec (DEFAULT_CHUNK_DIM )] * num_args
207
+ chunk_spec = map_structure (
208
+ lambda _ : TensorChunkSpec (DEFAULT_CHUNK_DIM ), chunks [0 ]
209
+ )
193
210
194
211
chunks_flat = []
212
+ # flatten chunk_spec first
213
+ chunk_spec = flatten (chunk_spec )
195
214
for chunk in chunks :
196
215
chunk_flat = flatten (chunk )
197
216
assert len (chunk_flat ) == len (
198
217
chunk_spec
199
218
), f"Chunk { chunk } did not match chunk spec { chunk_spec } "
200
219
chunks_flat .append (chunk_flat )
201
220
202
- def merge_non_tensor_type_arg (chunks , idx , chunk_spec_of_arg = None ):
221
+ def _merge_non_tensor_type_arg (chunks , idx , chunk_spec_of_arg = None ):
203
222
# use the first chunk's value as the merged result
204
223
arg_0 = chunks [0 ][idx ]
205
224
for chunk_idx in range (1 , len (chunks )):
@@ -226,11 +245,11 @@ def merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None):
226
245
"The TensorChunkSpec only supports paddle.Tensor type."
227
246
)
228
247
229
- merged_arg = merge_non_tensor_type_arg (
248
+ merged_arg = _merge_non_tensor_type_arg (
230
249
chunks_flat , arg_idx , chunk_spec_of_arg
231
250
)
232
251
else :
233
- merged_arg = merge_non_tensor_type_arg (
252
+ merged_arg = _merge_non_tensor_type_arg (
234
253
chunks_flat , arg_idx , chunk_spec_of_arg
235
254
)
236
255
0 commit comments