Skip to content

Commit 029c365

Browse files
committed
fix
1 parent dcf2e37 commit 029c365

File tree

1 file changed

+20
-1
lines changed
  • python/paddle/distributed/auto_parallel/static

1 file changed

+20
-1
lines changed

python/paddle/distributed/auto_parallel/static/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
wrap_decorator,
2929
)
3030
from paddle.framework import core
31+
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
3132
from paddle.static import Variable
3233

3334
from ..process_mesh import ProcessMesh, merge_process_meshes
@@ -873,7 +874,25 @@ def get_dist_attr(program, dist_context=None):
873874
"dim_names": process_mesh.dim_names,
874875
}
875876
else:
876-
raise NotImplementedError("get_dist_attr() not support old IR")
877+
from .dist_context import get_default_distributed_context
878+
879+
assert isinstance(program, paddle.static.Program)
880+
if dist_context is None:
881+
dist_context = get_default_distributed_context()
882+
for var in program.list_vars():
883+
if is_parameter(var) or is_belong_to_optimizer(var):
884+
tensor_dist_attr = (
885+
dist_context.get_tensor_dist_attr_for_program(var)
886+
)
887+
process_mesh = tensor_dist_attr.process_mesh
888+
dims_mapping = tensor_dist_attr.dims_mapping
889+
dim_names = tensor_dist_attr.process_mesh.dim_names
890+
dist_attr[var.name] = {
891+
"process_shape": process_mesh.shape,
892+
"process_group": process_mesh.process_ids,
893+
"dims_mapping": dims_mapping,
894+
"dim_names": dim_names,
895+
}
877896
return dist_attr
878897

879898

0 commit comments

Comments
 (0)