Skip to content

Commit 46c897e

Browse files
[Distributed] Support forward backward overlap schdule for VPP (#71995)
* [Distributed] Support forward_backward_overlap mode for VPP * add * fix name
1 parent ed0209b commit 46c897e

File tree

5 files changed

+659
-228
lines changed

5 files changed

+659
-228
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ message PpConfig {
9292
optional bool enable_offload_queue = 11 [ default = false ];
9393
optional bool enable_dynamic_shape = 12 [ default = false ];
9494
optional bool use_dualpipev = 13 [ default = false ];
95+
optional bool forward_backward_overlap_scheduler = 14 [ default = false ];
9596
}
9697

9798
message DygraphShardingConfig {

python/paddle/distributed/fleet/meta_parallel/dualpipev.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def _forward_compute(self, phase: int, micro_datasets=None) -> None:
150150
inputs = self._get_forward_inputs(micro_datasets, phase, acc_id)
151151

152152
if self.overlapped_forward_backward:
153-
schedule_chunk = self._layers.forward(
154-
inputs, chunk_id=phase, overlap_schedule_mode=True
155-
)
153+
schedule_chunk = self._layers.get_schedule_chunk(chunk_id=phase)
156154
outputs = schedule_chunk.forward(inputs)
157155
else:
158156
schedule_chunk = None
@@ -300,9 +298,7 @@ def _forward_backward_compute(
300298
)
301299

302300
# forward & backward
303-
forward_chunk = self._layers.forward(
304-
None, chunk_id=forward_phase, overlap_schedule_mode=True
305-
)
301+
forward_chunk = self._layers.get_schedule_chunk(chunk_id=forward_phase)
306302
backward_chunk = self.schedule_chunks[backward_phase][backward_acc_id]
307303
forward_outputs, forward_loss, backward_input_grads = (
308304
self._layers.overlapped_forward_backward(

python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ def execute_func(*x):
10171017

10181018
return execute_func
10191019

1020-
def forward(self, input, chunk_id=None, overlap_schedule_mode=False):
1020+
def update_run_function(self, chunk_id):
10211021
if chunk_id is not None:
10221022
assert isinstance(chunk_id, int), "chunk_id should be an int"
10231023
assert (
@@ -1035,9 +1035,15 @@ def forward(self, input, chunk_id=None, overlap_schedule_mode=False):
10351035
# But for interleave, self.run_function will keep updating to the target functions at every run.
10361036
self.run_function = model_chunk.get_run_function()
10371037

1038-
if overlap_schedule_mode:
1039-
assert self._recompute_interval == 0
1040-
return self.build_schedule_nodes(0, len(self.run_function))
1038+
def get_schedule_chunk(self, chunk_id):
1039+
self.update_run_function(chunk_id)
1040+
1041+
assert self._recompute_interval == 0
1042+
return self.build_schedule_nodes(0, len(self.run_function))
1043+
1044+
def forward(self, input, chunk_id=None):
1045+
self.update_run_function(chunk_id)
1046+
10411047
if self._recompute_interval == 0:
10421048
input = self.forward_function(0, len(self.run_function))(input)
10431049
else:

0 commit comments

Comments
 (0)