|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor |
| 5 | +from torch.distributed import ProcessGroup |
| 6 | + |
| 7 | +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for |
| 8 | +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent |
| 9 | +# version of PyTorch. The following 4 lines are for backward compatibility with |
| 10 | +# older PyTorch. |
| 11 | +if "all_gather_into_tensor" not in dir(torch.distributed): |
| 12 | + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
| 13 | +if "reduce_scatter_tensor" not in dir(torch.distributed): |
| 14 | + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base |
| 15 | + |
| 16 | + |
| 17 | +# Raw operation, does not support autograd, but does support async |
| 18 | +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| 19 | + world_size = torch.distributed.get_world_size(process_group) |
| 20 | + output = torch.empty( |
| 21 | + world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device |
| 22 | + ) |
| 23 | + handle = torch.distributed.all_gather_into_tensor( |
| 24 | + output, input_.contiguous(), group=process_group, async_op=async_op |
| 25 | + ) |
| 26 | + return output, handle |
| 27 | + |
| 28 | + |
| 29 | +# Raw operation, does not support autograd, but does support async |
| 30 | +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| 31 | + world_size = torch.distributed.get_world_size(process_group) |
| 32 | + assert input_.shape[0] % world_size == 0 |
| 33 | + output = torch.empty( |
| 34 | + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device |
| 35 | + ) |
| 36 | + handle = torch.distributed.reduce_scatter_tensor( |
| 37 | + output, input_.contiguous(), group=process_group, async_op=async_op |
| 38 | + ) |
| 39 | + return output, handle |
| 40 | + |
| 41 | + |
| 42 | +# Raw operation, does not support autograd, but does support async |
| 43 | +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| 44 | + input_ = input_.contiguous() |
| 45 | + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) |
| 46 | + return input_, handle |
| 47 | + |
| 48 | + |
| 49 | +class AllGatherFunc(torch.autograd.Function): |
| 50 | + """Gather the input from sequence parallel region and concatenate.""" |
| 51 | + |
| 52 | + @staticmethod |
| 53 | + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| 54 | + ctx.process_group = process_group |
| 55 | + output, _ = all_gather_raw(input_, process_group) |
| 56 | + return output |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def backward(ctx, grad_output: Tensor): |
| 60 | + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) |
| 61 | + return grad_input, None |
| 62 | + |
| 63 | + |
| 64 | +# Supports autograd, but does not support async |
| 65 | +all_gather = AllGatherFunc.apply |
| 66 | + |
| 67 | + |
| 68 | +class ReduceScatterFunc(torch.autograd.Function): |
| 69 | + """Reduce scatter the input from the sequence parallel region and concatenate.""" |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| 73 | + ctx.process_group = process_group |
| 74 | + output, _ = reduce_scatter_raw(input_, process_group) |
| 75 | + return output |
| 76 | + |
| 77 | + @staticmethod |
| 78 | + def backward(ctx, grad_output: Tensor): |
| 79 | + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) |
| 80 | + return grad_input, None |
| 81 | + |
| 82 | + |
| 83 | +# Supports autograd, but does not support async |
| 84 | +reduce_scatter = ReduceScatterFunc.apply |
| 85 | + |
| 86 | + |
| 87 | +class AllReduceFunc(torch.autograd.Function): |
| 88 | + """Gather the input from sequence parallel region and concatenate.""" |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| 92 | + ctx.process_group = process_group |
| 93 | + output, _ = all_reduce_raw(input_, process_group) |
| 94 | + return output |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def backward(ctx, grad_output: Tensor): |
| 98 | + return grad_output, None |
| 99 | + |
| 100 | + |
| 101 | +# Supports autograd, but does not support async |
| 102 | +all_reduce = AllReduceFunc.apply |
| 103 | + |
| 104 | + |
| 105 | +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): |
| 106 | + # We want to iterate over parameters with _shared_params=True in the same order, |
| 107 | + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). |
| 108 | + pamams_shared = { |
| 109 | + name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) |
| 110 | + } |
| 111 | + for _, p in sorted(pamams_shared.items()): |
| 112 | + with torch.no_grad(): |
| 113 | + # Broadcast needs src to be global rank, not group rank |
| 114 | + torch.distributed.broadcast( |
| 115 | + p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 |
| 120 | +def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): |
| 121 | + # We want to iterate over parameters with _sequence_parallel=True in the same order, |
| 122 | + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). |
| 123 | + params_seqparallel = { |
| 124 | + name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) |
| 125 | + } |
| 126 | + grads = [p.grad for _, p in sorted(params_seqparallel.items())] |
| 127 | + if grads: |
| 128 | + with torch.no_grad(): |
| 129 | + coalesced = torch._utils._flatten_dense_tensors(grads) |
| 130 | + torch.distributed.all_reduce(coalesced, group=process_group) |
| 131 | + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): |
| 132 | + buf.copy_(synced) |
| 133 | + |
| 134 | + |
| 135 | +def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: |
| 136 | + """Get the dim for the local rank derived from splitting dim on world_size processes. |
| 137 | +
|
| 138 | + The split may not be even across the world_size processes. |
| 139 | + """ |
| 140 | + multiple = dim // multiple_of |
| 141 | + div = multiple // world_size |
| 142 | + mod = multiple % world_size |
| 143 | + local_multiple = div + int(local_rank < mod) |
| 144 | + return local_multiple * multiple_of |
0 commit comments