Skip to content

[AutoParallel] Enhance processmesh #72052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions python/paddle/distributed/auto_parallel/process_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
from __future__ import annotations

import copy
import logging
from typing import TYPE_CHECKING, Any, SupportsIndex, Union

import numpy as np

import paddle
from paddle.distributed.communication.group import is_initialized
from paddle.framework import core

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from types import TracebackType
Expand Down Expand Up @@ -204,7 +208,7 @@ def unique_id(self) -> int:
return self._unique_id

def __getitem__(
self, index: slice | tuple[slice, ...] | SupportsIndex
self, index: slice | tuple[slice, ...] | str | SupportsIndex
) -> ProcessMesh:
if isinstance(index, tuple):
new_dim_names = []
Expand All @@ -221,6 +225,8 @@ def __getitem__(
new_mesh = self._mesh[index]
new_dim_names = self._dim_names
return ProcessMesh(new_mesh, new_dim_names)
elif isinstance(index, str):
return self.get_submesh_with_dim(index)
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
Expand Down Expand Up @@ -281,9 +287,119 @@ def get_mesh_with_dim(
new_mesh = self._mesh.transpose(new_order)

if index is not None:
return ProcessMesh(new_mesh[index], new_dim_names[1:])
if len(new_dim_names[1:]) > 0:
return ProcessMesh(new_mesh[index], new_dim_names[1:])
# satisfy the single dimension mesh case
else:
return ProcessMesh([new_mesh[index]], new_dim_names)
return ProcessMesh(new_mesh, new_dim_names)

def get_submesh_with_dim(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个方法和get_mesh_with_dim有什么区别?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_mesh_with_dim 只是对 mesh 的一个简单重排,如mesh.get_mesh_with_dim(“dp”)只是把mesh的dp维放在最外维,并没有减少mesh内process_ids。mesh.get_submesh_with_dim("dp")则是获取包含当前rank的dp通信组的submesh。比如说:mesh_2d = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "tp"])
dp_mesh = mesh_2d.get_submesh_with_dim("dp")
on rank 0, 4 returns a 1D submesh of ProcessMesh:([0, 4]).
on rank 1, 5 returns a 1D submesh of ProcessMesh:([1, 5]).
on rank 2, 6 returns a 1D submesh of ProcessMesh:([2, 6]).
on rank 3, 7 returns a 1D submesh of ProcessMesh:([3, 7]).

self,
dim_name: str,
) -> ProcessMesh:
"""
Slice the current ProcessMesh based on the dim_name given to create a submesh with single dimension remained.

Args:
dim_name (str): the name of the mesh dimension of the ProcessMesh to create the submesh for.
Returns:
A :class:`ProcessMesh` object

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist

>>> dist.init_parallel_env()
>>> mesh_2d = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "tp"])

>>> dp_mesh = mesh_2d.get_submesh_with_dim("dp")
>>> # ProcessMesh:([0, 4]) on rank 0, 4
>>> # ProcessMesh:([1, 5]) on rank 1, 5
>>> # ProcessMesh:([2, 6]) on rank 2, 6
>>> # ProcessMesh:([3, 7]) on rank 3, 7

>>> tp_mesh = mesh_2d.get_submesh_with_dim("tp")
>>> # ProcessMesh:([0, 1, 2, 3]) on rank 0, 1, 2, 3
>>> # ProcessMesh:([4, 5, 6, 7]) on rank 4, 5, 6, 7

>>> mesh_3d = dist.ProcessMesh([[[0, 1],[2, 3]], [[4, 5], [6, 7]]], dim_names=["pp","dp","tp"])

>>> pp_mesh = mesh_3d.get_submesh_with_dim("pp")
>>> # ProcessMesh:([0, 4]) on rank 0, 4
>>> # ProcessMesh:([1, 5]) on rank 1, 5
>>> # ProcessMesh:([2, 6]) on rank 2, 6
>>> # ProcessMesh:([3, 7]) on rank 3, 7

>>> dp_mesh = mesh_3d.get_submesh_with_dim("dp")
>>> # ProcessMesh:([0, 2]) on rank 0, 2
>>> # ProcessMesh:([1, 3]) on rank 1, 3
>>> # ProcessMesh:([4, 6]) on rank 4, 6
>>> # ProcessMesh:([5, 7]) on rank 5, 7

>>> tp_mesh = mesh_3d.get_submesh_with_dim("tp")
>>> # ProcessMesh:([0, 1]) on rank 0, 1
>>> # ProcessMesh:([2, 3]) on rank 2, 3
>>> # ProcessMesh:([4, 5]) on rank 4, 5
>>> # ProcessMesh:([6, 7]) on rank 6, 7
"""

