diff --git a/python/paddle/distributed/auto_parallel/pipelining/__init__.py b/python/paddle/distributed/auto_parallel/pipelining/__init__.py new file mode 100644 index 0000000000000..9cc1ba789f6dc --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [] diff --git a/python/paddle/distributed/auto_parallel/pipelining/_backward.py b/python/paddle/distributed/auto_parallel/pipelining/_backward.py new file mode 100644 index 0000000000000..4db4214b25b69 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/_backward.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any, Iterator + +import paddle + +from .utils import _map_debug_info + +logger = logging.getLogger(__name__) + + +def stage_backward_input( + stage_outputs_or_loss: list[paddle.Tensor], + output_grads: list[paddle.Tensor] | None, + input_values: list[paddle.Tensor], + weights: Iterator[paddle.Tensor], +) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any]]]: + raise NotImplementedError("stage_backward_input is not implemented yet") + + +def stage_backward_weight( + weights: Iterator[paddle.Tensor], + param_groups: list[dict[str, Any]], + retain_graph=False, +) -> tuple[paddle.Tensor | None, ...]: + raise NotImplementedError("stage_backward_weight is not implemented yet") + + +def stage_backward( + stage_output, + output_grads, + input_values, +) -> tuple[paddle.Tensor | None, ...]: + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + + """ + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors: list[paddle.Tensor] = [] + output_grad_tensors: list[paddle.Tensor | None] = [] + + def extract_tensors_with_grads( + output_val, + grad_val, + extract_tensors_with_grads, + ): + if isinstance(output_val, paddle.Tensor): + if output_val.stop_gradient and output_val.grad_fn is None: + return + assert isinstance( + grad_val, (paddle.Tensor, type(None)) + ), f"Expected Tensor or None gradient but got {type(grad_val)}" + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance( + grad_val, (tuple, list) + ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads( + ov, + gv, + extract_tensors_with_grads, + ) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val.keys(): + extract_tensors_with_grads( + output_val[k], grad_val[k], extract_tensors_with_grads + ) + else: + # Output is a non-tensor type; just ignore it + pass + + # Note: ref cycle + # break a ref cycle that would keep tensors alive until GC runs + # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward + # and used in extract_tensors_with_grads + # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors, + # and to itself (extract_tensors_with_grads) since it makes a recursive call + # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad + # fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore + extract_tensors_with_grads( + stage_output, output_grads, extract_tensors_with_grads + ) + paddle.autograd.backward( + stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs: list[paddle.Tensor | None] = [] + for val in input_values: + if isinstance(val, paddle.Tensor): + grad_inputs.append(val.grad) + else: + grad_inputs.append(None) + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {_map_debug_info(stage_output)} + Output gradient: {_map_debug_info(output_grads)} + Input: {_map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return tuple(grad_inputs) diff --git a/python/paddle/distributed/auto_parallel/pipelining/microbatch.py b/python/paddle/distributed/auto_parallel/pipelining/microbatch.py new file mode 100644 index 0000000000000..9a7658d923309 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/microbatch.py @@ -0,0 +1,262 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any + +import paddle +from paddle.utils import flatten, map_structure, pack_sequence_as + +logger = logging.getLogger(__name__) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_axis): + self.split_axis = split_axis + + split_axis: int + + def __repr__(self): + return f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_axis})" + + def __str__(self): + return f"TensorChunkSpec({self.split_axis})" + + +def _split_args_helper( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + A helper function of split_args_kwargs_into_chunks. + """ + assert len(args_dict) == len( + args_chunk_spec + ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + + shared_args_dict_flat = {} + # handle args one by one + for arg_key, arg in args_dict.items(): + arg_flat = flatten(arg) + + chunk_spec = args_chunk_spec[arg_key] + assert chunk_spec is not None + + chunk_spec_flat = flatten(chunk_spec) + assert len(chunk_spec_flat) == len( + arg_flat + ), f"{arg_key} {len(arg_flat)} != {len(chunk_spec_flat)}" + + shard_arg_flat = [] + + for v, chunk_v in zip(arg_flat, chunk_spec_flat): + if not isinstance(v, paddle.Tensor): + shard_arg_flat.append([v] * num_chunks) + elif isinstance(chunk_v, TensorChunkSpec): + v_split_axis_size = v.shape[chunk_v.split_axis] + + if v_split_axis_size < num_chunks: + raise ValueError( + f"Arg {arg_key} on chunking dimension has a size of {v_split_axis_size}, " + f"smaller than the number of chunks {num_chunks}. " + "Please adjust your num_chunks setting." + ) + # split tensor v + chunk_tensors = paddle.tensor_split( + v, num_chunks, chunk_v.split_axis + ) + + shard_arg_flat.append(chunk_tensors) + else: + raise TypeError(f"Unrecognized chunk spec: {chunk_v}") + + shared_args_dict_flat[arg_key] = shard_arg_flat + + # the structure of each element in args_split is the same as the original args_dict + args_split = [] + for idx in range(num_chunks): + chunk_args = {} + for key, arg in shared_args_dict_flat.items(): + arg_of_curr_chunk = ( + [v[idx] for v in arg] if len(arg) > 1 else arg[0][idx] + ) + chunk_args[key] = arg_of_curr_chunk + + # flatten chunk_args first, and then pack chunk_args as the origin args_dict + chunk_args = pack_sequence_as(args_dict, flatten(chunk_args)) + args_split.append(chunk_args) + return args_split + + +def split_args_kwargs_into_chunks( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + chunks: int, + args_chunk_spec: ( + tuple[ + tuple[TensorChunkSpec, ...] + | list[TensorChunkSpec, ...] + | TensorChunkSpec, + ..., + ] + | None + ) = None, + kwargs_chunk_spec: ( + dict[ + str, + tuple[TensorChunkSpec, ...] + | list[TensorChunkSpec, ...] + | TensorChunkSpec, + ] + | None + ) = None, +) -> tuple[list[tuple], list[dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: tuple of args + kwargs: dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: list of sharded args + kwargs_split: list of sharded kwargs + """ + + if kwargs is None: + kwargs = {} + + if args_chunk_spec is None: + args_chunk_spec = map_structure( + lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), args + ) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = map_structure( + lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), kwargs + ) + + args_split_dict = _split_args_helper( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + kwargs_split = _split_args_helper( + kwargs, + kwargs_chunk_spec, + chunks, + ) + + assert len(args_split_dict) == len(kwargs_split), ( + "args and kwargs are split into difference number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + # the form of each args_chunk should be tuple + args_split = [ + tuple(args_chunk[i] for i in range(len(args_chunk))) + for args_chunk in args_split_dict + ] + + return args_split, kwargs_split + + +def merge_chunks( + chunks: list[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single chunk according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + chunk: chunks merged value + """ + if len(chunks) == 0: + logger.warning("No chunks to merge.") + return chunks + + if chunk_spec is None: + chunk_spec = map_structure( + lambda _: TensorChunkSpec(DEFAULT_CHUNK_DIM), chunks[0] + ) + + chunks_flat = [] + # flatten chunk_spec first + chunk_spec = flatten(chunk_spec) + for chunk in chunks: + chunk_flat = flatten(chunk) + assert len(chunk_flat) == len( + chunk_spec + ), f"Chunk {chunk} did not match chunk spec {chunk_spec}" + chunks_flat.append(chunk_flat) + + def _merge_non_tensor_type_arg(chunks, idx, chunk_spec_of_arg=None): + # use the first chunk's value as the merged result + arg_0 = chunks[0][idx] + for chunk_idx in range(1, len(chunks)): + assert chunks[chunk_idx][idx] == arg_0, ( + f"Cannot merge chunks with index 0 and {idx} with different values," + f"When the arg's TensorChunkSpec is {chunk_spec_of_arg}" + ) + return arg_0 + + args_flat = [] + for arg_idx, chunk_spec_of_arg in enumerate(chunk_spec): + if isinstance(chunk_spec_of_arg, TensorChunkSpec): + if isinstance(chunks_flat[0][arg_idx], paddle.Tensor): + arg_chunks_to_merge = [ + chunks_flat[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flat)) + ] + merged_arg = paddle.concat( + arg_chunks_to_merge, axis=chunk_spec_of_arg.split_axis + ) + else: + logger.warning( + f"Cannot merge chunks with TensorChunkSpec {chunk_spec_of_arg}." + "The TensorChunkSpec only supports paddle.Tensor type." + ) + + merged_arg = _merge_non_tensor_type_arg( + chunks_flat, arg_idx, chunk_spec_of_arg + ) + else: + merged_arg = _merge_non_tensor_type_arg( + chunks_flat, arg_idx, chunk_spec_of_arg + ) + + args_flat.append(merged_arg) + + # pack args_flat as the input chunks[0] + return pack_sequence_as(chunks[0], args_flat) diff --git a/python/paddle/distributed/auto_parallel/pipelining/schedules.py b/python/paddle/distributed/auto_parallel/pipelining/schedules.py new file mode 100644 index 0000000000000..94c08241d433c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/schedules.py @@ -0,0 +1,1218 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NamedTuple, +) + +if TYPE_CHECKING: + from .stage import _PipelineStageBase + + +from microbatch import ( + TensorChunkSpec, + merge_chunks, + split_args_kwargs_into_chunks, +) + +import paddle +import paddle.distributed as dist +from paddle import profiler + +logger = logging.getLogger(__name__) + + +class _ActType(Enum): + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + + def __str__(self): + str_map = { + _ActType.FORWARD: "F", + _ActType.BACKWARD_INPUT: "I", + _ActType.BACKWARD_WEIGHT: "W", + _ActType.UNSHARD: "UNSHARD", + _ActType.RESHARD: "RESHARD", + _ActType.SEND_F: "SEND_F", + _ActType.RECV_F: "RECV_F", + _ActType.SEND_B: "SEND_B", + _ActType.RECV_B: "RECV_B", + _ActType.FULL_BACKWARD: "B", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ActType.FORWARD + elif action == "I": + return _ActType.BACKWARD_INPUT + elif action == "W": + return _ActType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ActType.UNSHARD + elif action == "RESHARD": + return _ActType.RESHARD + elif action == "SEND_F": + return _ActType.SEND_F + elif action == "RECV_F": + return _ActType.RECV_F + elif action == "SEND_B": + return _ActType.SEND_B + elif action == "RECV_B": + return _ActType.RECV_B + elif action == "B": + return _ActType.FULL_BACKWARD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ActType.FORWARD +BACKWARD_INPUT = _ActType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ActType.BACKWARD_WEIGHT +UNSHARD = _ActType.UNSHARD +RESHARD = _ActType.RESHARD +SEND_F = _ActType.SEND_F +RECV_F = _ActType.RECV_F +SEND_B = _ActType.SEND_B +RECV_B = _ActType.RECV_B +FULL_BACKWARD = _ActType.FULL_BACKWARD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ActType + microbatch_index: int | None = None + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Callable[..., paddle.Tensor] | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: list[paddle.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._has_backward: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._has_backward and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + """ + raise NotImplementedError + + @abstractmethod + def step(self, *args, target=None, losses: list | None = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + raise NotImplementedError + + def _check_inputs( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError( + f"losses must be a list but got a {type(losses)}" + ) + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: list[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + self._stage_initialized = False + + def _initialize_stage(self, args, kwargs): + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + if self._has_backward: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_initialized = True + + def step(self, *args, target=None, losses: list | None = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list( + paddle.tensor_split(target, self._n_microbatches) + ) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + if self._stage.is_last: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + +def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None): + """ + Simple wrapper over batch_isend_irecv from paddle.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return None + desc_str = f"{desc}, " if desc else "" + logger.info("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops).pop() + + +def _sorted_batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None): + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # list is the list of ops towards the peer + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with profiler.RecordEvent(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug( + "[%s] Forwarded microbatch %s", self._stage.stage_index, i + ) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + # No loss function, no need to run backward + if not self._has_backward: + return + + # Run backward + # Delay send waits + bwd_sends_to_wait: list[dist.Work] = [] + for i in range(self._n_microbatches): + with profiler.RecordEvent(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + work.wait() + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk( + i, loss=loss, last_backward=i == self._n_microbatches - 1 + ) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug( + "[%s] Backwarded microbatch %s", self._stage.stage_index, i + ) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + work.wait() + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work = None + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): + recv_work.wait() + + # Compute + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + if send_work: + send_work.wait() + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last forward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss( + self._stage, output, target_mbs, fwd_mb_index + ) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + if fuse_work := _batch_p2p( + fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv" + ): + fuse_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + if fuse_work := _batch_p2p( + bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv" + ): + fuse_work.wait() + + # Now do the fwd + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss( + self._stage, output, target_mbs, fwd_mb_index + ) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): + recv_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Clear previous chunk's backward sends (hopefully they have well finished) + if send_work: + send_work.wait() + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + if send_work: + send_work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + stage_index_to_group_rank: dict[int, int] | None = None, + use_full_backward: bool | None = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + if stage_index_to_group_rank is not None: + for stage in self._stages: + stage.stage_index_to_group_rank = stage_index_to_group_rank + self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + self._stages_initialized = False + + # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle + has_loss: bool = self._loss_fn is not None + self._should_compute_loss = lambda stage: stage.is_last and has_loss + + # This will be set during init of derived schedules + self.pipeline_order: dict[int, list[_Action | None]] = {} + + if use_full_backward is not None: + logger.warning( + "Deprecation warning: 'use_full_backward' is no longer supported. " + "Simply stop passing it, and everything should still work fine." + ) + + def _initialize_stages(self, args: tuple[Any, ...], kwargs): + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: tuple[Any, ...] = () + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + + if self._has_backward: + for stage_reverse in reversed(self._stages): + stage_reverse._prepare_backward_infra(self._n_microbatches) + + self._stages_initialized = True + + def step(self, *args, target=None, losses: list | None = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list( + paddle.tensor_split(target, self._n_microbatches) + ) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage + return None + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + """ + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + + if not self._stages_initialized: + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: set[int] = set() + all_next_ranks: set[int] = set() + for stage_index in stage_index_to_stage.keys(): + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add( + self.stage_index_to_group_rank[stage_index - 1] + ) + if stage_index < self._num_stages - 1: + all_next_ranks.add( + self.stage_index_to_group_rank[stage_index + 1] + ) + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: list[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + if computation_type == _ActType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss( + stage, output, target_mbs, mb_index + ) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ActType.FULL_BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_index] += 1 + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=backward_counter[stage_index] + == self._n_microbatches, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ActType.BACKWARD_INPUT: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ActType.BACKWARD_WEIGHT: + # perform weight update + stage = stage_index_to_stage[stage_index] + backward_counter[stage_index] += 1 + stage.backward_weight_one_chunk( + mb_index, + last_backward=backward_counter[stage_index] + == self._n_microbatches, + ) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle sends for the forward from a previous rank + if computation_type == _ActType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type in ( + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + ): + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle receives for the backwards from a next rank + if computation_type in (FORWARD, BACKWARD_WEIGHT): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type in ( + BACKWARD_INPUT, + FULL_BACKWARD, + ): + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + if ops: + _batch_p2p(ops).wait() + except Exception as e: + logger.error( + "[Rank %s] pipeline schedule %s caught the following exception \ + at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + time_step, + action, + ) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: dict[int, int] = defaultdict(int) + bwd_stage_mb_index: dict[int, int] = defaultdict(int) + weight_stage_mb_index: dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[_Action | None] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ActType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ActType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action( + bwd_stage_index, + FULL_BACKWARD_OR_BACKWARD_INPUT, + bwd_mb_index, + ) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ActType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action( + bwd_stage_index, + FULL_BACKWARD_OR_BACKWARD_INPUT, + bwd_mb_index, + ) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ActType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ActType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(PipelineScheduleMulti): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = ( + step // self.microbatches_per_round + ) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) diff --git a/python/paddle/distributed/auto_parallel/pipelining/stage.py b/python/paddle/distributed/auto_parallel/pipelining/stage.py new file mode 100644 index 0000000000000..d746da37ec809 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/stage.py @@ -0,0 +1,1201 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Union + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed.auto_parallel.api import ( + dtensor_from_local, + dtensor_to_local, +) + +from ._backward import stage_backward +from .utils import ( + TensorMeta, + _detach_and_requires_grad, + _flatten_args, + _get_stage_mesh, + _map_debug_info, + _map_structure_only, + _validate_tensors_metadata, + _zero_initialize_with_meta, + map_structure, +) + +if TYPE_CHECKING: + from paddle.distributed.communication.group import Group + +logger = logging.getLogger(__name__) + + +def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: + """[Note: pipeline model output type] + + The output of the model passed to pipelining can be any type, controlled by the user. + + However, there are 2 API surfaces that complicate this. + (1) the outputs of intermediate stages are passed via Send/Recv ops to subsequent stages. The implicit assumption + is that each element of the outputs is a tensor. Otherwise, Send/Recv would not be supported. The exception + is the last layer of the model, which can output anything any which won't be communicated via Send/Recv. + (2) the outputs of the last layer of the model are returned to the user, or, passed to the loss function. + The loss function can be written in any way, such that its inputs match the outputs of the model. + + It would be convenient if we could strictly type the output signature of the pipeline stage wrapping the model, + but we do not want to impose an unnecessary constraint on user provided models. + """ + if type(output) is list: + output = tuple(output) + + # Unify output form to tuple for easy correspondence with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + return output_tuple + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensormeta: TensorMeta): + self.meta = tensormeta + + +class _RecvInfo: + """ + Represents a stage input which is DenseTensor. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: paddle.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input_name={self.input_name}, source={self.source}, buffer={self.buffer.size})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: paddle.Tensor | TensorMeta, +) -> paddle.Tensor: + """ + Create a real dense tensor from a tensor. + """ + return paddle.empty( + example.shape, + dtype=example.dtype, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements methods used by manual frontend. + """ + + def __init__( + self, + layer: paddle.nn.Layer, + stage_index: int, + num_stages: int, + group: Group | None = None, + ): + """ + Args: + layer (paddle.nn.Layer): The Layer to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + group (Group|None): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.sublayer = layer + self.stage_index = stage_index + self.num_stages = num_stages + self.group = group + + # backward state + self.backward_state: dict[int, tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: tuple[paddle.Tensor, ...] | None = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: dict[int, tuple[Any, list[paddle.Tensor]]] = {} + # map microbatch ID to list of backward grad tensor args + self.bwd_cache: dict[int, tuple[paddle.Tensor | None, ...]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: list[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.act_send_info: dict[int, list] = {} + + # Backward infra will created lazily + self.grad_recv_info: dict = {} + self.grad_send_info: list | None = None + + # To be populated later by the Schedule + self.chunks: int | None = None + # For V-style pipeline, the calculation of self.stage_index_to_group_rank is not correct here. + self.stage_index_to_group_rank: dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: tuple[paddle.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert ( + self._outputs_meta is None + ), "Attempting to reconfigure output_meta, which is not supported" + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> tuple[paddle.Tensor, ...]: + """Get the output metadata (meta tensors) representing the outputs of this stage""" + assert ( + self._outputs_meta is not None + ), "Attempted to get_outputs_meta() without configuring output meta" + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: tuple, + ) -> list[int | None]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: list[int | None] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to disgard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_structure(map_recv_to_send, args_recv_info) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int) -> tuple[Any, ...]: + raise NotImplementedError + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: tuple[InputInfo, ...], + ) -> list[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: list[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else self.group.get_global_rank(peer_rank) + ) + ops.append( + dist.P2POp( + dist.irecv, info.buffer, peer_global_rank, self.group + ) + ) + + return ops + + """[Note: V-schedule special case] + + V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + + ex: 2 ranks, 4 stages forms a simple V: + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 + + stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to + use communication ops. Instead, they should pass tensor data directly via function call. + + set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and + should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + """ + + def set_local_fwd_input( + self, prev_stage_outputs: Any, mb_index: int + ) -> None: + """ + Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids + copying tensor data or using send/recv op. Detaches original tensor and sets stop_gradient so the + tensor can serve as a leaf for autograd and gradients can be collected from it during backward. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[mb_index] + + # See [Note: pipeline model output type] + prev_stage_outputs = _normalize_model_output_as_tuple( + prev_stage_outputs + ) + + for info, tensor in zip(recv_infos, prev_stage_outputs): + assert isinstance( + tensor, paddle.Tensor + ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" + assert isinstance( + info, _RecvInfo + ), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + + info.buffer = _detach_and_requires_grad(tensor) + + def get_local_bwd_output(self, mb_index): + """ + Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. + """ + assert ( + self.has_backward + ), "can't steal_bwd_input if this stage doesn't have backward" + assert not self.is_first, "can't get bwd output if this stage is first" + + self._check_chunk_id(mb_index) + return self.bwd_cache.pop(mb_index) + + def set_local_bwd_input( + self, + next_stage_bwd_outputs: tuple[paddle.Tensor | None, ...], + mb_index: int, + ) -> None: + """ + Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. + Does not detach or set 'stop_gradient'. + """ + assert isinstance( + next_stage_bwd_outputs, tuple + ), f"Expected tuple, got {type(next_stage_bwd_outputs)}" + + assert ( + self.has_backward + ), "can't set bwd input if this stage doesn't have backward" + assert not self.is_last, "can't set bwd input if this stage is last" + recv_infos = self.grad_recv_info[mb_index] + for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + assert isinstance( + tensor, paddle.Tensor + ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" + assert isinstance( + info, _RecvInfo + ), f"Expected a recv info, got {type(info)}" + info.buffer = tensor + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output = self.output_chunks[fwd_chunk_id] + # Unify output form to tuple for easy correspondence with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + + ops: list[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size, + ) + + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else self.group.get_global_rank(peer_rank) + ) + ops.append( + dist.P2POp( + dist.isend, + ( + out + if not out.is_dist() + else dtensor_to_local( + out, + out.process_mesh, + self.grads_meta[idx].placements, + ) + ), + peer_global_rank, + self.group, + ) + ) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + self._check_chunk_id(bwd_chunk_id) + + if not self.has_backward or self.is_first: + return [] + + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # list of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info( + self.args_recv_info[0] + ) + + ops: list[dist.P2POp] = [] + grads_input = self.bwd_cache.pop(bwd_chunk_id) + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, paddle.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size, + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else self.group.get_global_rank(peer_rank) + ) + ops.append( + dist.P2POp( + dist.isend, + ( + grad + if not grad.is_dist() + else dtensor_to_local( + grad, grad.process_mesh, grad.placements + ) + ), + peer_global_rank, + self.group, + ) + ) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + + # Clear grad of input buffers in between schedule steps. This is because + # `paddle.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for ( + recv_tuple + ) in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + a.buffer.clear_grad() + + def _map_tensor_from_recv_info( + self, + recv_infos: tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + tensors = map_structure( + get_recv_tensor, + recv_infos, # type: ignore[arg-type] + ) + + return tensors + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + curr_mesh = _get_stage_mesh(self.stage_index, self.group_size) + + def restore_placements_info(args, infos): + if isinstance(args, paddle.Tensor) and infos.placements is not None: + # set the placements attribute of the Tensor + args = dtensor_from_local(args, curr_mesh, infos.placements) + return args + elif isinstance(args, (list, tuple)): + # if args is list or tuple, handle each element recursively + return type(args)( + restore_placements_info(a, i) for a, i in zip(args, infos) + ) + elif isinstance(args, dict): + # if args is dict, recursively handle each key-value pair + return { + key: restore_placements_info(args[key], infos[key]) + for key in args + } + else: + # return directly + return args + + args = restore_placements_info(args, self.inputs_meta) + + out_val = self.sublayer(*args, **kwargs) + + return out_val + + def backward_maybe_with_nosync( + self, backward_type, bwd_kwargs: dict, last_backward=False + ) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any] | None]]: + """ + PP 与 DP 混用时,在每个batch的最后一个microbatch的反向开始时,此时的一些行为可能会有所差异,此时可能需要注意。 + """ + + def perform_backward( + backward_type, + ) -> Callable[ + [], + tuple[ + tuple[paddle.Tensor | None, ...], + list[dict[str, Any] | None], + ], + ]: + if backward_type == "full": + return lambda: ( + stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ), + None, + ) + elif backward_type == "input": + raise NotImplementedError( + "Input based backward is not implemented yet." + ) + elif backward_type == "weight": + raise NotImplementedError( + "Weight based backward is not implemented yet." + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + result = perform_backward(backward_type)() + grads, param_groups = result + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. + - `args` applies to the first stage only, other stages receives args + through activation transmission. + - `kwargs` can be passed to all stages via respective `step` calls. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + + composite_kwargs = kwargs or {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync( + *composite_args, **composite_kwargs + ) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {_map_debug_info(composite_args)} + kwargs: {_map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + self.output_chunks.append(output) + + # Save activations and inputs for backward + flat_args = _flatten_args(composite_args) + flat_kwargs = _flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + _map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + + # We return the original user-provided output, not normalized to tuple. + # See [Note: pipeline model output type] + return output + + def backward_one_chunk( + self, + bwd_chunk_id: int, + loss=None, + full_backward: bool = True, + last_backward=False, + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + + last_backward is controlled by the schedule and signals synchronization of gradients across DP groups + after the last backward. + """ + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `paddle.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + grads_input: tuple[paddle.Tensor | None, ...] = () + + if full_backward: + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + else: + raise NotImplementedError( + "Input based backward is not implemented yet." + ) + + self.bwd_cache[bwd_chunk_id] = grads_input + if self.is_last and not self.is_first: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + t.detach_() + + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + return grads_input + + def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): + raise NotImplementedError( + "Weight based backward is not implemented yet." + ) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + expected_args = self.args_recv_info[0] + else: + return + + if len(kwargs): + return + + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + _validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", + expected_tensors_meta, + args, + ) + + def _validate_fwd_outputs(self, outputs: tuple[paddle.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or because + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + _validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", + expected_tensors_meta, + outputs, + ) + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. + + Args: + layer (nn.Layer): The Pypaddle module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer. + output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer. + group (Group, None): The process group for distributed training. If None, default group. + """ + + def __init__( + self, + layer: nn.Layer, + stage_index: int, + num_stages: int, + input_args: TensorMeta | tuple[TensorMeta, ...] | None = None, + output_args: TensorMeta | tuple[TensorMeta, ...] | None = None, + group: Group | None = None, + ): + super().__init__(layer, stage_index, num_stages, group) + self.inputs: list[paddle.Tensor] | None = None + self.inputs_meta: tuple[TensorMeta, ...] | None = None + # output's grad meta-info + self.grads_meta: tuple[TensorMeta, ...] | None = None + + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) + if isinstance(input_args, TensorMeta) + else input_args + ) + + assert ( + output_args is not None + ), "If passing input_args, also pass output_args to override shape inference" + self._configure_outputs_meta( + (output_args,) + if isinstance(output_args, TensorMeta) + else output_args + ) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: list[paddle.Tensor] = [] + + def stage_global_rank(peer_rank): + return ( + peer_rank + if self.group is None + else group.get_global_rank(peer_rank) + ) + + self.prev_rank = stage_global_rank( + (self.group_rank - 1) % self.group_size + ) + self.next_rank = stage_global_rank( + (self.group_rank + 1) % self.group_size + ) + + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " + f"{self.is_last=}, {self.num_stages=}, " + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] + == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = _map_structure_only( + paddle.Tensor, + lambda x: TensorMeta(x), + args, + ) + else: + assert ( + len(args) == 0 + ), "Can't supply input args for shape inference on non-first stage" + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list(objects, src=self.prev_rank, group=self.group) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + # zero-initialise tensors only for inference outputs + zero_initialize_with_meta_ = partial( + _zero_initialize_with_meta, + mesh=_get_stage_mesh(self.stage_index, self.group_size), + ) + args = _map_structure_only( + TensorMeta, + zero_initialize_with_meta_, + args, + ) + + # set attributes needed for forward + with ( + paddle.no_grad() if not self.has_backward else paddle.enable_grad() + ): + logger.debug( + "Shape inference: stage %s running forward", self.stage_index + ) + if self.has_backward: + + def requires_grad(x): + x.stop_gradient = False + return x + + args = _map_structure_only(paddle.Tensor, requires_grad, args) + + outputs = self.sublayer(*args, **kwargs) + if self.has_backward: + flatten_input_tensors = _flatten_args(args) + _flatten_args( + kwargs + ) + self.fwd_cache[0] = ( + _normalize_model_output_as_tuple(outputs), # stage_output + flatten_input_tensors, # input_values + ) + + # if single tensor, convert so it is always a list + if isinstance(outputs, paddle.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + _map_structure_only(paddle.Tensor, lambda x: TensorMeta(x), outputs) + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] + == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=self.next_rank, + group=self.group, + ) + outputs_meta = () + + return outputs_meta + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + + assert num_microbatches is not None, "num_microbatches must be provided" + + outputs: tuple[Any, ...] = () + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + + assert self.inputs_meta is not None + + for chunk_id in range(num_microbatches): + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + [ + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp), + ) + for inp in self.inputs_meta + ] + ) + # In case there is backward pass, set stop_gradient for receive buffers + if self.has_backward: + for r in recv_infos: + r.buffer.stop_gradient = False + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + [_RootArgPlaceholder(i) for i in self.inputs_meta] + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: dict[int, list] = {} + + for idx in range(len(self.get_outputs_meta())): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + return outputs + + def _shape_inference_bwd( + self, + ): + assert self.fwd_cache is not None + stage_output, input_values = self.fwd_cache.pop(0) + if ( + self.is_last + or self.stage_index_to_group_rank[self.stage_index + 1] + == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `grads`", + self.stage_index, + ) + grads = (None,) + + else: + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.recv_object_list(objects, src=self.next_rank, group=self.group) + recv_grads = objects[0] + assert isinstance(recv_grads, tuple), type(recv_grads) + grads = recv_grads + + self.grads_meta = grads + + # zero-initialize tensors only for inference backward meta-info + zero_initialize_with_meta_ = partial( + _zero_initialize_with_meta, + mesh=_get_stage_mesh(self.stage_index, self.group_size), + ) + grads = _map_structure_only( + TensorMeta, zero_initialize_with_meta_, grads + ) + + paddle.autograd.backward(stage_output, grads, True) + + # output is the grad meta for input_values(list) + output_meta = tuple( + map_structure(lambda x: TensorMeta(x.grad), input_values) + ) + + if ( + self.is_first + # if not last stage, then check if previous stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] + == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping send to previous stage", + self.stage_index, + ) + else: + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.send_object_list( + [output_meta], dst=self.prev_rank, group=self.group + ) + + def _prepare_backward_infra(self, num_microbatches: int) -> tuple[Any, ...]: + assert self.has_backward is not None + + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + grads: tuple[Any, ...] = () + if self.grads_meta is None: + self._shape_inference_bwd() + + assert self.grads_meta is not None + # clear backward_state + self.clear_runtime_states() + return grads + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + grad_recv_info: tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + [ + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta(self.get_outputs_meta()[idx]), + ) + for idx, dst_list in act_send_info.items() + ] + ) + return grad_recv_info diff --git a/python/paddle/distributed/auto_parallel/pipelining/utils.py b/python/paddle/distributed/auto_parallel/pipelining/utils.py new file mode 100644 index 0000000000000..76f05036e240a --- /dev/null +++ b/python/paddle/distributed/auto_parallel/pipelining/utils.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Tuple, Union + +import paddle +from paddle.distributed import fleet +from paddle.utils import map_structure + +logger = logging.getLogger(__name__) + + +def _detach_and_requires_grad(x): + o = x.detach() + o.stop_gradient = False + return o + + +def _detach_and_keep_grad(x): + o = x.detach_() + o.stop_gradient = x.stop_gradient + return o + + +def _zero_initialize_with_meta(meta, mesh): + assert isinstance(meta, TensorMeta) + x = paddle.zeros(meta.shape, dtype=meta.dtype) + if meta.placements: + x = paddle.distributed.shard_tensor(x, mesh, meta.placements) + return x + + +def _flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + paddle.utils.map_structure( + extract_tensor_args, + args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def _validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + + +def _validate_tensors_metadata( + desc, + expected_tensors: list[paddle.Tensor] | tuple[paddle.Tensor, ...], + actual_tensors: list[paddle.Tensor] | tuple[paddle.Tensor, ...], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + _validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +NestedStruct = Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] + + +def _map_structure_only( + type_: Any, fn: Callable[[Any], Any], structure: NestedStruct +) -> NestedStruct: + """ + Apply `fn` to each entry which matches `type_` in `structure` and return a new structure with the same shape. + """ + return map_structure( + lambda x: fn(x) if isinstance(x, type_) else x, structure + ) + + +class TensorMeta: + def __init__(self, tensor: paddle.Tensor): + self.shape = tensor.shape + self.dtype = tensor.dtype + self.placements = None if not tensor.is_dist() else tensor.placements + + def __repr__(self): + return f"TensorMeta(shape={self.shape}, dtype={self.dtype}, placements={self.placements})" + + +def _get_pp_mesh(pp_idx=0, pp_dim_names="pp"): + """ + Get the mesh of the {pp_idx}th PipelineStage. + """ + mesh = fleet.auto.get_mesh() + assert ( + mesh is not None + ), "the mesh is None, please call fleet.auto.set_mesh first." + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + else: + logger.warning( + f"The dim name of pp {pp_dim_names} not exist in global mesh {mesh}" + ) + return mesh + + +def _get_stage_mesh(stage_index, pp_group_size, style=None): + if style == "v": + raise NotImplementedError + if style is not None: + raise ValueError(f"Unknown style: {style}, style can be None, v.") + else: + + pp_idx = stage_index % pp_group_size + return _get_pp_mesh(pp_idx) + + +def _friendly_debug_info(v): + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, paddle.Tensor): + return f"Tensor({v.shape}, stop_gradient={v.stop_gradient}, dtype={v.dtype})" + else: + return str(v) + + +def _map_debug_info(a): + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return map_structure(_friendly_debug_info, a) diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index 0d2b6098648d7..767f4afe80270 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -219,7 +219,12 @@ def _packed_nest_with_indices(structure, flat, index): packed.append(_sequence_like(s, child)) index = new_index else: - packed.append(flat[index]) + # Paddle requires python version > 3.7, so dict is always OrderedDict + packed.append( + flat[index] + if not isinstance(flat, dict) + else list(flat.values())[index] + ) index += 1 return index, packed diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 1e8ca3a3b142a..b43ca38e2c740 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -180,6 +180,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_dtensor_from_local_api) py_test_modules(test_dy_local_view_compute MODULES test_dy_local_view_compute) py_test_modules(test_local_view_compute MODULES test_local_view_compute) + py_test_modules(test_PipelineStage MODULES test_PipelineStage) + py_test_modules(test_microbatch MODULES test_microbatch) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/PipelineStage_demo.py b/test/auto_parallel/PipelineStage_demo.py new file mode 100644 index 0000000000000..e4ce0af03ed20 --- /dev/null +++ b/test/auto_parallel/PipelineStage_demo.py @@ -0,0 +1,487 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import random +from collections import defaultdict +from typing import TYPE_CHECKING + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.pipelining._backward import ( + stage_backward, + stage_backward_input, + stage_backward_weight, +) +from paddle.distributed.auto_parallel.pipelining.stage import ( + PipelineStage, + _RecvInfo, +) +from paddle.distributed.auto_parallel.pipelining.utils import ( + PipeliningShapeError, + _detach_and_keep_grad, + _get_stage_mesh, + _validate_tensor_metadata, + _validate_tensors_metadata, +) +from paddle.io import Dataset + +if TYPE_CHECKING: # 添加类型检查块 + from paddle.distributed.communication.group import Group +logger = logging.getLogger(__name__) + + +def fix_seeds(seed=2025): + """Fix random seeds to ensure reproducibility""" + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + +def _batch_p2p(p2p_ops, desc=None): + # TODO(zhengtianyu): 等合入Scheduler后,删除该函数 + """Execute batch point-to-point communication operations""" + if len(p2p_ops) == 0: + return None + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops).pop() + + +def _sorted_batch_p2p(p2p_ops, desc=None): + # TODO(zhengtianyu): 等合入Scheduler后,删除该函数 + """Sort and execute batch point-to-point communication by peer rank""" + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +class MyModel(nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(8, 8, bias_attr=False) + self.linear2 = nn.Linear(8, 8, bias_attr=False) + self.linear3 = nn.Linear(8, 8, bias_attr=False) + self.linear4 = nn.Linear(8, 8, bias_attr=False) + + def forward(self, x): + if hasattr(self, 'linear1'): + x = self.linear1(x) + x = self.linear2(x) + if hasattr(self, 'linear3'): + x = self.linear3(x) + x = self.linear4(x) + return x + + +class PPMyModel(nn.Layer): + def __init__(self): + super().__init__() + self.mesh = paddle.distributed.ProcessMesh([0, 1], dim_names=["pp"]) + self.num_layers = 4 + self.num_layers_per_card = self.num_layers // 2 + + # Create layers same as MyModel + self.linears = nn.LayerList() + for i in range(self.num_layers): + linear = nn.Linear(8, 8, bias_attr=False) + + # Mark network parameters + linear.weight = dist.shard_tensor( + linear.weight, + self.get_pp_mesh(i), + [dist.Replicate()], + ) + + self.linears.append(linear) + + def get_pp_mesh(self, layer_index): + # layer_index=0-3 corresponds to mesh_idx 0,0,1,1 respectively + mesh_idx = int(layer_index / (self.num_layers / 2)) + return self.mesh[mesh_idx] + + def forward(self, x): + x.stop_gradient = False + out = x + + for i in range(self.num_layers): + # Mark intermediate variables, reshard when device switching is needed + if i % self.num_layers_per_card == 0 and i > 0: + out = dist.reshard(out, self.get_pp_mesh(i), [dist.Replicate()]) + + out = self.linears[i](out) + + return paddle.cast(out, 'float32') + + +class RandomDataset(Dataset): + def __init__(self, image_size, num_samples=1): + super().__init__() + self.image_size = image_size + self.num_samples = num_samples + + def __getitem__(self, index): + # Keep dimension as [8] + input = paddle.rand([self.image_size], dtype='float32') + label = paddle.rand([8], dtype='float32') + return input, label + + def __len__(self): + return self.num_samples + + +def manual_model_split( + model: MyModel, stage_idx: int, group: Group +) -> PipelineStage: + """Manually split model into pipeline stages""" + if stage_idx == 0: + del model.linear3 + del model.linear4 + elif stage_idx == 1: + del model.linear1 + del model.linear2 + else: + raise ValueError("Invalid stage index.") + + return PipelineStage(model, stage_idx, 2, group=group) + + +class TestPipelineStage: + @classmethod + def setUpClass(cls): + """Initialize test class setup""" + paddle.distributed.init_parallel_env() + cls.group = paddle.distributed.new_group([0, 1]) + cls.rank = dist.get_rank() + cls.mesh = paddle.distributed.ProcessMesh([0, 1], dim_names=["pp"]) + fleet.auto.set_mesh(cls.mesh) + + def test_PipelineStage(self): + """Test complete pipeline including forward, backward and model comparison""" + fix_seeds() + self.model = MyModel() + self.micro_batches = 1 # The PipelineStage component is currently tested separately, so it is set to 1, and the micro_batches > 1 scenario will be overridden when the schedule component is tested in the future + self.stage = manual_model_split(self.model, self.rank, self.group) + self.stage.has_backward = True + opt = paddle.optimizer.AdamW( + learning_rate=0.001, parameters=self.model.parameters() + ) + loss_fn = nn.MSELoss() + dataset = RandomDataset(image_size=8, num_samples=100) + + losses = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + data, label = dataset[iter_idx] + data = paddle.to_tensor(data).unsqueeze(0) + label = paddle.to_tensor(label).unsqueeze(0) + + # Prepare infrastructure + if self.rank == 0: + self.stage._prepare_forward_infra(self.micro_batches, (data,)) + else: + self.stage._prepare_forward_infra(self.micro_batches, ()) + self.stage._prepare_backward_infra(self.micro_batches) + + # Forward pass + fwd_sends_to_wait = [] + + # Receive operations + ops = self.stage.get_fwd_recv_ops(0) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + # Forward computation + output = self.stage.forward_one_chunk(0, (data,), ()) + # Send operations + ops = self.stage.get_fwd_send_ops(0) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + # Wait for all send operations to complete + for work in fwd_sends_to_wait: + work.wait() + + # Calculate loss if last stage + loss = None + if self.stage.is_last: + loss = loss_fn(output, label) + assert loss is not None + losses.append(loss.item()) + + # Backward pass + bwd_sends_to_wait = [] + + # Receive gradients + ops = self.stage.get_bwd_recv_ops(0) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + work.wait() + + # Backward computation + grads = self.stage.backward_one_chunk( + 0, loss=loss, last_backward=True + ) + assert grads is not None + + # Send gradients + ops = self.stage.get_bwd_send_ops(0) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + # Wait for all send operations to complete + for work in bwd_sends_to_wait: + work.wait() + + opt.step() + opt.clear_grad() + + return losses + + def test_pp_model(self): + """Test pipeline parallel model using MyModel""" + fix_seeds() + + pp_model = PPMyModel() + opt = paddle.optimizer.AdamW( + learning_rate=0.001, parameters=pp_model.parameters() + ) + loss_fn = nn.MSELoss() + + dataset = RandomDataset(image_size=8, num_samples=100) + + pp_losses = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + data, label = dataset[iter_idx] + data = paddle.to_tensor(data).unsqueeze(0) + label = paddle.to_tensor(label).unsqueeze(0) + + output = pp_model(data) + + loss = loss_fn(output, label) + pp_losses.append(loss.item()) + + loss.backward() + opt.step() + opt.clear_grad() + + return pp_losses + + def test_single_gpu(self): + """Test single GPU training with the complete model""" + # Only run single GPU training on rank 1 + if self.rank == 1: + fix_seeds() + single_model = MyModel() + opt = paddle.optimizer.AdamW( + learning_rate=0.001, parameters=single_model.parameters() + ) + loss_fn = nn.MSELoss() + + dataset = RandomDataset(image_size=8, num_samples=100) + + losses = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + data, label = dataset[iter_idx] + output = single_model(data) + + loss = loss_fn(output, label) + losses.append(loss.item()) + loss.backward() + + opt.step() + opt.clear_grad() + + return losses + return None + + def test_simple_func_about_schedules(self): + """Test local data transfer functions between stages on the same rank""" + if self.rank == 0: + # 1. Test set_local_fwd_input + tensor = paddle.to_tensor([1.0, 2.0, 3.0]) + stage = PipelineStage(nn.Linear(3, 3), 1, 2, group=self.group) + stage.args_recv_info[0] = (_RecvInfo("test", 0, paddle.empty([3])),) + stage.set_local_fwd_input(tensor, 0) + assert stage.args_recv_info[0][0].buffer is not None + + # 2. Test get_local_bwd_output + stage.has_backward = True + grad_tensor = paddle.to_tensor([4.0, 5.0, 6.0]) + stage.bwd_cache[0] = (grad_tensor,) + stage.chunks = 2 + bwd_output = stage.get_local_bwd_output(0) + assert bwd_output[0].equal_all(grad_tensor) + + # 3. Test set_local_bwd_input + prev_stage = PipelineStage(nn.Linear(3, 3), 0, 2, group=self.group) + prev_stage.has_backward = True + prev_stage.grad_recv_info[0] = ( + _RecvInfo("test", 1, paddle.empty([3])), + ) + grad_input = (paddle.to_tensor([7.0, 8.0, 9.0]),) + prev_stage.set_local_bwd_input(grad_input, 0) + assert prev_stage.grad_recv_info[0][0].buffer.equal_all( + grad_input[0] + ) + + def test_backward_some_simple_examples(self): + """Test simple examples in backward""" + if self.rank == 0: + # 1. Test backward propagation with dictionary and tuple outputs + input_tensor = paddle.to_tensor([1.0, 2.0], stop_gradient=False) + + output_dict = { + "out": input_tensor * 2.0, + "out_tensor_is_dict_grad_is_None": {"out": input_tensor * 2.0}, + "out_tensor_is_tuple_grad_is_None": (input_tensor * 2.0,), + } + grad_dict = { + "out": paddle.to_tensor([0.1, 0.2]), + "out_tensor_is_dict_grad_is_None": None, + "out_tensor_is_tuple_grad_is_None": None, + } + + input_grads = stage_backward(output_dict, grad_dict, [input_tensor]) + expected_grad = paddle.to_tensor([2 * 0.1, 2 * 0.2]) + + np.testing.assert_allclose( + input_grads[0].numpy(), expected_grad.numpy(), rtol=1e-5 + ) + # 2. Test not yet implemented stage_backward_input and stage_backward_weight + try: + stage_backward_input( + [input_tensor * 2.0], + [paddle.to_tensor([0.1, 0.2])], + [input_tensor], + iter([paddle.to_tensor([1.0, 1.0])]), + ) + raise AssertionError("Should raise Error") + except NotImplementedError as e: + pass + try: + stage_backward_weight( + iter([paddle.to_tensor([1.0, 1.0])]), + [{"params": [paddle.to_tensor([1.0, 1.0])]}], + ) + raise AssertionError("Should raise Error") + except NotImplementedError as e: + pass + + def test_utils_some_simple_examples(self): + """Test simple examples in utils""" + if self.rank == 0: + # 1. Test exceptions in _get_stage_mesh + try: + _get_stage_mesh(0, 2, style="v") + raise AssertionError("Should raise Error") + except NotImplementedError as e: + pass + try: + _get_stage_mesh(0, 2, style="unknown") + raise AssertionError("Should raise Error") + except ValueError as e: + pass + + # 2. Test exceptions in _validate_tensors_metadata + try: + # Length mismatch + expected = [paddle.to_tensor([1.0, 2.0])] + actual = [paddle.to_tensor([1.0]), paddle.to_tensor([2.0])] + _validate_tensors_metadata("test", expected, actual) + raise AssertionError("Should raise Error") + except PipeliningShapeError as e: + pass + + # 3. Test exceptions in _validate_tensor_metadata + try: + # Shape mismatch + expected = paddle.to_tensor([1.0, 2.0]) + actual = paddle.to_tensor([1.0]) + _validate_tensor_metadata("test", expected, actual) + raise AssertionError("Should raise Error") + except PipeliningShapeError as e: + pass + + try: + # Dtype mismatch + expected = paddle.to_tensor([1.0, 2.0], dtype='float32') + actual = paddle.to_tensor([1, 2], dtype='int32') + _validate_tensor_metadata("test", expected, actual) + raise AssertionError("Should raise Error") + except PipeliningShapeError as e: + pass + + # 4. Test _detach_and_keep_grad + a = paddle.to_tensor([2.0], stop_gradient=False) + b = a * 2 + x = _detach_and_keep_grad(b) + assert x is b + assert x.stop_gradient == b.stop_gradient + assert (x.numpy() == b.numpy()).all() + x.stop_gradient = False + z = x * 3 + z.backward() + + assert x.grad is not None + assert a.grad is None + + def run_test(self): + """Compare losses between three training methods""" + self.setUpClass() + self.test_simple_func_about_schedules() + self.test_backward_some_simple_examples() + self.test_utils_some_simple_examples() + # Run three training methods + pipeline_losses = self.test_PipelineStage() + pp_losses = self.test_pp_model() + single_losses = self.test_single_gpu() + + if self.rank == 1: + np.testing.assert_allclose( + pipeline_losses, + pp_losses, + rtol=1e-5, + ) + + np.testing.assert_allclose( + pipeline_losses, + single_losses, + rtol=1e-5, + ) + + +if __name__ == '__main__': + TestPipelineStage().run_test() diff --git a/test/auto_parallel/test_PipelineStage.py b/test/auto_parallel/test_PipelineStage.py new file mode 100644 index 0000000000000..ac0620c04ca66 --- /dev/null +++ b/test/auto_parallel/test_PipelineStage.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestPipelineStage(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(10, 20)", + "dtype": "float32", + "seeds": str(self._seeds), + "shard": "0", + } + self._changeable_envs = { + "backend": ["gpu"], + } + + def test_PipelineStage(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "PipelineStage_demo.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_microbatch.py b/test/auto_parallel/test_microbatch.py new file mode 100644 index 0000000000000..5d1a910c1c4c1 --- /dev/null +++ b/test/auto_parallel/test_microbatch.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.distributed.auto_parallel.pipelining.microbatch import ( + TensorChunkSpec, + merge_chunks, + split_args_kwargs_into_chunks, +) + + +class TestMicrobatch(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.batch_size = 8 + self.feature_size = 4 + self.tensor = paddle.randn([self.batch_size, self.feature_size]) + + def test_tensor_chunk_spec(self): + # Test creation and string representation of TensorChunkSpec + spec = TensorChunkSpec(0) + self.assertEqual(spec.split_axis, 0) + self.assertEqual(str(spec), "TensorChunkSpec(0)") + self.assertTrue("TensorChunkSpec(0)" in repr(spec)) + + def test_split_args_kwargs(self): + # Test basic parameter splitting + args = (self.tensor,) + kwargs = {"input": self.tensor} + num_chunks = 2 + + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, kwargs, num_chunks + ) + + self.assertEqual(len(args_split), num_chunks) + self.assertEqual(len(kwargs_split), num_chunks) + self.assertEqual( + args_split[0][0].shape[0], self.batch_size // num_chunks + ) + + # Test splitting with non-tensor parameters + args = (self.tensor, 42, "string") + kwargs = {"tensor": self.tensor, "number": 42} + num_chunks = 2 + + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, kwargs, num_chunks + ) + + # Verify non-tensor parameters remain unchanged in each chunk + self.assertEqual(args_split[0][1], 42) + self.assertEqual(args_split[0][2], "string") + self.assertEqual(kwargs_split[0]["number"], 42) + + # Test splitting with custom specification + tensor_2d = paddle.randn([4, 6]) + args = (tensor_2d,) + args_chunk_spec = (TensorChunkSpec(1),) # Split on second dimension + + args_split, _ = split_args_kwargs_into_chunks( + args, None, 2, args_chunk_spec + ) + + self.assertEqual(args_split[0][0].shape[1], 3) + + def test_merge_chunks(self): + # Test merging chunks + chunk1 = paddle.randn([4, 4]) + chunk2 = paddle.randn([4, 4]) + chunks = [chunk1, chunk2] + chunk_spec = [TensorChunkSpec(0)] + + merged = merge_chunks(chunks, chunk_spec) + self.assertEqual(merged.shape[0], 8) + + # Test merging chunks containing non-tensor values + chunks = [(paddle.randn([4, 4]), 42)] * 2 + chunk_spec = [TensorChunkSpec(0), None] + + merged = merge_chunks(chunks, chunk_spec) + self.assertEqual(merged[1], 42) + + # Test error cases + with self.assertRaises(ValueError): + # Test error when tensor size is smaller than number of chunks + small_tensor = paddle.randn([1, 4]) + split_args_kwargs_into_chunks((small_tensor,), None, 2) + + with self.assertRaises(AssertionError): + # Test error when parameter count doesn't match chunk_spec length + split_args_kwargs_into_chunks( + (self.tensor,), + None, + 2, + (TensorChunkSpec(0), TensorChunkSpec(1)), + ) + + # test merge empty chunks + empty_chunks = [] + result = merge_chunks(empty_chunks, None) + self.assertEqual(result, []) + + # test tensor size smaller than chunks number + small_tensor = paddle.randn([1, 4]) + with self.assertRaises(ValueError): + split_args_kwargs_into_chunks((small_tensor,), None, 2) + + # test merge non-tensor with tensor spec + chunks = [(42,), (42,)] + chunk_spec = (TensorChunkSpec(0),) + result = merge_chunks(chunks, chunk_spec) + self.assertEqual(result[0], 42) + + def test_nested_structure(self): + # test nested tensor + nested_tensor = [ + [paddle.randn([4, 2]), paddle.randn([4, 2])], + [paddle.randn([4, 2]), paddle.randn([4, 2])], + ] + + args = (nested_tensor,) + kwargs = {"nested": nested_tensor} + + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, kwargs, 2 + ) + + self.assertEqual(len(args_split), 2) + self.assertEqual(len(args_split[0][0]), 2) + self.assertEqual(len(args_split[0][0][0]), 2) + self.assertEqual(args_split[0][0][0][0].shape, [2, 2]) + + self.assertEqual(len(kwargs_split), 2) + self.assertEqual(len(kwargs_split[0]["nested"]), 2) + self.assertEqual(len(kwargs_split[0]["nested"][0]), 2) + self.assertEqual(kwargs_split[0]["nested"][0][0].shape, [2, 2]) + + merged_args = merge_chunks( + args_split, + [ + [TensorChunkSpec(0), TensorChunkSpec(0)], + [TensorChunkSpec(0), TensorChunkSpec(0)], + ], + ) + + self.assertEqual(merged_args[0][0][0].shape, [4, 2]) + self.assertEqual(merged_args[0][1][1].shape, [4, 2]) + + self.assertEqual(len(merged_args[0]), 2) + self.assertEqual(len(merged_args[0][0]), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_jit_save_load.py b/test/legacy_test/test_jit_save_load.py index cfa87feeebaa3..0efbe4aebb2c2 100644 --- a/test/legacy_test/test_jit_save_load.py +++ b/test/legacy_test/test_jit_save_load.py @@ -698,6 +698,15 @@ def dfs(obj1, obj2): dfs(nested_list_copy, nested_list_copy_pack_back) + dict_x = { + "a": paddle.to_tensor([1.0]), + "b": paddle.to_tensor([2.0]), + "c": paddle.to_tensor([3.0]), + } + dict_y = copy.deepcopy(dict_x) + dict_z = paddle.utils.pack_sequence_as(dict_x, dict_y) + dfs(dict_x, dict_z) + class TestSaveLoadWithDictInput(unittest.TestCase):