|
28 | 28 | wrap_decorator,
|
29 | 29 | )
|
30 | 30 | from paddle.framework import core
|
| 31 | +from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter |
31 | 32 | from paddle.static import Variable
|
32 | 33 |
|
33 | 34 | from ..process_mesh import ProcessMesh, merge_process_meshes
|
@@ -873,7 +874,25 @@ def get_dist_attr(program, dist_context=None):
|
873 | 874 | "dim_names": process_mesh.dim_names,
|
874 | 875 | }
|
875 | 876 | else:
|
876 |
| - raise NotImplementedError("get_dist_attr() not support old IR") |
| 877 | + from .dist_context import get_default_distributed_context |
| 878 | + |
| 879 | + assert isinstance(program, paddle.static.Program) |
| 880 | + if dist_context is None: |
| 881 | + dist_context = get_default_distributed_context() |
| 882 | + for var in program.list_vars(): |
| 883 | + if is_parameter(var) or is_belong_to_optimizer(var): |
| 884 | + tensor_dist_attr = ( |
| 885 | + dist_context.get_tensor_dist_attr_for_program(var) |
| 886 | + ) |
| 887 | + process_mesh = tensor_dist_attr.process_mesh |
| 888 | + dims_mapping = tensor_dist_attr.dims_mapping |
| 889 | + dim_names = tensor_dist_attr.process_mesh.dim_names |
| 890 | + dist_attr[var.name] = { |
| 891 | + "process_shape": process_mesh.shape, |
| 892 | + "process_group": process_mesh.process_ids, |
| 893 | + "dims_mapping": dims_mapping, |
| 894 | + "dim_names": dim_names, |
| 895 | + } |
877 | 896 | return dist_attr
|
878 | 897 |
|
879 | 898 |
|
|
0 commit comments