@@ -804,10 +804,13 @@ def _insert_send_recv(cur_id, prev_id):
804
804
inputs = {'x' : var },
805
805
attrs = {
806
806
'peer' : 1 ,
807
+ self ._op_role_key : op_role ,
807
808
'ring_id' : ring_id ,
808
809
},
809
810
)
810
- send_op .dist_attr .execution_stream = "default"
811
+ print ("---------pipeline----------" )
812
+ print (send_op )
813
+ # send_op.dist_attr.execution_stream = "default"
811
814
extra_index_info ['index' ] += 1
812
815
var_shape = list (var .shape )
813
816
var_shape [0 ] = (
@@ -823,9 +826,12 @@ def _insert_send_recv(cur_id, prev_id):
823
826
'dtype' : var .dtype ,
824
827
'peer' : 0 ,
825
828
'ring_id' : ring_id ,
829
+ self ._op_role_key : op_role ,
826
830
},
827
831
)
828
- recv_op .dist_attr .execution_stream = "default"
832
+ print ("---------pipeline----------" )
833
+ print (recv_op )
834
+ # recv_op.dist_attr.execution_stream = "default"
829
835
extra_index_info ['index' ] += 1
830
836
elif self .schedule_mode == '1F1B' : # 1F1B
831
837
var_shape = list (var .shape )
@@ -892,9 +898,12 @@ def _insert_send_recv(cur_id, prev_id):
892
898
attrs = {
893
899
'ring_id' : ring_id ,
894
900
'peer' : 1 ,
901
+ self ._op_role_key : op_role ,
895
902
},
896
903
)
897
- send_op .dist_attr .execution_stream = "default"
904
+ print ("---------pipeline----------" )
905
+ print (send_op )
906
+ # send_op.dist_attr.execution_stream = "default"
898
907
else :
899
908
block ._insert_op_without_sync (
900
909
index = index + extra_index_info ['index' ],
@@ -906,6 +915,7 @@ def _insert_send_recv(cur_id, prev_id):
906
915
'use_calc_stream' : False ,
907
916
'num' : self .mp_degree ,
908
917
'id' : self .mp_rank ,
918
+ self ._op_role_key : op_role ,
909
919
},
910
920
)
911
921
extra_index_info ['index' ] += 1
@@ -941,9 +951,12 @@ def _insert_send_recv(cur_id, prev_id):
941
951
'dtype' : var .dtype ,
942
952
'peer' : 0 ,
943
953
'ring_id' : ring_id ,
954
+ self ._op_role_key : op_role ,
944
955
},
945
956
)
946
- recv_op .dist_attr .execution_stream = "default"
957
+ print ("---------pipeline----------" )
958
+ print (recv_op )
959
+ # recv_op.dist_attr.execution_stream = "default"
947
960
else :
948
961
block ._insert_op_without_sync (
949
962
index = index + extra_index_info ['index' ],
@@ -956,6 +969,7 @@ def _insert_send_recv(cur_id, prev_id):
956
969
'out_shape' : var_shape ,
957
970
'num' : self .mp_degree ,
958
971
'id' : self .mp_rank ,
972
+ self ._op_role_key : op_role ,
959
973
},
960
974
)
961
975
extra_index_info ['index' ] += 1
@@ -1622,8 +1636,11 @@ def _process_persistable_vars_in_multi_sections(
1622
1636
# microbatch
1623
1637
'peer' : read_dev_index ,
1624
1638
'ring_id' : ring_id ,
1639
+ self ._op_role_key : self ._op_role .LRSched ,
1625
1640
},
1626
1641
)
1642
+ print ("---------pipeline----------" )
1643
+ # print(recv_op)
1627
1644
read_block ._insert_op (
1628
1645
index = 0 ,
1629
1646
type = 'p_recv' ,
@@ -1634,6 +1651,7 @@ def _process_persistable_vars_in_multi_sections(
1634
1651
# microbatch
1635
1652
'peer' : write_dev_index ,
1636
1653
'ring_id' : ring_id ,
1654
+ self ._op_role_key : self ._op_role .LRSched ,
1637
1655
},
1638
1656
)
1639
1657
read_block ._insert_op (
0 commit comments