Skip to content

Commit b273aca

Browse files
committed
add unit test and refine EPHybridCommunicateGroup
1 parent 16135d7 commit b273aca

File tree

2 files changed

+75
-17
lines changed

2 files changed

+75
-17
lines changed

python/paddle/distributed/fleet/base/topology.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -619,34 +619,50 @@ def __init__(
619619
self.global_rank = paddle.distributed.get_rank()
620620

621621
dim_dict = dict(zip(hybrid_group_names, dims))
622-
self._ep_degree = dim_dict['expert']
623-
self._moe_sharding_degree = dim_dict['moe_sharding']
624-
self._moe_pp_degree = dim_dict['pipe']
625-
self._dp_degree = dim_dict['data']
626-
self._mp_degree = dim_dict['model']
627-
self._pp_degree = dim_dict['pipe']
628-
self._sharding_degree = dim_dict['sharding']
629-
self._sep_degree = dim_dict['sep']
622+
self._ep_degree = dim_dict.get('expert', 1)
623+
self._moe_sharding_degree = dim_dict.get('moe_sharding', 1)
624+
self._moe_pp_degree = dim_dict.get('pipe', 1)
625+
self._dp_degree = dim_dict.get('data', 1)
626+
self._mp_degree = dim_dict.get('model', 1)
627+
self._pp_degree = dim_dict.get('pipe', 1)
628+
self._sharding_degree = dim_dict.get('sharding', 1)
629+
self._sep_degree = dim_dict.get('sep', 1)
630630

631631
moe_hybrid_group_names = []
632632
moe_dims = []
633633
for name, dim in zip(hybrid_group_names, dims):
634634
if name in ["pipe", "moe_sharding", "expert"]:
635635
moe_hybrid_group_names.append(name)
636636
moe_dims.append(dim)
637+
assert (
638+
"moe_sharding" in moe_hybrid_group_names
639+
and "expert" in moe_hybrid_group_names
640+
)
637641

638642
self._moe_topo = CommunicateTopology(moe_hybrid_group_names, moe_dims)
639643
dim_dict["dense_sharding"] = (
640644
dim_dict["sharding"] // dim_dict["moe_sharding"]
641645
)
642-
dense_group_names = [
643-
"moe_sharding",
644-
"pipe",
645-
"dense_sharding",
646-
"data",
647-
"sep",
648-
"model",
649-
]
646+
if hybrid_group_names.index("pipe") > hybrid_group_names.index(
647+
"moe_sharding"
648+
):
649+
dense_group_names = [
650+
"moe_sharding",
651+
"pipe",
652+
"dense_sharding",
653+
"data",
654+
"sep",
655+
"model",
656+
]
657+
else:
658+
dense_group_names = [
659+
"pipe",
660+
"moe_sharding",
661+
"dense_sharding",
662+
"data",
663+
"sep",
664+
"model",
665+
]
650666
dense_dims = [dim_dict[name] for name in dense_group_names]
651667
self._dense_topo = CommunicateTopology(dense_group_names, dense_dims)
652668
self._moe_topo._parent_hcg = self

test/collective/fleet/hybrid_parallel_communicate_group.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
import paddle
1818
from paddle.distributed import fleet
19+
from paddle.distributed.fleet.base import topology as tp
1920

2021

2122
class TestNewGroupAPI:
2223
def __init__(self):
2324
paddle.distributed.init_parallel_env()
2425
topo = fleet.CommunicateTopology(
25-
["data", "model", "sharding", "pipe"], [2, 1, 1, 1]
26+
["data", "sep", "model", "sharding", "pipe"], [2, 1, 1, 1, 1]
2627
)
2728
self.hcg = fleet.HybridCommunicateGroup(topo)
2829

@@ -101,6 +102,47 @@ def test_all(self):
101102
print("test barrier api ok")
102103

103104

105+
class TestHybridEPGroup:
106+
def __init__(self):
107+
paddle.distributed.init_parallel_env()
108+
group_names = [
109+
"moe_sharding",
110+
"sharding",
111+
"pipe",
112+
"sep",
113+
"data",
114+
"expert",
115+
"model",
116+
]
117+
dims = [1, 1, 1, 1, 1, 2, 2]
118+
119+
self.hcg = tp.EPHybridCommunicateGroup(group_names, dims)
120+
121+
def test_all(self):
122+
global_rank = paddle.distributed.get_rank()
123+
124+
dp_rank = self.hcg.get_data_parallel_rank()
125+
assert dp_rank == 0
126+
assert self.hcg.get_expert_parallel_world_size() == 2
127+
assert self.hcg.get_moe_sharding_parallel_world_size() == 1
128+
assert self.hcg.get_model_parallel_world_size() == 2
129+
assert self.hcg.get_expert_parallel_rank() == global_rank
130+
assert self.hcg.get_moe_sharding_parallel_rank() == 0
131+
assert self.hcg.get_expert_parallel_group_src_rank() == 0
132+
assert (
133+
self.hcg.get_moe_sharding_parallel_group_src_rank() == global_rank
134+
)
135+
136+
moe_sharding_group = self.hcg.get_moe_sharding_parallel_group()
137+
ep_group = self.hcg.get_expert_parallel_group()
138+
mp_group = self.hcg.get_model_parallel_group()
139+
assert moe_sharding_group.ranks == [global_rank]
140+
assert ep_group.ranks == [0, 1]
141+
assert mp_group.ranks == [0, 1]
142+
143+
104144
if __name__ == "__main__":
105145
gpt = TestNewGroupAPI()
106146
gpt.test_all()
147+
ep_test = TestHybridEPGroup()
148+
ep_test.test_all()

0 commit comments

Comments
 (0)