reorder_mesh = self.get_mesh_with_dim(dim_name)._mesh.reshape(
self.get_dim_size(dim_name), -1
)
curr_rank = paddle.distributed.get_rank()
if curr_rank not in self._process_ids:
logger.warning(
f"Rank {curr_rank} is not in the process mesh, just return None"
)
return None
# find curr_rank in reorder_mesh, get the column index
col_idx = np.argmax(reorder_mesh == curr_rank) % reorder_mesh.shape[-1]
sub_mesh = ProcessMesh(reorder_mesh[:, col_idx], [dim_name])
return sub_mesh

def get_group(
self,
dim_name: str | None = None,
) -> paddle.distributed.communication.group.Group:
"""
Convert single dimension ProcessMesh to the corresponding Group.

Args:
dim_name (str, optional): it can be the name of the mesh dimension. Default is None.

Returns:
A :class:`Group` object.
"""

# check parallel environment whether ready or not
assert is_initialized(), (
"When you want to get a group from the ProcessMesh."
" Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment."
)
if len(self._dim_names) > 1 and dim_name is None:
raise ValueError(
"You should specify the dim_name when the ProcessMesh has more than one dimensions."
)
if len(self._dim_names) == 1:
if dim_name is not None and dim_name not in self._dim_names:
raise ValueError(
f"{dim_name} not in the dimension names {self._dim_names}"
)
else:
pg = paddle.distributed.new_group(self._process_ids)
return pg
else:
if dim_name not in self._dim_names:
raise ValueError(
f"{dim_name} not in the dimension names {self._dim_names}"
)
sub_mesh = self.get_submesh_with_dim(dim_name)
return sub_mesh.get_group(dim_name)

def __enter__(self) -> None:
set_current_process_mesh(self)
default_prog = paddle.static.default_main_program()
Expand Down
7 changes: 7 additions & 0 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,10 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_parallel_api_with_llama_lora
PROPERTIES TIMEOUT "360" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_process_mesh MODULES test_process_mesh ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "60" LABELS
"RUN_TYPE=HYBRID")
endif()
138 changes: 138 additions & 0 deletions test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
import paddle.distributed as dist


class TestProcessMesh:
def init_dist_env(self):
dist.init_parallel_env()
paddle.seed(2025)

def test_get_submesh_with_dim(self):
curr_rank = dist.get_rank()

# Test 2D mesh
mesh_2d = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["dp", "tp"])

# Test case 1: Get submesh for dp dimension
dp_mesh = mesh_2d.get_submesh_with_dim("dp")
dp_mesh_ = mesh_2d["dp"]
assert dp_mesh == dp_mesh_
if curr_rank == 0:
assert dp_mesh.process_ids == [0, 2]
elif curr_rank == 1:
assert dp_mesh.process_ids == [1, 3]

# Test case 2: Get submesh for tp dimension
tp_mesh = mesh_2d.get_submesh_with_dim("tp")
tp_mesh_ = mesh_2d["tp"]
assert tp_mesh == tp_mesh_
if curr_rank == 0:
assert tp_mesh.process_ids == [0, 1]
elif curr_rank == 1:
assert tp_mesh.process_ids == [0, 1]

