@@ -74,7 +74,6 @@ def recv_meta(self, group, reverse=False, broadcast=False):
74
74
75
75
data_numel = paddle .empty ([1 ], dtype = "int64" )
76
76
if not broadcast :
77
- src_rank = _hcg ._get_p2p_prev_rank ()
78
77
paddle .distributed .recv (data_numel , src = src_rank , group = group )
79
78
else :
80
79
paddle .distributed .broadcast (
@@ -782,13 +781,15 @@ def recv_backward(
782
781
if _timers is not None :
783
782
_timers ("recv_backward" ).start ()
784
783
785
- assert (
786
- not self ._dynamic_shape
787
- ), "p2p_helper.recv_backward function doesn't support dynamic_shape now"
784
+ need_increase_cnt = False
788
785
789
786
if pp_last_stage :
790
787
output_tensor_grad = None
791
788
else :
789
+ if self ._dynamic_shape :
790
+ self ._recv_meta (reverse = True )
791
+ need_increase_cnt = True
792
+
792
793
_ , output_tensor_grad , _ = _p2p_helper (
793
794
tensor_send_next = None ,
794
795
tensor_send_prev = None ,
@@ -797,8 +798,12 @@ def recv_backward(
797
798
sync_recv = sync_recv ,
798
799
send_recv_meta = self ._send_recv_meta ,
799
800
batch_p2p_comm = batch_p2p_comm ,
801
+ dynamic_shape = self ._dynamic_shape ,
800
802
)
801
803
804
+ if self ._dynamic_shape and need_increase_cnt :
805
+ self ._dynamic_cnt += 1
806
+
802
807
if _timers is not None :
803
808
_timers ("recv_backward" ).stop ()
804
809
0 commit comments