Skip to content

Commit bea4f53

Browse files
authored
dynamic_shape support balanced vpp (#72386)
1 parent 3800607 commit bea4f53

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

+10
Original file line numberDiff line numberDiff line change
@@ -2760,6 +2760,9 @@ def backward_async_comm(
27602760

27612761
# reset dynamic meta counter
27622762
if self._dynamic_shape:
2763+
assert self._p2p_helper._dynamic_cnt == len(
2764+
self._p2p_helper._send_recv_meta_list
2765+
), "p2p dynamic_cnt should equal to send_recv_meta_list"
27632766
self._p2p_helper._dynamic_cnt = 0
27642767

27652768
return train_loss
@@ -3359,6 +3362,13 @@ def forward_backward_pipeline(
33593362
backward_send_recv_buffer_queue.empty()
33603363
), "send_recv buffer should be empty"
33613364

3365+
# reset dynamic meta counter
3366+
if self._dynamic_shape:
3367+
assert self._p2p_helper._dynamic_cnt == len(
3368+
self._p2p_helper._send_recv_meta_list
3369+
), "p2p dynamic_cnt should equal to send_recv_meta_list"
3370+
self._p2p_helper._dynamic_cnt = 0
3371+
33623372
self._flush_records()
33633373
self._sync_overlap_grads()
33643374

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ def recv_forward(self, pp_first_stage, sync_recv=True, batch_p2p_comm=True):
763763
sync_recv=sync_recv,
764764
send_recv_meta=self._send_recv_meta,
765765
batch_p2p_comm=batch_p2p_comm,
766+
dynamic_shape=self._dynamic_shape,
766767
)
767768
if self._dynamic_shape:
768769
self._dynamic_cnt += 1
@@ -807,9 +808,6 @@ def recv_backward(
807808
if _timers is not None:
808809
_timers("recv_backward").stop()
809810

810-
if self._dynamic_shape and need_increase_cnt:
811-
self._dynamic_cnt += 1
812-
813811
return output_tensor_grad
814812

815813
def send_forward(
@@ -823,10 +821,6 @@ def send_forward(
823821
if _timers is not None:
824822
_timers("send_forward").start()
825823

826-
assert (
827-
not self._dynamic_shape
828-
), "p2p_helper.send_forward function doesn't support dynamic_shape now"
829-
830824
if not pp_last_stage:
831825
self._send_meta(output_tensor, skip_check_meta=skip_check_meta)
832826
_p2p_helper(
@@ -836,7 +830,11 @@ def send_forward(
836830
recv_next=False,
837831
send_recv_meta=self._send_recv_meta,
838832
batch_p2p_comm=batch_p2p_comm,
833+
dynamic_shape=self._dynamic_shape,
839834
)
835+
if self._dynamic_shape:
836+
self._dynamic_cnt += 1
837+
840838
if _timers is not None:
841839
_timers("send_forward").stop()
842840

@@ -847,19 +845,20 @@ def send_backward(
847845
if _timers is not None:
848846
_timers("send_backward").start()
849847

850-
assert (
851-
not self._dynamic_shape
852-
), "p2p_helper.send_backward function doesn't support dynamic_shape now"
853-
854848
if not pp_first_stage:
849+
if self._dynamic_shape:
850+
self._send_meta(input_tensor_grad, reverse=True)
855851
_p2p_helper(
856852
tensor_send_next=None,
857853
tensor_send_prev=input_tensor_grad,
858854
recv_prev=False,
859855
recv_next=False,
860856
send_recv_meta=self._send_recv_meta,
861857
batch_p2p_comm=batch_p2p_comm,
858+
dynamic_shape=self._dynamic_shape,
862859
)
860+
if self._dynamic_shape:
861+
self._dynamic_cnt += 1
863862
if _timers is not None:
864863
_timers("send_backward").stop()
865864

0 commit comments

Comments
 (0)