# Test case 3: 3D mesh with 8 cards (2x2x2)
mesh_3d = dist.ProcessMesh(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["pp", "dp", "tp"]
)

# Test each dimension
pp_mesh = mesh_3d.get_submesh_with_dim("pp")
pp_mesh_ = mesh_3d["pp"]
assert pp_mesh == pp_mesh_
dp_mesh = mesh_3d.get_submesh_with_dim("dp")
dp_mesh_ = mesh_3d["dp"]
assert dp_mesh == dp_mesh_
tp_mesh = mesh_3d.get_submesh_with_dim("tp")
tp_mesh_ = mesh_3d["tp"]
assert tp_mesh == tp_mesh_

# Verify pp dimension results
if curr_rank == 0:
assert pp_mesh.process_ids == [0, 4]
elif curr_rank == 1:
assert pp_mesh.process_ids == [1, 5]

# Verify dp dimension results
if curr_rank == 0:
assert dp_mesh.process_ids == [0, 2]
elif curr_rank == 1:
assert dp_mesh.process_ids == [1, 3]

# Verify tp dimension results
if curr_rank == 0:
assert tp_mesh.process_ids == [0, 1]
elif curr_rank == 1:
assert tp_mesh.process_ids == [0, 1]

# Test case 4: When rank is not in the mesh
mesh_small = dist.ProcessMesh([0, 1], dim_names=["x"])
if curr_rank not in [0, 1]:
assert mesh_small.get_submesh_with_dim("x") is None

def test_get_group(self):
curr_rank = dist.get_rank()

# Test case 1: Single dimension mesh without dim_name
mesh_1d = dist.ProcessMesh([0, 1], dim_names=["x"])
if curr_rank in [0, 1]:
group_1d = mesh_1d.get_group()
assert isinstance(group_1d, dist.communication.group.Group)

# Test case 2: Single dimension mesh with correct dim_name
group_1d_with_name = mesh_1d.get_group(dim_name="x")
assert isinstance(
group_1d_with_name, dist.communication.group.Group
)

# Test case 3: Single dimension mesh with wrong dim_name
try:
mesh_1d.get_group(dim_name="wrong_name")
raise AssertionError("Should raise ValueError")
except ValueError:
pass

# Test case 4: Multi-dimension mesh
mesh_2d = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["dp", "tp"])
if curr_rank in [0, 1, 2, 3]:
# Test without dim_name
try:
mesh_2d.get_group()
raise AssertionError("Should raise ValueError")
except ValueError:
pass

# Test with correct dim_name
group_2d = mesh_2d.get_group(dim_name="dp")
assert isinstance(group_2d, dist.communication.group.Group)

# Test with wrong dim_name
try:
mesh_2d.get_group(dim_name="wrong_name")
raise AssertionError("Should raise ValueError")
except ValueError:
pass

def test_process_mesh(self):
self.init_dist_env()
self.test_get_submesh_with_dim()
self.test_get_group()


if __name__ == '__main__':
TestProcessMesh().test_process_mesh()
46 changes: 46 additions & 0 deletions test/auto_parallel/hybrid_strategy/test_process_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import collective.test_communication_api_base as test_base


class TestProcessMeshPass(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(
num_of_devices=2,
timeout=50,
)
self._default_envs = {
"FLAGS_cudnn_deterministic": "1",
"FLAGS_enable_pir_api": "1",
}
self._changeable_envs = {
"backend": ["gpu"],
}

def test_process_mesh(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"process_mesh_demo_unittest.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/auto_parallel/hybrid_strategy/testslist.csv
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ test_parallel_api_with_llama_2d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy
test_parallel_api_with_llama_3d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_to_distributed_api_for_llama,LINUX,GPU,180,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_api_with_llama_lora,LINUX,GPU,360,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_process_mesh,LINUX,GPU,60,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
Loading