Skip to content

Commit 319494b

Browse files
authored
cherry-pick static save load (#60033)
1 parent 5236d47 commit 319494b

12 files changed

+495
-232
lines changed

python/paddle/base/dygraph/tensor_patch_methods.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,22 @@ def set_value(self, value):
229229
# if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc
230230
# this Interface behavior will be unifed in the future.
231231
if self.is_dist():
232-
# calling set method bound for DistTensor
233-
value = paddle.distributed.shard_tensor(
234-
value, self.value().process_mesh, self.value().placements
235-
)
232+
if isinstance(value, paddle.Tensor) and value.is_dist():
233+
from paddle.distributed.auto_parallel.placement_type import (
234+
check_placements_equal,
235+
)
236+
237+
# TODO: support reshard later
238+
assert value.process_mesh == self.value().process_mesh or check_placements_equal(
239+
value.placements, self.value().placements
240+
), f"process_mesh:{value.process_mesh} != {self.value().process_mesh} or placements:{value.placements} != {self.value().placements} not match"
241+
else:
242+
# calling set method bound for DistTensor
243+
value = paddle.distributed.shard_tensor(
244+
value,
245+
self.value().process_mesh,
246+
self.value().placements,
247+
)
236248
self.value().get_tensor().set(value.get_tensor())
237249
return
238250
self.value().get_tensor().set(

python/paddle/distributed/auto_parallel/api.py

+104-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from collections import defaultdict
1616
from typing import Callable
1717

18+
import numpy as np
19+
1820
import paddle
1921
import paddle.distributed as dist
2022
from paddle import nn
@@ -28,6 +30,7 @@
2830
from paddle.distributed.auto_parallel.interface import (
2931
shard_tensor as shard_tensor_static,
3032
)
33+
from paddle.distributed.auto_parallel.placement_type import to_placements
3134
from paddle.distributed.auto_parallel.static.completion import (
3235
mark_as_sharding_propagation_skip_op,
3336
)
@@ -37,10 +40,11 @@
3740
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
3841
from paddle.distributed.auto_parallel.static.utils import (
3942
convert_to_dims_mapping,
43+
get_dist_attr,
4044
)
4145
from paddle.framework import core
4246

43-
from .placement_type import get_shard_spec
47+
from .placement_type import check_placements_equal, get_shard_spec
4448

4549
# There are the auto parallel API of the unified version of dynamic and static mode.
4650
# Some APIs have the same name with the previous APIs implementation, which are
@@ -321,6 +325,100 @@ def __call__(self, *args):
321325
else:
322326
return None
323327

328+
def state_dict(self, mode="all"):
329+
"""
330+
Get the state dict of model and optimizer.
331+
332+
Args:
333+
mode (str): Can be ['opt', 'param', 'all'],
334+
'opt' : The return value only contains the variable in the optimizer.
335+
'param' : The return value only contains the variable in the network, not the variable in the optimizer.
336+
'all' : The return value contains the variable in the network and optimizer.
337+
Default: 'all'
338+
"""
339+
local_state_dict = self.dist_main_program(
340+
mode=self._engine._mode
341+
).state_dict(mode)
342+
dist_state_dict = self._build_distributed_state_dict(local_state_dict)
343+
return dist_state_dict
344+
345+
def _build_distributed_state_dict(self, local_state_dict):
346+
"""
347+
Args:
348+
local_state_dict(Dict[str, libpaddle.Tensor]): The state dict from program.
349+
"""
350+
dist_main_program = self.dist_main_program(mode=self._engine._mode)
351+
dist_context = self._engine._dist_contexts[self._mode]
352+
# Dict[var.name, Dict["process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping]]
353+
dist_attrs = get_dist_attr(dist_main_program, dist_context)
354+
355+
def build_distributed_tensor(local_tensor, dist_attr):
356+
assert isinstance(
357+
local_tensor, (paddle.Tensor, np.ndarray, paddle.base.Tensor)
358+
)
359+
if not isinstance(local_tensor, paddle.Tensor):
360+
local_tensor = paddle.Tensor(local_tensor)
361+
assert isinstance(
362+
local_tensor, paddle.Tensor
363+
), f"local tensor:{local_tensor} type {type(local_tensor)} is not paddle.Tensor."
364+
assert len(local_tensor.shape) == len(
365+
dist_attr["dims_mapping"]
366+
), f"local tensor shape {local_tensor.shape} not equal to dims_mapping shape {dist_attr['dims_mapping']}."
367+
global_shape = local_tensor.shape
368+
for i, dim in enumerate(dist_attr["dims_mapping"]):
369+
assert dim >= -1 and dim < len(
370+
local_tensor.shape
371+
), f"dim {dim} out of range."
372+
if dim == -1:
373+
continue
374+
elif dim >= 0:
375+
global_shape[i] = (
376+
dist_attr["process_shape"][dim] * local_tensor.shape[i]
377+
)
378+
else:
379+
raise ValueError(f"dim {dim} is not supported.")
380+
# TODO(pangengzheng): construct dist_tensor with _dtensor_from_local api when it is ready.
381+
global_tensor = paddle.zeros(global_shape, dtype=local_tensor.dtype)
382+
mesh = dist.ProcessMesh(
383+
np.array(dist_attr["process_group"]).reshape(
384+
dist_attr["process_shape"]
385+
)
386+
)
387+
placements = to_placements(dist_attr["dims_mapping"], mesh)
388+
dist_tensor = dist.shard_tensor(global_tensor, mesh, placements)
389+
assert (
390+
dist_tensor._local_value().shape == local_tensor.shape
391+
), f"local tensor shape {dist_tensor._local_value().shape} not equal to local_tensor.shape:{local_tensor.shape}"
392+
paddle.assign(local_tensor, dist_tensor._local_value())
393+
return dist_tensor
394+
395+
global_state_dict = {}
396+
with paddle.base.dygraph.guard():
397+
for var_name, tensor in local_state_dict.items():
398+
assert (
399+
var_name in dist_attrs
400+
), f"var {var_name} not in dist attrs:{dist_attrs}."
401+
global_state_dict[var_name] = build_distributed_tensor(
402+
tensor, dist_attrs[var_name]
403+
)
404+
return global_state_dict
405+
406+
def set_state_dict(self, state_dict):
407+
local_state_dict = {}
408+
dist_main_program = self.dist_main_program(mode=self._engine._mode)
409+
cur_state_dict = self.state_dict()
410+
for k, v in state_dict.items():
411+
assert v.is_dist(), f"key {k} value:{v} is not a dist tensor."
412+
if k in cur_state_dict:
413+
cur_v = cur_state_dict[k]
414+
assert v.process_mesh == cur_state_dict[
415+
k
416+
].process_mesh or check_placements_equal(
417+
v.placements, cur_v.placements
418+
), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match"
419+
local_state_dict[k] = v._local_value()
420+
dist_main_program.set_state_dict(local_state_dict)
421+
324422

325423
# Part2: DistTensor construction related APIs
326424

@@ -437,6 +535,7 @@ def sharding(self):
437535
438536
Examples:
439537
.. code-block:: python
538+
440539
>>> import paddle
441540
>>> import paddle.distributed as dist
442541
@@ -462,6 +561,7 @@ def gradient_merge(self):
462561
463562
Examples:
464563
.. code-block:: python
564+
465565
>>> import paddle
466566
>>> import paddle.distributed as dist
467567
@@ -488,6 +588,7 @@ def fused_passes(self):
488588
489589
Examples:
490590
.. code-block:: python
591+
491592
>>> import paddle
492593
>>> import paddle.distributed as dist
493594
@@ -515,6 +616,7 @@ def pipeline(self):
515616
516617
Examples:
517618
.. code-block:: python
619+
518620
>>> import paddle
519621
>>> import paddle.distributed as dist
520622
@@ -563,6 +665,7 @@ def to_static(
563665
564666
Examples:
565667
.. code-block:: python
668+
566669
>>> import numpy as np
567670
>>> import paddle
568671
>>> import paddle.distributed as dist

python/paddle/distributed/auto_parallel/placement_type.py

+14
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ def to_placements(dim_map, mesh, partial_idx=[]):
5050
return placements
5151

5252

53+
def check_placements_equal(this, that):
54+
assert isinstance(this, list) and isinstance(that, list)
55+
small_placemets = this if len(this) < len(that) else that
56+
large_placements = that if len(this) < len(that) else this
57+
for i in range(len(large_placements)):
58+
if i < len(small_placemets):
59+
if small_placemets[i] != large_placements[i]:
60+
return False
61+
else:
62+
if large_placements[i] != Replicate():
63+
return False
64+
return True
65+
66+
5367
def to_dim_map(placements, tensor_dims):
5468
"""
5569
convert placements to dim_map.

0 commit comments

Comments
 (0)