Skip to content

Commit 4ab582e

Browse files
fix
1 parent 81deeab commit 4ab582e

File tree

2 files changed

+200
-39
lines changed

2 files changed

+200
-39
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 198 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,13 @@ def __init__(self, layers, hcg, strategy):
14031403
self.overlap_schedule_mode = hasattr(
14041404
type(self._layers), "overlapped_forward_backward"
14051405
)
1406+
if self.overlap_schedule_mode:
1407+
assert (
1408+
not self._overlap_p2p_comm
1409+
), "Overlap p2p comm is not incompatible with overlap_schedule_mode."
1410+
assert (
1411+
not self._profiling
1412+
), "Profiling is not compatible with overlap_schedule_mode."
14061413
logger.info(
14071414
f"Using PipelineParallelWithInterleave with overlapping forward backward={self.overlap_schedule_mode}"
14081415
)
@@ -1591,12 +1598,7 @@ def _get_virtual_pp_rank(self, micro_step, forward):
15911598

15921599
return virtual_pp_stage
15931600

1594-
def _forward_step_helper(
1595-
self, micro_dataset, micro_step, overlap_schedule_mode=False
1596-
):
1597-
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
1598-
self.set_virtual_pipeline_rank(virtual_pp_rank)
1599-
1601+
def _get_forward_input(self, virtual_pp_rank):
16001602
# some checkers
16011603
assert hasattr(self, 'input_tensors')
16021604
assert hasattr(self, 'output_tensors')
@@ -1606,14 +1608,15 @@ def _forward_step_helper(
16061608
len(self.output_tensors[virtual_pp_rank]) + 1
16071609
)
16081610
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1609-
output_tensor, schedule_chunk, loss_fn_node = self._forward_step(
1610-
input_tensor,
1611-
micro_dataset,
1612-
virtual_pp_rank,
1613-
step_id=micro_step,
1614-
overlap_schedule_mode=overlap_schedule_mode,
1615-
)
1611+
return input_tensor
16161612

