Skip to content

Commit b724233

Browse files
authored
fix pipeline in dynamic_shape (#72243)
1 parent 622df0a commit b724233

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def recv_meta(self, group, reverse=False, broadcast=False):
7474

7575
data_numel = paddle.empty([1], dtype="int64")
7676
if not broadcast:
77-
src_rank = _hcg._get_p2p_prev_rank()
7877
paddle.distributed.recv(data_numel, src=src_rank, group=group)
7978
else:
8079
paddle.distributed.broadcast(
@@ -782,13 +781,15 @@ def recv_backward(
782781
if _timers is not None:
783782
_timers("recv_backward").start()
784783

785-
assert (
786-
not self._dynamic_shape
787-
), "p2p_helper.recv_backward function doesn't support dynamic_shape now"
784+
need_increase_cnt = False
788785

789786
if pp_last_stage:
790787
output_tensor_grad = None
791788
else:
789+
if self._dynamic_shape:
790+
self._recv_meta(reverse=True)
791+
need_increase_cnt = True
792+
792793
_, output_tensor_grad, _ = _p2p_helper(
793794
tensor_send_next=None,
794795
tensor_send_prev=None,
@@ -797,8 +798,12 @@ def recv_backward(
797798
sync_recv=sync_recv,
798799
send_recv_meta=self._send_recv_meta,
799800
batch_p2p_comm=batch_p2p_comm,
801+
dynamic_shape=self._dynamic_shape,
800802
)
801803

804+
if self._dynamic_shape and need_increase_cnt:
805+
self._dynamic_cnt += 1
806+
802807
if _timers is not None:
803808
_timers("recv_backward").stop()
804809

0 commit comments

Comments
 (0)