15
15
from collections import defaultdict
16
16
from typing import Callable
17
17
18
+ import numpy as np
19
+
18
20
import paddle
19
21
import paddle .distributed as dist
20
22
from paddle import nn
28
30
from paddle .distributed .auto_parallel .interface import (
29
31
shard_tensor as shard_tensor_static ,
30
32
)
33
+ from paddle .distributed .auto_parallel .placement_type import to_placements
31
34
from paddle .distributed .auto_parallel .static .completion import (
32
35
mark_as_sharding_propagation_skip_op ,
33
36
)
37
40
from paddle .distributed .auto_parallel .static .dist_op import DistributedOperator
38
41
from paddle .distributed .auto_parallel .static .utils import (
39
42
convert_to_dims_mapping ,
43
+ get_dist_attr ,
40
44
)
41
45
from paddle .framework import core
42
46
43
- from .placement_type import get_shard_spec
47
+ from .placement_type import check_placements_equal , get_shard_spec
44
48
45
49
# There are the auto parallel API of the unified version of dynamic and static mode.
46
50
# Some APIs have the same name with the previous APIs implementation, which are
@@ -321,6 +325,100 @@ def __call__(self, *args):
321
325
else :
322
326
return None
323
327
328
+ def state_dict (self , mode = "all" ):
329
+ """
330
+ Get the state dict of model and optimizer.
331
+
332
+ Args:
333
+ mode (str): Can be ['opt', 'param', 'all'],
334
+ 'opt' : The return value only contains the variable in the optimizer.
335
+ 'param' : The return value only contains the variable in the network, not the variable in the optimizer.
336
+ 'all' : The return value contains the variable in the network and optimizer.
337
+ Default: 'all'
338
+ """
339
+ local_state_dict = self .dist_main_program (
340
+ mode = self ._engine ._mode
341
+ ).state_dict (mode )
342
+ dist_state_dict = self ._build_distributed_state_dict (local_state_dict )
343
+ return dist_state_dict
344
+
345
+ def _build_distributed_state_dict (self , local_state_dict ):
346
+ """
347
+ Args:
348
+ local_state_dict(Dict[str, libpaddle.Tensor]): The state dict from program.
349
+ """
350
+ dist_main_program = self .dist_main_program (mode = self ._engine ._mode )
351
+ dist_context = self ._engine ._dist_contexts [self ._mode ]
352
+ # Dict[var.name, Dict["process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping]]
353
+ dist_attrs = get_dist_attr (dist_main_program , dist_context )
354
+
355
+ def build_distributed_tensor (local_tensor , dist_attr ):
356
+ assert isinstance (
357
+ local_tensor , (paddle .Tensor , np .ndarray , paddle .base .Tensor )
358
+ )
359
+ if not isinstance (local_tensor , paddle .Tensor ):
360
+ local_tensor = paddle .Tensor (local_tensor )
361
+ assert isinstance (
362
+ local_tensor , paddle .Tensor
363
+ ), f"local tensor:{ local_tensor } type { type (local_tensor )} is not paddle.Tensor."
364
+ assert len (local_tensor .shape ) == len (
365
+ dist_attr ["dims_mapping" ]
366
+ ), f"local tensor shape { local_tensor .shape } not equal to dims_mapping shape { dist_attr ['dims_mapping' ]} ."
367
+ global_shape = local_tensor .shape
368
+ for i , dim in enumerate (dist_attr ["dims_mapping" ]):
369
+ assert dim >= - 1 and dim < len (
370
+ local_tensor .shape
371
+ ), f"dim { dim } out of range."
372
+ if dim == - 1 :
373
+ continue
374
+ elif dim >= 0 :
375
+ global_shape [i ] = (
376
+ dist_attr ["process_shape" ][dim ] * local_tensor .shape [i ]
377
+ )
378
+ else :
379
+ raise ValueError (f"dim { dim } is not supported." )
380
+ # TODO(pangengzheng): construct dist_tensor with _dtensor_from_local api when it is ready.
381
+ global_tensor = paddle .zeros (global_shape , dtype = local_tensor .dtype )
382
+ mesh = dist .ProcessMesh (
383
+ np .array (dist_attr ["process_group" ]).reshape (
384
+ dist_attr ["process_shape" ]
385
+ )
386
+ )
387
+ placements = to_placements (dist_attr ["dims_mapping" ], mesh )
388
+ dist_tensor = dist .shard_tensor (global_tensor , mesh , placements )
389
+ assert (
390
+ dist_tensor ._local_value ().shape == local_tensor .shape
391
+ ), f"local tensor shape { dist_tensor ._local_value ().shape } not equal to local_tensor.shape:{ local_tensor .shape } "
392
+ paddle .assign (local_tensor , dist_tensor ._local_value ())
393
+ return dist_tensor
394
+
395
+ global_state_dict = {}
396
+ with paddle .base .dygraph .guard ():
397
+ for var_name , tensor in local_state_dict .items ():
398
+ assert (
399
+ var_name in dist_attrs
400
+ ), f"var { var_name } not in dist attrs:{ dist_attrs } ."
401
+ global_state_dict [var_name ] = build_distributed_tensor (
402
+ tensor , dist_attrs [var_name ]
403
+ )
404
+ return global_state_dict
405
+
406
+ def set_state_dict (self , state_dict ):
407
+ local_state_dict = {}
408
+ dist_main_program = self .dist_main_program (mode = self ._engine ._mode )
409
+ cur_state_dict = self .state_dict ()
410
+ for k , v in state_dict .items ():
411
+ assert v .is_dist (), f"key { k } value:{ v } is not a dist tensor."
412
+ if k in cur_state_dict :
413
+ cur_v = cur_state_dict [k ]
414
+ assert v .process_mesh == cur_state_dict [
415
+ k
416
+ ].process_mesh or check_placements_equal (
417
+ v .placements , cur_v .placements
418
+ ), f"process_mesh:{ v .process_mesh } != { cur_v .process_mesh } or placements:{ v .placements } != { cur_v .placements } not match"
419
+ local_state_dict [k ] = v ._local_value ()
420
+ dist_main_program .set_state_dict (local_state_dict )
421
+
324
422
325
423
# Part2: DistTensor construction related APIs
326
424
@@ -437,6 +535,7 @@ def sharding(self):
437
535
438
536
Examples:
439
537
.. code-block:: python
538
+
440
539
>>> import paddle
441
540
>>> import paddle.distributed as dist
442
541
@@ -462,6 +561,7 @@ def gradient_merge(self):
462
561
463
562
Examples:
464
563
.. code-block:: python
564
+
465
565
>>> import paddle
466
566
>>> import paddle.distributed as dist
467
567
@@ -488,6 +588,7 @@ def fused_passes(self):
488
588
489
589
Examples:
490
590
.. code-block:: python
591
+
491
592
>>> import paddle
492
593
>>> import paddle.distributed as dist
493
594
@@ -515,6 +616,7 @@ def pipeline(self):
515
616
516
617
Examples:
517
618
.. code-block:: python
619
+
518
620
>>> import paddle
519
621
>>> import paddle.distributed as dist
520
622
@@ -563,6 +665,7 @@ def to_static(
563
665
564
666
Examples:
565
667
.. code-block:: python
668
+
566
669
>>> import numpy as np
567
670
>>> import paddle
568
671
>>> import paddle.distributed as dist
0 commit comments