1613+
def _store_forward_outputs(
1614+
self,
1615+
virtual_pp_rank,
1616+
output_tensor,
1617+
schedule_chunk=None,
1618+
loss_fn_node=None,
1619+
):
16171620
self.output_tensors[virtual_pp_rank].append(output_tensor)
16181621
# If overlap_schedule_mode eq False, the schedule chunk is a None
16191622
self.schedule_chunks[virtual_pp_rank].append(schedule_chunk)
@@ -1624,6 +1627,26 @@ def _forward_step_helper(
16241627
# no need to store tensor for backward
16251628
self.input_tensors[virtual_pp_rank].pop()
16261629
self.output_tensors[virtual_pp_rank].pop()
1630+
1631+
def _forward_step_helper(
1632+
self, micro_dataset, micro_step, overlap_schedule_mode=False
1633+
):
1634+
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
1635+
self.set_virtual_pipeline_rank(virtual_pp_rank)
1636+
1637+
input_tensor = self._get_forward_input(virtual_pp_rank)
1638+
1639+
output_tensor, schedule_chunk, loss_fn_node = self._forward_step(
1640+
input_tensor,
1641+
micro_dataset,
1642+
virtual_pp_rank,
1643+
step_id=micro_step,
1644+
overlap_schedule_mode=overlap_schedule_mode,
1645+
)
1646+
1647+
self._store_forward_outputs(
1648+
virtual_pp_rank, output_tensor, schedule_chunk, loss_fn_node
1649+
)
16271650
return output_tensor
16281651

16291652
def _overlap_comm_grads(self):
@@ -1659,10 +1682,7 @@ def _sync_overlap_grads(self):
16591682
for buffer in buffers:
16601683
buffer.scale_grads()
16611684

1662-
def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
1663-
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
1664-
self.set_virtual_pipeline_rank(virtual_pp_rank)
1665-
1685+
def _get_backward_input(self, virtual_pp_rank):
16661686
# some checkers
16671687
assert hasattr(self, 'input_tensors')
16681688
assert hasattr(self, 'output_tensors')
@@ -1684,6 +1704,26 @@ def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
16841704
else:
16851705
loss_fn_node = None
16861706

1707+
return (
1708+
input_tensor,
1709+
output_tensor,
1710+
output_tensor_grad,
1711+
schedule_chunk,
1712+
loss_fn_node,
1713+
)
1714+
1715+
def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
1716+
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
1717+
self.set_virtual_pipeline_rank(virtual_pp_rank)
1718+
1719+
(
1720+
input_tensor,
1721+
output_tensor,
1722+
output_tensor_grad,
1723+
schedule_chunk,
1724+
loss_fn_node,
1725+
) = self._get_backward_input(virtual_pp_rank)
1726+
16871727
input_tensor_grad = self._backward_step(
16881728
input_tensor,
16891729
output_tensor,
@@ -1698,6 +1738,133 @@ def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
16981738

16991739
return input_tensor_grad
17001740

1741+
def _forward_backward_helper(
1742+
self, micro_dataset, forward_micro_step_id, backward_micro_step_id
1743+
):
1744+
if not self.overlap_schedule_mode:
1745+
self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
1746+
output_tensor = self._forward_step_helper(
1747+
micro_dataset,
1748+
forward_micro_step_id,
1749+
)
1750+
self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
1751+
1752+
# backward
1753+
self._record_stamp(
1754+
"B", backward_micro_step_id, '"B"', forward=False
1755+
)
1756+
input_tensor_grad = self._backward_step_helper(
1757+
backward_micro_step_id,
1758+
)
1759+
self._record_stamp(
1760+
"B", backward_micro_step_id, '"E"', forward=False
1761+
)
1762+
return output_tensor, input_tensor_grad
1763+
else:
1764+
# 1. prepare forward inputs
1765+
forward_virtual_pp_rank = self._get_virtual_pp_rank(
1766+
forward_micro_step_id, forward=True
1767+
)
1768+
self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
1769+
1770+
if self.user_hooks_enabled:
1771+
self.forward_hooks.run_hook()
1772+
1773+
forward_inputs = self._get_forward_input(forward_virtual_pp_rank)
1774+
if self.is_pipeline_first_stage():
1775+
forward_inputs = next(micro_dataset)[0]
1776+
self._check_micro_batch_data_valid(forward_inputs)
1777+
if self.is_pipeline_last_stage():
1778+
labels = next(micro_dataset)[1]
1779+
1780+
# 2. get forward chunks
1781+
forward_chunk = self._layers.get_schedule_chunk(
1782+
chunk_id=forward_virtual_pp_rank
1783+
)
1784+
1785+
if self.is_pipeline_last_stage():
1786+
assert len(self._layers._loss_fn) == 1
1787+
forward_loss_fn_node = self._layers._loss_fn[
1788+
0
1789+
].build_schedule_node()
1790+
forward_loss_fn_node.labels = labels
1791+
if self.accumulate_steps > 1 and not self._delay_scale_loss:
1792+
forward_loss_fn_node.scale_loss_factor = (
1793+
self.accumulate_steps
1794+
)
1795+
else:
1796+
forward_loss_fn_node = None
1797+
1798+
# 3. prepare backward inputs & get backward chunks
1799+
backward_virtual_pp_rank = self._get_virtual_pp_rank(
1800+
backward_micro_step_id, forward=False
1801+
)
1802+
self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
1803+
1804+
if self.user_hooks_enabled:
1805+
self.backward_hooks.run_hook()
1806+
1807+
(
1808+
_,
1809+
_,
1810+
backward_grads,
1811+
backward_chunk,
1812+
backward_loss_fn_node,
1813+
) = self._get_backward_input(backward_virtual_pp_rank)
1814+
1815+
# 4. forward & backward
1816+
if self.processed_steps < g_profile_pipeline_details_steps:
1817+
get_sync_logger().info("Before forward_backward_step")
1818+
if self._enable_timer:
1819+
self.timers("forward_backward_step").start()
1820+
output_tensor, forward_loss, input_tensor_grad = (
1821+
self._layers.overlapped_forward_backward(
1822+
forward_chunk,
1823+
forward_inputs,
1824+
forward_loss_fn_node,
1825+
backward_chunk,
1826+
backward_loss_fn_node,
1827+
backward_grads,
1828+
self.scaler,
1829+
)
1830+
)
1831+
if self.processed_steps < g_profile_pipeline_details_steps:
1832+
get_sync_logger().info("After forward_backward_step")
1833+
if self._enable_timer:
1834+
self.timers("forward_backward_step").stop()
1835+
1836+
# 5. process forward outputs
1837+
forward_virtual_pp_rank = self._get_virtual_pp_rank(
1838+
forward_micro_step_id, forward=True
1839+
)
1840+
self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
1841+
self._store_forward_outputs(
1842+
forward_virtual_pp_rank,
1843+
output_tensor,
1844+
forward_chunk,
1845+
forward_loss_fn_node,
1846+
)
1847+
1848+
if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
1849+
# Only increase micro batch id at virtual first/last pp stage.
1850+
# The micro batch id is used to load data, therefore, only increase it when load data.
1851+
self.micro_batch_id += 1
1852+
1853+
if self.is_pipeline_last_stage():
1854+
# In overlap mode, only one loss_fn is supported.
1855+
if self.total_loss is None:
1856+
self.total_loss = [[]]
1857+
self.total_loss[0].append(forward_loss.detach())
1858+
1859+
# 6. process backward outputs
1860+
backward_virtual_pp_rank = self._get_virtual_pp_rank(
1861+
backward_micro_step_id, forward=False
1862+
)
1863+
self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
1864+
self._overlap_comm_grads()
1865+
1866+
return output_tensor, input_tensor_grad
1867+
17011868
def bw_hook_func(self, buffer, param):
17021869
# For pipeline with interleave, we need to add grad to buffer without communication.
17031870
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@@ -1753,6 +1920,12 @@ def forward_backward_pipeline(
17531920
self._using_cache
17541921
), "cache should be enabled for pipeline with interleave"
17551922

1923+
self.overlap_schedule_mode = hasattr(
1924+
type(self._layers), "overlapped_forward_backward"
1925+
)
1926+
if forward_only:
1927+
self.overlap_schedule_mode = False
1928+
17561929
# init some attributes for this batch run
17571930
self.scaler = scaler
17581931
self.total_loss = None
@@ -1859,6 +2032,7 @@ def _process_bwd_buffer(step_id, tensor):
18592032
startup_steps += (self.num_model_chunks - 1) * first_chunk_acc
18602033
startup_steps = min(startup_steps, num_steps)
18612034

2035+
# An additional micro step is needed for overplapping schedule
18622036
if self.overlap_schedule_mode:
18632037
startup_steps += 1
18642038
steady_steps = num_steps - startup_steps
@@ -2169,30 +2343,15 @@ def _process_bwd_buffer(step_id, tensor):
21692343
overlap_p2p_comm=True,
21702344
)
21712345
else:
2172-
self._record_stamp(
2173-
"F", forward_micro_step_id, '"B"', forward=True
2174-
)
2175-
output_tensor = self._forward_step_helper(
2176-
micro_dataset,
2177-
forward_micro_step_id,
2178-
overlap_schedule_mode=self.overlap_schedule_mode,
2179-
)
2180-
self._record_stamp(
2181-
"F", forward_micro_step_id, '"E"', forward=True
2182-
)
2183-
2184-
# backward
21852346
backward_micro_step_id = micro_step
2186-
self._record_stamp(
2187-
"B", backward_micro_step_id, '"B"', forward=False
2188-
)
2189-
input_tensor_grad = self._backward_step_helper(
2190-
backward_micro_step_id,
2191-
overlap_schedule_mode=self.overlap_schedule_mode,
2192-
)
2193-
self._record_stamp(
2194-
"B", backward_micro_step_id, '"E"', forward=False
2347+
output_tensor, input_tensor_grad = (
2348+
self._forward_backward_helper(
2349+
micro_dataset,
2350+
forward_micro_step_id,
2351+
backward_micro_step_id,
2352+
)
21952353
)
2354+
21962355
if (
21972356
self._best_unbalanced_scheduler
21982357
and self.is_pipeline_last_stage(ignore_virtual=True)

python/paddle/distributed/fleet/meta_parallel/pp_utils/forward_backward_overlap_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def backward(self, output_grad=None, scaler=None):
143143

144144
if not isinstance(self.inputs, (tuple, list)):
145145
inputs = (self.inputs,)
146+
else:
147+
inputs = self.inputs
146148
grad = tuple([e.grad if e is not None else None for e in inputs])
147149
self._reset_states()
148150

0 commit comments

Comments
 (0)