@@ -1403,6 +1403,13 @@ def __init__(self, layers, hcg, strategy):
1403
1403
self .overlap_schedule_mode = hasattr (
1404
1404
type (self ._layers ), "overlapped_forward_backward"
1405
1405
)
1406
+ if self .overlap_schedule_mode :
1407
+ assert (
1408
+ not self ._overlap_p2p_comm
1409
+ ), "Overlap p2p comm is not incompatible with overlap_schedule_mode."
1410
+ assert (
1411
+ not self ._profiling
1412
+ ), "Profiling is not compatible with overlap_schedule_mode."
1406
1413
logger .info (
1407
1414
f"Using PipelineParallelWithInterleave with overlapping forward backward={ self .overlap_schedule_mode } "
1408
1415
)
@@ -1591,12 +1598,7 @@ def _get_virtual_pp_rank(self, micro_step, forward):
1591
1598
1592
1599
return virtual_pp_stage
1593
1600
1594
- def _forward_step_helper (
1595
- self , micro_dataset , micro_step , overlap_schedule_mode = False
1596
- ):
1597
- virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = True )
1598
- self .set_virtual_pipeline_rank (virtual_pp_rank )
1599
-
1601
+ def _get_forward_input (self , virtual_pp_rank ):
1600
1602
# some checkers
1601
1603
assert hasattr (self , 'input_tensors' )
1602
1604
assert hasattr (self , 'output_tensors' )
@@ -1606,14 +1608,15 @@ def _forward_step_helper(
1606
1608
len (self .output_tensors [virtual_pp_rank ]) + 1
1607
1609
)
1608
1610
input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1609
- output_tensor , schedule_chunk , loss_fn_node = self ._forward_step (
1610
- input_tensor ,
1611
- micro_dataset ,
1612
- virtual_pp_rank ,
1613
- step_id = micro_step ,
1614
- overlap_schedule_mode = overlap_schedule_mode ,
1615
- )
1611
+ return input_tensor
1616
1612
1613
+ def _store_forward_outputs (
1614
+ self ,
1615
+ virtual_pp_rank ,
1616
+ output_tensor ,
1617
+ schedule_chunk = None ,
1618
+ loss_fn_node = None ,
1619
+ ):
1617
1620
self .output_tensors [virtual_pp_rank ].append (output_tensor )
1618
1621
# If overlap_schedule_mode eq False, the schedule chunk is a None
1619
1622
self .schedule_chunks [virtual_pp_rank ].append (schedule_chunk )
@@ -1624,6 +1627,26 @@ def _forward_step_helper(
1624
1627
# no need to store tensor for backward
1625
1628
self .input_tensors [virtual_pp_rank ].pop ()
1626
1629
self .output_tensors [virtual_pp_rank ].pop ()
1630
+
1631
+ def _forward_step_helper (
1632
+ self , micro_dataset , micro_step , overlap_schedule_mode = False
1633
+ ):
1634
+ virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = True )
1635
+ self .set_virtual_pipeline_rank (virtual_pp_rank )
1636
+
1637
+ input_tensor = self ._get_forward_input (virtual_pp_rank )
1638
+
1639
+ output_tensor , schedule_chunk , loss_fn_node = self ._forward_step (
1640
+ input_tensor ,
1641
+ micro_dataset ,
1642
+ virtual_pp_rank ,
1643
+ step_id = micro_step ,
1644
+ overlap_schedule_mode = overlap_schedule_mode ,
1645
+ )
1646
+
1647
+ self ._store_forward_outputs (
1648
+ virtual_pp_rank , output_tensor , schedule_chunk , loss_fn_node
1649
+ )
1627
1650
return output_tensor
1628
1651
1629
1652
def _overlap_comm_grads (self ):
@@ -1659,10 +1682,7 @@ def _sync_overlap_grads(self):
1659
1682
for buffer in buffers :
1660
1683
buffer .scale_grads ()
1661
1684
1662
- def _backward_step_helper (self , micro_step , overlap_schedule_mode = False ):
1663
- virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = False )
1664
- self .set_virtual_pipeline_rank (virtual_pp_rank )
1665
-
1685
+ def _get_backward_input (self , virtual_pp_rank ):
1666
1686
# some checkers
1667
1687
assert hasattr (self , 'input_tensors' )
1668
1688
assert hasattr (self , 'output_tensors' )
@@ -1684,6 +1704,26 @@ def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
1684
1704
else :
1685
1705
loss_fn_node = None
1686
1706
1707
+ return (
1708
+ input_tensor ,
1709
+ output_tensor ,
1710
+ output_tensor_grad ,
1711
+ schedule_chunk ,
1712
+ loss_fn_node ,
1713
+ )
1714
+
1715
+ def _backward_step_helper (self , micro_step , overlap_schedule_mode = False ):
1716
+ virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = False )
1717
+ self .set_virtual_pipeline_rank (virtual_pp_rank )
1718
+
1719
+ (
1720
+ input_tensor ,
1721
+ output_tensor ,
1722
+ output_tensor_grad ,
1723
+ schedule_chunk ,
1724
+ loss_fn_node ,
1725
+ ) = self ._get_backward_input (virtual_pp_rank )
1726
+
1687
1727
input_tensor_grad = self ._backward_step (
1688
1728
input_tensor ,
1689
1729
output_tensor ,
@@ -1698,6 +1738,133 @@ def _backward_step_helper(self, micro_step, overlap_schedule_mode=False):
1698
1738
1699
1739
return input_tensor_grad
1700
1740
1741
+ def _forward_backward_helper (
1742
+ self , micro_dataset , forward_micro_step_id , backward_micro_step_id
1743
+ ):
1744
+ if not self .overlap_schedule_mode :
1745
+ self ._record_stamp ("F" , forward_micro_step_id , '"B"' , forward = True )
1746
+ output_tensor = self ._forward_step_helper (
1747
+ micro_dataset ,
1748
+ forward_micro_step_id ,
1749
+ )
1750
+ self ._record_stamp ("F" , forward_micro_step_id , '"E"' , forward = True )
1751
+
1752
+ # backward
1753
+ self ._record_stamp (
1754
+ "B" , backward_micro_step_id , '"B"' , forward = False
1755
+ )
1756
+ input_tensor_grad = self ._backward_step_helper (
1757
+ backward_micro_step_id ,
1758
+ )
1759
+ self ._record_stamp (
1760
+ "B" , backward_micro_step_id , '"E"' , forward = False
1761
+ )
1762
+ return output_tensor , input_tensor_grad
1763
+ else :
1764
+ # 1. prepare forward inputs
1765
+ forward_virtual_pp_rank = self ._get_virtual_pp_rank (
1766
+ forward_micro_step_id , forward = True
1767
+ )
1768
+ self .set_virtual_pipeline_rank (forward_virtual_pp_rank )
1769
+
1770
+ if self .user_hooks_enabled :
1771
+ self .forward_hooks .run_hook ()
1772
+
1773
+ forward_inputs = self ._get_forward_input (forward_virtual_pp_rank )
1774
+ if self .is_pipeline_first_stage ():
1775
+ forward_inputs = next (micro_dataset )[0 ]
1776
+ self ._check_micro_batch_data_valid (forward_inputs )
1777
+ if self .is_pipeline_last_stage ():
1778
+ labels = next (micro_dataset )[1 ]
1779
+
1780
+ # 2. get forward chunks
1781
+ forward_chunk = self ._layers .get_schedule_chunk (
1782
+ chunk_id = forward_virtual_pp_rank
1783
+ )
1784
+
1785
+ if self .is_pipeline_last_stage ():
1786
+ assert len (self ._layers ._loss_fn ) == 1
1787
+ forward_loss_fn_node = self ._layers ._loss_fn [
1788
+ 0
1789
+ ].build_schedule_node ()
1790
+ forward_loss_fn_node .labels = labels
1791
+ if self .accumulate_steps > 1 and not self ._delay_scale_loss :
1792
+ forward_loss_fn_node .scale_loss_factor = (
1793
+ self .accumulate_steps
1794
+ )
1795
+ else :
1796
+ forward_loss_fn_node = None
1797
+
1798
+ # 3. prepare backward inputs & get backward chunks
1799
+ backward_virtual_pp_rank = self ._get_virtual_pp_rank (
1800
+ backward_micro_step_id , forward = False
1801
+ )
1802
+ self .set_virtual_pipeline_rank (backward_virtual_pp_rank )
1803
+
1804
+ if self .user_hooks_enabled :
1805
+ self .backward_hooks .run_hook ()
1806
+
1807
+ (
1808
+ _ ,
1809
+ _ ,
1810
+ backward_grads ,
1811
+ backward_chunk ,
1812
+ backward_loss_fn_node ,
1813
+ ) = self ._get_backward_input (backward_virtual_pp_rank )
1814
+
1815
+ # 4. forward & backward
1816
+ if self .processed_steps < g_profile_pipeline_details_steps :
1817
+ get_sync_logger ().info ("Before forward_backward_step" )
1818
+ if self ._enable_timer :
1819
+ self .timers ("forward_backward_step" ).start ()
1820
+ output_tensor , forward_loss , input_tensor_grad = (
1821
+ self ._layers .overlapped_forward_backward (
1822
+ forward_chunk ,
1823
+ forward_inputs ,
1824
+ forward_loss_fn_node ,
1825
+ backward_chunk ,
1826
+ backward_loss_fn_node ,
1827
+ backward_grads ,
1828
+ self .scaler ,
1829
+ )
1830
+ )
1831
+ if self .processed_steps < g_profile_pipeline_details_steps :
1832
+ get_sync_logger ().info ("After forward_backward_step" )
1833
+ if self ._enable_timer :
1834
+ self .timers ("forward_backward_step" ).stop ()
1835
+
1836
+ # 5. process forward outputs
1837
+ forward_virtual_pp_rank = self ._get_virtual_pp_rank (
1838
+ forward_micro_step_id , forward = True
1839
+ )
1840
+ self .set_virtual_pipeline_rank (forward_virtual_pp_rank )
1841
+ self ._store_forward_outputs (
1842
+ forward_virtual_pp_rank ,
1843
+ output_tensor ,
1844
+ forward_chunk ,
1845
+ forward_loss_fn_node ,
1846
+ )
1847
+
1848
+ if self .is_pipeline_first_stage () or self .is_pipeline_last_stage ():
1849
+ # Only increase micro batch id at virtual first/last pp stage.
1850
+ # The micro batch id is used to load data, therefore, only increase it when load data.
1851
+ self .micro_batch_id += 1
1852
+
1853
+ if self .is_pipeline_last_stage ():
1854
+ # In overlap mode, only one loss_fn is supported.
1855
+ if self .total_loss is None :
1856
+ self .total_loss = [[]]
1857
+ self .total_loss [0 ].append (forward_loss .detach ())
1858
+
1859
+ # 6. process backward outputs
1860
+ backward_virtual_pp_rank = self ._get_virtual_pp_rank (
1861
+ backward_micro_step_id , forward = False
1862
+ )
1863
+ self .set_virtual_pipeline_rank (backward_virtual_pp_rank )
1864
+ self ._overlap_comm_grads ()
1865
+
1866
+ return output_tensor , input_tensor_grad
1867
+
1701
1868
def bw_hook_func (self , buffer , param ):
1702
1869
# For pipeline with interleave, we need to add grad to buffer without communication.
1703
1870
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@@ -1753,6 +1920,12 @@ def forward_backward_pipeline(
1753
1920
self ._using_cache
1754
1921
), "cache should be enabled for pipeline with interleave"
1755
1922
1923
+ self .overlap_schedule_mode = hasattr (
1924
+ type (self ._layers ), "overlapped_forward_backward"
1925
+ )
1926
+ if forward_only :
1927
+ self .overlap_schedule_mode = False
1928
+
1756
1929
# init some attributes for this batch run
1757
1930
self .scaler = scaler
1758
1931
self .total_loss = None
@@ -1859,6 +2032,7 @@ def _process_bwd_buffer(step_id, tensor):
1859
2032
startup_steps += (self .num_model_chunks - 1 ) * first_chunk_acc
1860
2033
startup_steps = min (startup_steps , num_steps )
1861
2034
2035
+ # An additional micro step is needed for overplapping schedule
1862
2036
if self .overlap_schedule_mode :
1863
2037
startup_steps += 1
1864
2038
steady_steps = num_steps - startup_steps
@@ -2169,30 +2343,15 @@ def _process_bwd_buffer(step_id, tensor):
2169
2343
overlap_p2p_comm = True ,
2170
2344
)
2171
2345
else :
2172
- self ._record_stamp (
2173
- "F" , forward_micro_step_id , '"B"' , forward = True
2174
- )
2175
- output_tensor = self ._forward_step_helper (
2176
- micro_dataset ,
2177
- forward_micro_step_id ,
2178
- overlap_schedule_mode = self .overlap_schedule_mode ,
2179
- )
2180
- self ._record_stamp (
2181
- "F" , forward_micro_step_id , '"E"' , forward = True
2182
- )
2183
-
2184
- # backward
2185
2346
backward_micro_step_id = micro_step
2186
- self ._record_stamp (
2187
- "B" , backward_micro_step_id , '"B"' , forward = False
2188
- )
2189
- input_tensor_grad = self ._backward_step_helper (
2190
- backward_micro_step_id ,
2191
- overlap_schedule_mode = self .overlap_schedule_mode ,
2192
- )
2193
- self ._record_stamp (
2194
- "B" , backward_micro_step_id , '"E"' , forward = False
2347
+ output_tensor , input_tensor_grad = (
2348
+ self ._forward_backward_helper (
2349
+ micro_dataset ,
2350
+ forward_micro_step_id ,
2351
+ backward_micro_step_id ,
2352
+ )
2195
2353
)
2354
+
2196
2355
if (
2197
2356
self ._best_unbalanced_scheduler
2198
2357
and self .is_pipeline_last_stage (ignore_virtual = True )
0 commit comments