Skip to content

Commit 2ed6eea

Browse files
committed
add hybrid expert parallel communication group
1 parent e9f1826 commit 2ed6eea

File tree

7 files changed

+382
-17
lines changed

7 files changed

+382
-17
lines changed

paddle/fluid/framework/distributed_strategy.proto

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ message HybridConfig {
119119
optional DygraphShardingConfig sharding_configs = 8;
120120
optional bool enable_optimizer_timer = 9 [ default = false ];
121121
optional bool split_norm_comm = 10 [ default = false ];
122+
optional int32 ep_degree = 11 [ default = 1 ];
123+
optional int32 moe_sharding_degree = 12 [ default = 1 ];
122124
}
123125

124126
message AMPConfig {

python/paddle/distributed/fleet/base/topology.py

+280-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import collections
17+
import math
1718
import os
1819
from functools import reduce
1920
from itertools import product
@@ -341,11 +342,13 @@ def _check_sep_exist(self) -> None:
341342
assert self._sep_degree > 1, "sep not exist"
342343

343344
def _set_comm_group(
344-
self, parallel_method: str = "data"
345+
self, parallel_method: str = "data", topo: CommunicateTopology = None
345346
) -> tuple[list[int], Group]:
346347
parallel_group = []
347348
parallel_comm_group = None
348-
parallel_groups = self._topo.get_comm_list(parallel_method)
349+
if topo is None:
350+
topo = self._topo
351+
parallel_groups = topo.get_comm_list(parallel_method)
349352

350353
group_nccl_comm_init_option = (
351354
g_pipeline_nccl_comm_init_option
@@ -370,11 +373,13 @@ def _set_comm_group(
370373
return parallel_group, parallel_comm_group
371374

372375
def _set_check_group(
373-
self, parallel_method: str = "data"
376+
self, parallel_method: str = "data", topo: CommunicateTopology = None
374377
) -> tuple[list[int], Group]:
375378
parallel_group = []
376379
parallel_comm_group = None
377-
parallel_size = self._topo.get_dim(parallel_method)
380+
if topo is None:
381+
topo = self._topo
382+
parallel_size = topo.get_dim(parallel_method)
378383
for idx in range(parallel_size):
379384
parallel_groups = self._topo.get_axis_list(parallel_method, idx)
380385
comm_group = paddle.distributed.new_group(ranks=parallel_groups)
@@ -563,6 +568,9 @@ def get_pp_mp_parallel_group(self) -> Group:
563568
self._check_sep_exist()
564569
return self._pp_mp_comm_group
565570

571+
def get_moe_sharding_parallel_world_size(self) -> int:
572+
return 0
573+
566574
def create_fuse_group(
567575
self, fused_strategy_list: list[str]
568576
) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]:
@@ -593,6 +601,274 @@ def create_fuse_group(
593601
return parallel_group[0], parallel_comm_group[0]
594602

595603

604+
class EPHybridCommunicateGroup(HybridCommunicateGroup):
605+
def __init__(
606+
self,
607+
hybrid_group_names: list[str] = [
608+
"pipe",
609+
"moe_sharding",
610+
"expert",
611+
"data",
612+
"sharding",
613+
"sep",
614+
"model",
615+
],
616+
dims: list[int] = [1, 1, 1, 1, 1, 1, 1],
617+
) -> None:
618+
self.nranks = paddle.distributed.get_world_size()
619+
self.global_rank = paddle.distributed.get_rank()
620+
621+
dim_dict = dict(zip(hybrid_group_names, dims))
622+
self._ep_degree = dim_dict['expert']
623+
self._moe_sharding_degree = dim_dict['moe_sharding']
624+
self._moe_pp_degree = dim_dict['pipe']
625+
self._dp_degree = dim_dict['data']
626+
self._mp_degree = dim_dict['model']
627+
self._pp_degree = dim_dict['pipe']
628+
self._sharding_degree = dim_dict['sharding']
629+
self._sep_degree = dim_dict['sep']
630+
631+
moe_hybrid_group_names = []
632+
moe_dims = []
633+
for name, dim in zip(hybrid_group_names, dims):
634+
if name in ["pipe", "moe_sharding", "expert"]:
635+
moe_hybrid_group_names.append(name)
636+
moe_dims.append(dim)
637+
638+
self._moe_topo = CommunicateTopology(moe_hybrid_group_names, moe_dims)
639+
dim_dict["dense_sharding"] = (
640+
dim_dict["sharding"] // dim_dict["moe_sharding"]
641+
)
642+
dense_group_names = [
643+
"moe_sharding",
644+
"pipe",
645+
"dense_sharding",
646+
"data",
647+
"sep",
648+
"model",
649+
]
650+
dense_dims = [dim_dict[name] for name in dense_group_names]
651+
self._dense_topo = CommunicateTopology(dense_group_names, dense_dims)
652+
self._moe_topo._parent_hcg = self
653+
self._dense_topo._parent_hcg = self
654+
self._topo = self._dense_topo
655+
656+
self._data_parallel_id = self._get_parallel_id(self._dense_topo, "data")
657+
self._model_parallel_id = self._get_parallel_id(
658+
self._dense_topo, "model"
659+
)
660+
self._sharding_parallel_id = self._get_sharding_parallel_id()
661+
self._sep_parallel_id = self._get_parallel_id(self._dense_topo, "sep")
662+
self.stage_id = self._get_parallel_id(self._moe_topo, "pipe")
663+
self._expert_parallel_id = self._get_parallel_id(
664+
self._moe_topo, "expert"
665+
)
666+
self._moe_sharding_parallel_id = self._get_parallel_id(
667+
self._moe_topo, "moe_sharding"
668+
)
669+
670+
assert (
671+
self._moe_pp_degree == self._pp_degree
672+
), f"Mismatch moe_pp_degree:{self._moe_pp_degree}, pp_degree:{self._pp_degree}."
673+
assert (
674+
self._topo._world_size == self._moe_topo._world_size
675+
), f"Mismatch world_size:{self._topo._world_size}, moe_world_size:{self._moe_topo._world_size}."
676+
assert (
677+
self._sep_degree == 1 and self._dp_degree == 1
678+
), f"sep_degree {self._sep_degree} and dp_degree {self._dp_degree} must be 1 in MoE."
679+
680+
self._pp_group, self._pp_comm_group = self._set_comm_group(
681+
"pipe", self._moe_topo
682+
)
683+
paddle.distributed.all_reduce(
684+
paddle.zeros([1], dtype="int32"),
685+
op=paddle.distributed.ReduceOp.SUM,
686+
group=self._pp_comm_group,
687+
)
688+
env_name = "FLAGS_eager_communication_connection"
689+
if paddle.get_flags(env_name)[env_name]:
690+
if self._pp_comm_group.nranks > 1:
691+
self._pp_comm_group.process_group.eager_connect_ring_exchange()
692+
693+
# create comm group for expert parallel
694+
self._ep_group, self._ep_comm_group = self._set_comm_group(
695+
"expert", self._moe_topo
696+
)
697+
698+
# create comm group for sharding parallel in MoE layer
699+
self._moe_sharding_group, self._moe_sharding_comm_group = (
700+
self._set_comm_group("moe_sharding", self._moe_topo)
701+
)
702+
703+
# create comm group for data parallel
704+
self._dp_group, self._dp_comm_group = self._set_comm_group(
705+
"data", self._dense_topo
706+
)
707+
708+
# create comm group for sep parallel
709+
self._sep_group, self._sep_comm_group = self._set_comm_group(
710+
"sep", self._dense_topo
711+
)
712+
713+
# create comm group for model parallel
714+
self._mp_group, self._mp_comm_group = self._set_comm_group(
715+
"model", self._dense_topo
716+
)
717+
718+
# create comm group for sharding parallel
719+
self._sharding_group, self._sharding_comm_group = (
720+
self.build_sharding_group(self._dense_topo)
721+
)
722+
723+
# create global group for check inf_nan / clip global norm
724+
self._check_group, self._check_comm_group = self._set_check_group(
725+
"data", self._dense_topo
726+
)
727+
self.sharding_group, self.sharding_check_comm_group = (
728+
self.build_sharding_group(self._dense_topo)
729+
)
730+
731+
# (
732+
# self.sharding_check_group,
733+
# self.sharding_check_comm_group,
734+
# ) = self._set_check_group("sharding")
735+
736+
# create p2p group
737+
self.is_first_stage = self.stage_id == 0
738+
self.is_last_stage = self.stage_id == (self._pp_degree - 1)
739+
740+
# create p2p_groups
741+
if self._pp_degree > 1:
742+
if paddle.framework.core.is_compiled_with_nccl():
743+
check_nccl_version_for_p2p()
744+
self._set_p2p_prev_next()
745+
if _use_four_directions:
746+
self._set_four_directions_p2p_group()
747+
748+
debug_str = (
749+
f"HybridParallelInfo: rank_id: {self.global_rank}, mp_degree: {self._mp_degree}, "
750+
f"sharding_degree: {self._sharding_degree}, pp_degree: {self._pp_degree}, dp_degree: {self._dp_degree}, sep_degree: {self._sep_degree}, "
751+
f"ep_degree: {self._ep_degree}, moe_sharding_degree: {self._moe_sharding_degree}"
752+
)
753+
debug_str += f", mp_group: {self._mp_group}, sharding_group: {self._sharding_group}, pp_group: {self._pp_group}, dp_group: {self._dp_group}, sep_group: {self._sep_group}, check/clip group: {self._check_group}, ep_group: {self._ep_group}, moe_sharding_group: {self._moe_sharding_group}."
754+
logger.info(debug_str)
755+
756+
global _HYBRID_PARALLEL_GROUP
757+
_HYBRID_PARALLEL_GROUP = self
758+
759+
def build_sharding_group(self, topo):
760+
parallel_group = []
761+
parallel_comm_group = None
762+
763+
parallel_groups = self.merge_inner_comm_list(
764+
topo, "moe_sharding", "dense_sharding"
765+
)
766+
767+
group_nccl_comm_init_option = 0
768+
769+
for group in parallel_groups:
770+
comm_group = paddle.distributed.new_group(
771+
ranks=group,
772+
nccl_comm_init_option=group_nccl_comm_init_option,
773+
)
774+
if self.global_rank in group:
775+
parallel_group = group
776+
parallel_comm_group = comm_group
777+
778+
assert len(parallel_group) > 0
779+
assert parallel_comm_group is not None
780+
781+
logger.info(
782+
f"Total {len(parallel_groups)} sharding comm group(s) create successfully!"
783+
)
784+
return parallel_group, parallel_comm_group
785+
786+
def merge_inner_comm_list(self, topo, outer_name, inner_name):
787+
"""
788+
merge all inner communication list whose rank-id are in
789+
the same outer communication list. E.g.:
790+
outer_comm_list: [[0, 4], [1, 5]]
791+
inner_comm_list: [[0, 2], [1, 3], [4, 6], [5, 7]]
792+
=> merged_inner_comm_list: [[0, 2, 4, 6], [1, 3, 5, 7]]
793+
"""
794+
inner_axis = topo._parallel_names.index(inner_name)
795+
outer_axis = topo._parallel_names.index(outer_name)
796+
inner_comm_list = topo.get_comm_list(inner_name)
797+
798+
num_merged_groups = len(inner_comm_list) // topo._dims[outer_axis]
799+
interval = (
800+
math.prod(topo._dims[(outer_axis + 1) :]) // topo._dims[inner_axis]
801+
)
802+
assert num_merged_groups > 0 and interval > 0
803+
804+
merged_comm_list = []
805+
for i in range(num_merged_groups):
806+
comm = []
807+
for j in range(topo._dims[outer_axis]):
808+
assert i + j * interval < len(
809+
inner_comm_list
810+
), f"Unexpected error in merge_inner_comm_list, {i}, {j}, {interval}, {len(inner_comm_list)}"
811+
comm += inner_comm_list[i + j * interval]
812+
merged_comm_list.append(comm)
813+
814+
return merged_comm_list
815+
816+
def find_col_idx(self, comm_list, global_rank):
817+
rows = len(comm_list)
818+
cols = len(comm_list[0])
819+
r = rows - 1
820+
c = 0
821+
822+
while r >= 0 and c < cols:
823+
current = comm_list[r][c]
824+
if current == global_rank:
825+
return c
826+
elif current < global_rank:
827+
c += 1
828+
else:
829+
r -= 1
830+
831+
return None
832+
833+
def _get_parallel_id(self, topo, parallel_type):
834+
comm_list = topo.get_comm_list(parallel_type)
835+
parallel_id = self.find_col_idx(comm_list, self.global_rank)
836+
assert parallel_id is not None
837+
return parallel_id
838+
839+
def _get_sharding_parallel_id(self):
840+
sharding_comm_list = self.merge_inner_comm_list(
841+
self._dense_topo, "moe_sharding", "dense_sharding"
842+
)
843+
parallel_id = self.find_col_idx(sharding_comm_list, self.global_rank)
844+
assert parallel_id is not None
845+
return parallel_id
846+
847+
def get_expert_parallel_rank(self) -> int:
848+
return self._expert_parallel_id
849+
850+
def get_expert_parallel_world_size(self) -> int:
851+
return self._ep_degree
852+
853+
def get_expert_parallel_group(self) -> Group:
854+
return self._ep_comm_group
855+
856+
def get_expert_parallel_group_src_rank(self) -> int:
857+
return self._ep_comm_group.ranks[0]
858+
859+
def get_moe_sharding_parallel_rank(self) -> int:
860+
return self._moe_sharding_parallel_id
861+
862+
def get_moe_sharding_parallel_world_size(self) -> int:
863+
return self._moe_sharding_degree
864+
865+
def get_moe_sharding_parallel_group(self) -> Group:
866+
return self._moe_sharding_comm_group
867+
868+
def get_moe_sharding_parallel_group_src_rank(self) -> int:
869+
return self._moe_sharding_comm_group.ranks[0]
870+
871+
596872
class _CommunicateGroup:
597873
"""tmp for static"""
598874

0 commit comments

Comments
 (0)