14
14
from __future__ import annotations
15
15
16
16
import collections
17
+ import math
17
18
import os
18
19
from functools import reduce
19
20
from itertools import product
@@ -341,11 +342,13 @@ def _check_sep_exist(self) -> None:
341
342
assert self ._sep_degree > 1 , "sep not exist"
342
343
343
344
def _set_comm_group (
344
- self , parallel_method : str = "data"
345
+ self , parallel_method : str = "data" , topo : CommunicateTopology = None
345
346
) -> tuple [list [int ], Group ]:
346
347
parallel_group = []
347
348
parallel_comm_group = None
348
- parallel_groups = self ._topo .get_comm_list (parallel_method )
349
+ if topo is None :
350
+ topo = self ._topo
351
+ parallel_groups = topo .get_comm_list (parallel_method )
349
352
350
353
group_nccl_comm_init_option = (
351
354
g_pipeline_nccl_comm_init_option
@@ -370,11 +373,13 @@ def _set_comm_group(
370
373
return parallel_group , parallel_comm_group
371
374
372
375
def _set_check_group (
373
- self , parallel_method : str = "data"
376
+ self , parallel_method : str = "data" , topo : CommunicateTopology = None
374
377
) -> tuple [list [int ], Group ]:
375
378
parallel_group = []
376
379
parallel_comm_group = None
377
- parallel_size = self ._topo .get_dim (parallel_method )
380
+ if topo is None :
381
+ topo = self ._topo
382
+ parallel_size = topo .get_dim (parallel_method )
378
383
for idx in range (parallel_size ):
379
384
parallel_groups = self ._topo .get_axis_list (parallel_method , idx )
380
385
comm_group = paddle .distributed .new_group (ranks = parallel_groups )
@@ -563,6 +568,9 @@ def get_pp_mp_parallel_group(self) -> Group:
563
568
self ._check_sep_exist ()
564
569
return self ._pp_mp_comm_group
565
570
571
+ def get_moe_sharding_parallel_world_size (self ) -> int :
572
+ return 0
573
+
566
574
def create_fuse_group (
567
575
self , fused_strategy_list : list [str ]
568
576
) -> tuple [list [list [int ]], list [Group ]] | tuple [list [int ], Group ]:
@@ -593,6 +601,274 @@ def create_fuse_group(
593
601
return parallel_group [0 ], parallel_comm_group [0 ]
594
602
595
603
604
+ class EPHybridCommunicateGroup (HybridCommunicateGroup ):
605
+ def __init__ (
606
+ self ,
607
+ hybrid_group_names : list [str ] = [
608
+ "pipe" ,
609
+ "moe_sharding" ,
610
+ "expert" ,
611
+ "data" ,
612
+ "sharding" ,
613
+ "sep" ,
614
+ "model" ,
615
+ ],
616
+ dims : list [int ] = [1 , 1 , 1 , 1 , 1 , 1 , 1 ],
617
+ ) -> None :
618
+ self .nranks = paddle .distributed .get_world_size ()
619
+ self .global_rank = paddle .distributed .get_rank ()
620
+
621
+ dim_dict = dict (zip (hybrid_group_names , dims ))
622
+ self ._ep_degree = dim_dict ['expert' ]
623
+ self ._moe_sharding_degree = dim_dict ['moe_sharding' ]
624
+ self ._moe_pp_degree = dim_dict ['pipe' ]
625
+ self ._dp_degree = dim_dict ['data' ]
626
+ self ._mp_degree = dim_dict ['model' ]
627
+ self ._pp_degree = dim_dict ['pipe' ]
628
+ self ._sharding_degree = dim_dict ['sharding' ]
629
+ self ._sep_degree = dim_dict ['sep' ]
630
+
631
+ moe_hybrid_group_names = []
632
+ moe_dims = []
633
+ for name , dim in zip (hybrid_group_names , dims ):
634
+ if name in ["pipe" , "moe_sharding" , "expert" ]:
635
+ moe_hybrid_group_names .append (name )
636
+ moe_dims .append (dim )
637
+
638
+ self ._moe_topo = CommunicateTopology (moe_hybrid_group_names , moe_dims )
639
+ dim_dict ["dense_sharding" ] = (
640
+ dim_dict ["sharding" ] // dim_dict ["moe_sharding" ]
641
+ )
642
+ dense_group_names = [
643
+ "moe_sharding" ,
644
+ "pipe" ,
645
+ "dense_sharding" ,
646
+ "data" ,
647
+ "sep" ,
648
+ "model" ,
649
+ ]
650
+ dense_dims = [dim_dict [name ] for name in dense_group_names ]
651
+ self ._dense_topo = CommunicateTopology (dense_group_names , dense_dims )
652
+ self ._moe_topo ._parent_hcg = self
653
+ self ._dense_topo ._parent_hcg = self
654
+ self ._topo = self ._dense_topo
655
+
656
+ self ._data_parallel_id = self ._get_parallel_id (self ._dense_topo , "data" )
657
+ self ._model_parallel_id = self ._get_parallel_id (
658
+ self ._dense_topo , "model"
659
+ )
660
+ self ._sharding_parallel_id = self ._get_sharding_parallel_id ()
661
+ self ._sep_parallel_id = self ._get_parallel_id (self ._dense_topo , "sep" )
662
+ self .stage_id = self ._get_parallel_id (self ._moe_topo , "pipe" )
663
+ self ._expert_parallel_id = self ._get_parallel_id (
664
+ self ._moe_topo , "expert"
665
+ )
666
+ self ._moe_sharding_parallel_id = self ._get_parallel_id (
667
+ self ._moe_topo , "moe_sharding"
668
+ )
669
+
670
+ assert (
671
+ self ._moe_pp_degree == self ._pp_degree
672
+ ), f"Mismatch moe_pp_degree:{ self ._moe_pp_degree } , pp_degree:{ self ._pp_degree } ."
673
+ assert (
674
+ self ._topo ._world_size == self ._moe_topo ._world_size
675
+ ), f"Mismatch world_size:{ self ._topo ._world_size } , moe_world_size:{ self ._moe_topo ._world_size } ."
676
+ assert (
677
+ self ._sep_degree == 1 and self ._dp_degree == 1
678
+ ), f"sep_degree { self ._sep_degree } and dp_degree { self ._dp_degree } must be 1 in MoE."
679
+
680
+ self ._pp_group , self ._pp_comm_group = self ._set_comm_group (
681
+ "pipe" , self ._moe_topo
682
+ )
683
+ paddle .distributed .all_reduce (
684
+ paddle .zeros ([1 ], dtype = "int32" ),
685
+ op = paddle .distributed .ReduceOp .SUM ,
686
+ group = self ._pp_comm_group ,
687
+ )
688
+ env_name = "FLAGS_eager_communication_connection"
689
+ if paddle .get_flags (env_name )[env_name ]:
690
+ if self ._pp_comm_group .nranks > 1 :
691
+ self ._pp_comm_group .process_group .eager_connect_ring_exchange ()
692
+
693
+ # create comm group for expert parallel
694
+ self ._ep_group , self ._ep_comm_group = self ._set_comm_group (
695
+ "expert" , self ._moe_topo
696
+ )
697
+
698
+ # create comm group for sharding parallel in MoE layer
699
+ self ._moe_sharding_group , self ._moe_sharding_comm_group = (
700
+ self ._set_comm_group ("moe_sharding" , self ._moe_topo )
701
+ )
702
+
703
+ # create comm group for data parallel
704
+ self ._dp_group , self ._dp_comm_group = self ._set_comm_group (
705
+ "data" , self ._dense_topo
706
+ )
707
+
708
+ # create comm group for sep parallel
709
+ self ._sep_group , self ._sep_comm_group = self ._set_comm_group (
710
+ "sep" , self ._dense_topo
711
+ )
712
+
713
+ # create comm group for model parallel
714
+ self ._mp_group , self ._mp_comm_group = self ._set_comm_group (
715
+ "model" , self ._dense_topo
716
+ )
717
+
718
+ # create comm group for sharding parallel
719
+ self ._sharding_group , self ._sharding_comm_group = (
720
+ self .build_sharding_group (self ._dense_topo )
721
+ )
722
+
723
+ # create global group for check inf_nan / clip global norm
724
+ self ._check_group , self ._check_comm_group = self ._set_check_group (
725
+ "data" , self ._dense_topo
726
+ )
727
+ self .sharding_group , self .sharding_check_comm_group = (
728
+ self .build_sharding_group (self ._dense_topo )
729
+ )
730
+
731
+ # (
732
+ # self.sharding_check_group,
733
+ # self.sharding_check_comm_group,
734
+ # ) = self._set_check_group("sharding")
735
+
736
+ # create p2p group
737
+ self .is_first_stage = self .stage_id == 0
738
+ self .is_last_stage = self .stage_id == (self ._pp_degree - 1 )
739
+
740
+ # create p2p_groups
741
+ if self ._pp_degree > 1 :
742
+ if paddle .framework .core .is_compiled_with_nccl ():
743
+ check_nccl_version_for_p2p ()
744
+ self ._set_p2p_prev_next ()
745
+ if _use_four_directions :
746
+ self ._set_four_directions_p2p_group ()
747
+
748
+ debug_str = (
749
+ f"HybridParallelInfo: rank_id: { self .global_rank } , mp_degree: { self ._mp_degree } , "
750
+ f"sharding_degree: { self ._sharding_degree } , pp_degree: { self ._pp_degree } , dp_degree: { self ._dp_degree } , sep_degree: { self ._sep_degree } , "
751
+ f"ep_degree: { self ._ep_degree } , moe_sharding_degree: { self ._moe_sharding_degree } "
752
+ )
753
+ debug_str += f", mp_group: { self ._mp_group } , sharding_group: { self ._sharding_group } , pp_group: { self ._pp_group } , dp_group: { self ._dp_group } , sep_group: { self ._sep_group } , check/clip group: { self ._check_group } , ep_group: { self ._ep_group } , moe_sharding_group: { self ._moe_sharding_group } ."
754
+ logger .info (debug_str )
755
+
756
+ global _HYBRID_PARALLEL_GROUP
757
+ _HYBRID_PARALLEL_GROUP = self
758
+
759
+ def build_sharding_group (self , topo ):
760
+ parallel_group = []
761
+ parallel_comm_group = None
762
+
763
+ parallel_groups = self .merge_inner_comm_list (
764
+ topo , "moe_sharding" , "dense_sharding"
765
+ )
766
+
767
+ group_nccl_comm_init_option = 0
768
+
769
+ for group in parallel_groups :
770
+ comm_group = paddle .distributed .new_group (
771
+ ranks = group ,
772
+ nccl_comm_init_option = group_nccl_comm_init_option ,
773
+ )
774
+ if self .global_rank in group :
775
+ parallel_group = group
776
+ parallel_comm_group = comm_group
777
+
778
+ assert len (parallel_group ) > 0
779
+ assert parallel_comm_group is not None
780
+
781
+ logger .info (
782
+ f"Total { len (parallel_groups )} sharding comm group(s) create successfully!"
783
+ )
784
+ return parallel_group , parallel_comm_group
785
+
786
+ def merge_inner_comm_list (self , topo , outer_name , inner_name ):
787
+ """
788
+ merge all inner communication list whose rank-id are in
789
+ the same outer communication list. E.g.:
790
+ outer_comm_list: [[0, 4], [1, 5]]
791
+ inner_comm_list: [[0, 2], [1, 3], [4, 6], [5, 7]]
792
+ => merged_inner_comm_list: [[0, 2, 4, 6], [1, 3, 5, 7]]
793
+ """
794
+ inner_axis = topo ._parallel_names .index (inner_name )
795
+ outer_axis = topo ._parallel_names .index (outer_name )
796
+ inner_comm_list = topo .get_comm_list (inner_name )
797
+
798
+ num_merged_groups = len (inner_comm_list ) // topo ._dims [outer_axis ]
799
+ interval = (
800
+ math .prod (topo ._dims [(outer_axis + 1 ) :]) // topo ._dims [inner_axis ]
801
+ )
802
+ assert num_merged_groups > 0 and interval > 0
803
+
804
+ merged_comm_list = []
805
+ for i in range (num_merged_groups ):
806
+ comm = []
807
+ for j in range (topo ._dims [outer_axis ]):
808
+ assert i + j * interval < len (
809
+ inner_comm_list
810
+ ), f"Unexpected error in merge_inner_comm_list, { i } , { j } , { interval } , { len (inner_comm_list )} "
811
+ comm += inner_comm_list [i + j * interval ]
812
+ merged_comm_list .append (comm )
813
+
814
+ return merged_comm_list
815
+
816
+ def find_col_idx (self , comm_list , global_rank ):
817
+ rows = len (comm_list )
818
+ cols = len (comm_list [0 ])
819
+ r = rows - 1
820
+ c = 0
821
+
822
+ while r >= 0 and c < cols :
823
+ current = comm_list [r ][c ]
824
+ if current == global_rank :
825
+ return c
826
+ elif current < global_rank :
827
+ c += 1
828
+ else :
829
+ r -= 1
830
+
831
+ return None
832
+
833
+ def _get_parallel_id (self , topo , parallel_type ):
834
+ comm_list = topo .get_comm_list (parallel_type )
835
+ parallel_id = self .find_col_idx (comm_list , self .global_rank )
836
+ assert parallel_id is not None
837
+ return parallel_id
838
+
839
+ def _get_sharding_parallel_id (self ):
840
+ sharding_comm_list = self .merge_inner_comm_list (
841
+ self ._dense_topo , "moe_sharding" , "dense_sharding"
842
+ )
843
+ parallel_id = self .find_col_idx (sharding_comm_list , self .global_rank )
844
+ assert parallel_id is not None
845
+ return parallel_id
846
+
847
+ def get_expert_parallel_rank (self ) -> int :
848
+ return self ._expert_parallel_id
849
+
850
+ def get_expert_parallel_world_size (self ) -> int :
851
+ return self ._ep_degree
852
+
853
+ def get_expert_parallel_group (self ) -> Group :
854
+ return self ._ep_comm_group
855
+
856
+ def get_expert_parallel_group_src_rank (self ) -> int :
857
+ return self ._ep_comm_group .ranks [0 ]
858
+
859
+ def get_moe_sharding_parallel_rank (self ) -> int :
860
+ return self ._moe_sharding_parallel_id
861
+
862
+ def get_moe_sharding_parallel_world_size (self ) -> int :
863
+ return self ._moe_sharding_degree
864
+
865
+ def get_moe_sharding_parallel_group (self ) -> Group :
866
+ return self ._moe_sharding_comm_group
867
+
868
+ def get_moe_sharding_parallel_group_src_rank (self ) -> int :
869
+ return self ._moe_sharding_comm_group .ranks [0 ]
870
+
871
+
596
872
class _CommunicateGroup :
597
873
"""tmp for static"""
598
874
0 commit comments