@@ -763,6 +763,7 @@ def recv_forward(self, pp_first_stage, sync_recv=True, batch_p2p_comm=True):
763
763
sync_recv = sync_recv ,
764
764
send_recv_meta = self ._send_recv_meta ,
765
765
batch_p2p_comm = batch_p2p_comm ,
766
+ dynamic_shape = self ._dynamic_shape ,
766
767
)
767
768
if self ._dynamic_shape :
768
769
self ._dynamic_cnt += 1
@@ -807,9 +808,6 @@ def recv_backward(
807
808
if _timers is not None :
808
809
_timers ("recv_backward" ).stop ()
809
810
810
- if self ._dynamic_shape and need_increase_cnt :
811
- self ._dynamic_cnt += 1
812
-
813
811
return output_tensor_grad
814
812
815
813
def send_forward (
@@ -823,10 +821,6 @@ def send_forward(
823
821
if _timers is not None :
824
822
_timers ("send_forward" ).start ()
825
823
826
- assert (
827
- not self ._dynamic_shape
828
- ), "p2p_helper.send_forward function doesn't support dynamic_shape now"
829
-
830
824
if not pp_last_stage :
831
825
self ._send_meta (output_tensor , skip_check_meta = skip_check_meta )
832
826
_p2p_helper (
@@ -836,7 +830,11 @@ def send_forward(
836
830
recv_next = False ,
837
831
send_recv_meta = self ._send_recv_meta ,
838
832
batch_p2p_comm = batch_p2p_comm ,
833
+ dynamic_shape = self ._dynamic_shape ,
839
834
)
835
+ if self ._dynamic_shape :
836
+ self ._dynamic_cnt += 1
837
+
840
838
if _timers is not None :
841
839
_timers ("send_forward" ).stop ()
842
840
@@ -847,19 +845,20 @@ def send_backward(
847
845
if _timers is not None :
848
846
_timers ("send_backward" ).start ()
849
847
850
- assert (
851
- not self ._dynamic_shape
852
- ), "p2p_helper.send_backward function doesn't support dynamic_shape now"
853
-
854
848
if not pp_first_stage :
849
+ if self ._dynamic_shape :
850
+ self ._send_meta (input_tensor_grad , reverse = True )
855
851
_p2p_helper (
856
852
tensor_send_next = None ,
857
853
tensor_send_prev = input_tensor_grad ,
858
854
recv_prev = False ,
859
855
recv_next = False ,
860
856
send_recv_meta = self ._send_recv_meta ,
861
857
batch_p2p_comm = batch_p2p_comm ,
858
+ dynamic_shape = self ._dynamic_shape ,
862
859
)
860
+ if self ._dynamic_shape :
861
+ self ._dynamic_cnt += 1
863
862
if _timers is not None :
864
863
_timers ("send_backward" ).stop ()
865
864
0 commit comments