Skip to content

torch.distributed.all_reduce works extremely slow with torch_xla #9696

@Arioll

Description

@Arioll

🐛 Bug

all_reduce takes much more time and memory being used with torch_xla model than the one used with torch model. With torch_xla it takes about 20 minutes for one forward-backward with DDP gradient synchronization

To Reproduce

model.py

import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class RotaryPositionalEmbedding(nn.Module):
    """Rotary Positional Embedding (RoPE)"""

    def __init__(self, dim: int, device: torch.device, max_seq_length: int = 2048):
        super().__init__()
        self.dim = dim
        self.max_seq_length = max_seq_length
        self.device = device

        # Precompute frequencies
        inv_freq = 1.0 / (
            10000 ** (torch.arange(0, dim, 2, device=device).float() / dim)
        )
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        # Generate positions [0, 1, 2, ..., seq_len-1]
        t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)

        # Compute frequencies [seq_len, dim//2]
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)

        # Create complex exponential form [seq_len, dim//2]
        emb = torch.cat((freqs, freqs), dim=-1)  # [seq_len, dim]
        emb = emb[None, :, None, :]  # [1, seq_len, 1, dim]

        # Apply rotary embeddings
        cos = torch.cos(emb)
        sin = torch.sin(emb)

        # Rotate queries/keys
        x_rot = x * cos + self._rotate_half(x) * sin
        return x_rot

    def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)


class MLP(nn.Module):
    """Feed-forward network with GELU activation"""

    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.gelu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)


class Attention(nn.Module):
    """Multi-head self-attention with RoPE"""

    def __init__(
        self, hidden_size: int, num_heads: int, head_dim: int, device: torch.device
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device

        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

        self.rotary_emb = RotaryPositionalEmbedding(head_dim, device)

    def forward(
        self,
        x: torch.Tensor,
        batch_size: int,
        seq_len: int,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Project queries, keys, values
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Apply rotary positional embeddings
        q = self.rotary_emb(q, seq_len)
        k = self.rotary_emb(k, seq_len)

        # Transpose for attention computation [batch_size, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Apply attention mask if provided
        if attention_mask is not None:
            scores = scores + attention_mask

        # Softmax and attention output
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Transpose back and combine heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(
            batch_size, seq_len, self.num_heads * self.head_dim
        )

        # Final projection
        return self.o_proj(attn_output)


class TransformerBlock(nn.Module):
    """Single transformer decoder block"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        intermediate_size: int,
        device: torch.device,
    ):
        super().__init__()
        self.self_attn = Attention(hidden_size, num_heads, head_dim, device)
        self.mlp = MLP(hidden_size, intermediate_size)
        self.input_layernorm = nn.RMSNorm(hidden_size)
        self.post_attention_layernorm = nn.RMSNorm(hidden_size)

    def forward(
        self,
        x: torch.Tensor,
        batch_size: int,
        seq_len: int,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Self-attention with residual connection
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, batch_size, seq_len, attention_mask)
        x = residual + x

        # MLP with residual connection
        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = residual + x

        return x


class SmallM135M(nn.Module):
    """SmallM 135M parameter model implementation"""

    def __init__(
        self,
        vocab_size: int = 32000,
        hidden_size: int = 768,
        num_layers: int = 24,
        num_heads: int = 12,
        head_dim: int = 64,
        intermediate_size: int = 2048,
        max_seq_length: int = 2048,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_length = max_seq_length
        self.device = device

        # Model components
        self.embed_tokens = nn.Embedding(vocab_size, hidden_size)

        # Transformer layers
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    hidden_size, num_heads, head_dim, intermediate_size, device
                )
                for _ in range(num_layers)
            ]
        )

        self.norm = nn.RMSNorm(hidden_size)
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)

        # Tie weights
        self.lm_head.weight = self.embed_tokens.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        batch_size: int,
        seq_len: int,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = self._create_attention_mask(seq_len)

        # Embed tokens
        x = self.embed_tokens(input_ids)

        # Apply transformer layers
        for layer in self.layers:
            x = layer(x, batch_size, seq_len, attention_mask)

        # Final normalization
        x = self.norm(x)

        # Language modeling head
        logits = self.lm_head(x)

        return logits

    def _create_attention_mask(self, seq_len: int) -> torch.Tensor:
        """Create causal attention mask"""
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=self.device) * float("-inf"), diagonal=1
        )
        return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        batch_size: int,
        seq_len: int,
        max_length: int = 100,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> torch.Tensor:
        """Simple greedy generation"""
        self.eval()

        current_seq_len = seq_len
        current_input_ids = input_ids

        for _ in range(max_length - current_seq_len):
            outputs = self(current_input_ids, batch_size, current_seq_len)
            logits = outputs["logits"][:, -1, :] / temperature

            # Apply top-k filtering if specified
            if top_k is not None:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float("-inf")

            # Greedy sampling
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
            current_seq_len += 1

        return current_input_ids


# Example usage with wrapper for convenience
class SmallM135MWrapper(nn.Module):
    """Wrapper that automatically extracts shape parameters"""

    def __init__(self, device):
        super().__init__()
        self.model = SmallM135M(device=device).to(device)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        batch_size, seq_len = input_ids.shape
        return self.model(input_ids, batch_size, seq_len, attention_mask, labels)

    def generate(
        self,
        input_ids: torch.Tensor,
        max_length: int = 100,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> torch.Tensor:
        batch_size, seq_len = input_ids.shape
        return self.model.generate(
            input_ids, batch_size, seq_len, max_length, temperature, top_k
        )

train_xla.py

from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import os

import torch.distributed as distr

import time
import itertools

import torch
import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn

from model import SmallM135M


class TrainBaseTransformer:
    def __init__(self):
        self.img_dim = 224
        self.batch_size = 128
        self.num_steps = 300
        self.num_epochs = 100
        # Roughly the size of Imagenet dataset.
        self.train_dataset_len = 12000
        # For the purpose of this example, we are going to use fake data.

        self.device = torch_xla.device()
        self.model = SmallM135M().to(self.device)  # Any LM model, there are no torch_xla in it
        self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)

    def _train_update(self, step, loss, tracker, epoch):
        print(f"epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}")

    def run_optimizer(self):
        self.optimizer.step()

    def loss_fn(self, out, target):
        return ((out) ** 2).mean()

    def step_fn(self, data, attn_mask, target, batch_size, seq_len):
        self.optimizer.zero_grad()
        output = self.model(data, batch_size, seq_len, attn_mask)
        loss = self.loss_fn(output, target)
        loss.backward()
        self.sync_gradients()  # Comment this out and the problem is solved but the weights will not be synchronized across machines
        self.optimizer.step()
        return loss

    def train_loop_fn(self, epoch):
        tracker = xm.RateTracker()
        self.model.train()
        for step in range(
            self.train_dataset_len // self.batch_size // distr.get_world_size()
        ):
            data, mask, target = (
                torch.randint(0, 100, (1, 100)).to(self.device),
                torch.ones((1, 100)).to(self.device),
                torch.zeros((self.batch_size,)).to(self.device),
            )
            batch_size, seq_len = data.shape
            loss = self.step_fn(data, mask, target, batch_size, seq_len)
            print(f"Step: {step} Model weight sum: {self.get_weight_sum()}")
            if step % 10 == 0:
                self._train_update(step, loss, tracker, epoch)

    def sync_gradients(self):
        world_size = float(distr.get_world_size())
        for p in self.model.parameters():
            if p.requires_grad and p.grad is not None:
                distr.all_reduce(p.grad.data, op=distr.ReduceOp.SUM)
                p.grad.data /= world_size

    def sync_model_parameters(self):
        world_size = float(distr.get_world_size())
        for p in self.model.parameters():
            distr.all_reduce(p.data, op=distr.ReduceOp.SUM)
            p.data /= world_size

    def get_weight_sum(self):
        weight_sum = 0
        for p in self.model.parameters():
            weight_sum += p.sum()
        return weight_sum

    def start_training(self, rank: int):
        rank = distr.get_rank()
        self.sync_model_parameters()
        # self.sync_optimizer_parameters()
        for epoch in range(1, self.num_epochs + 1):
            print(
                "Rank {} Epoch {} train begin {}".format(
                    rank, epoch, time.strftime("%l:%M%p %Z on %b %d, %Y")
                )
            )
            self.train_loop_fn(epoch)
            print(
                "Rank {} Epoch {} Weight sum {} train end {}".format(
                    rank,
                    epoch,
                    self.get_weight_sum(),
                    time.strftime("%l:%M%p %Z on %b %d, %Y"),
                )
            )


def init_process_group():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    master_addr = os.environ["MASTER_ADDR"]
    master_port = os.environ["MASTER_PORT"]

    print(f"Initializing rank {rank}/{world_size - 1} via {master_addr}:{master_port}")

    # Increase timeout for multi-machine setups
    timeout = torch.distributed.default_pg_timeout

    distr.init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",
        init_method=f"tcp://{master_addr}:{master_port}",
        world_size=world_size,
        rank=rank,
        timeout=timeout,
    )


def print_distributed_info():
    if distr.is_initialized():
        print("=== Distributed Training Info ===")
        print(f"Rank: {distr.get_rank()}/{distr.get_world_size() - 1}")
        print(f"Backend: {distr.get_backend()}")
        print(f"Initialized: {distr.is_initialized()}")

        # Try to get more details (may not be available in all backends)
        try:
            print(f"MPI available: {distr.is_mpi_available()}")
            print(f"NCCL available: {distr.is_nccl_available()}")
            print(f"GLOO available: {distr.is_gloo_available()}")
        except:
            pass

        # Check if CUDA is available and being used
        if torch.cuda.is_available():
            print(f"CUDA device count: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")
            print(f"CUDA device name: {torch.cuda.get_device_name()}")
    else:
        print("Distributed not initialized")


if __name__ == "__main__":
    # torch.manual_seed(42)
    # torch.cuda.manual_seed_all(42)
    init_process_group()
    print_distributed_info()
    base = TrainBaseTransformer()
    try:
        torch_xla.launch(base.start_training)
    finally:
        distr.destroy_process_group()

Dockerfile

FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
    apt-get install -y python3 python3-distutils python3-venv python3-pip && \ 
    apt-get install -y curl wget unzip vim nano

RUN apt-get install -y libjemalloc2
RUN apt-get install -y iputils-ping net-tools

RUN echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc
RUN pip3 install torch==2.5.0 --index-url https://download.pytorch.org/whl/cu121 torchvision
RUN pip install torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl

compose.yaml

services:
  smallm_xla_train_0:
    build:
      dockerfile: Dockerfile
    container_name: smallm_xla_train_0
    shm_size: 16GB
    volumes:
      - ./:/app
    working_dir: /app
    networks:
      - gpu_xla_25_network
    tty: true
    stdin_open: true
    environment:
      - RANK=0
      - MASTER_ADDR=smallm_xla_train_0
      - MASTER_PORT=1234
      - GPU_NUM_DEVICES=1 
      - PJRT_DEVICE=CUDA
      - WORLD_SIZE=2
    deploy:
      resources:
        reservations:
          devices:
          - driver: nvidia
            device_ids: ['0']
            capabilities: [gpu]
    command: torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr="smallm_xla_train_0" --master_port=1234 train_xla.py

  smallm_xla_train_1:
    build:
      dockerfile: Dockerfile
    container_name: smallm_xla_train_1
    shm_size: 16GB
    volumes:
      - ./:/app
    working_dir: /app
    networks:
      - gpu_xla_25_network
    tty: true
    stdin_open: true
    environment:
      - RANK=1
      - MASTER_ADDR=smallm_xla_train_0
      - MASTER_PORT=1234
      - GPU_NUM_DEVICES=1 
      - PJRT_DEVICE=CUDA
      - WORLD_SIZE=2
    deploy:
      resources:
        reservations:
          devices:
          - driver: nvidia
            device_ids: ['1']
            capabilities: [gpu]
    command: torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr="smallm_xla_train_0" --master_port=1234 train_xla.py

networks:
  gpu_xla_25_network:
    name: gpu_xla_25_network
    driver: bridge

Steps to reproduce the behavior:

  1. Go to the directory containing all the 4 files provided above
  2. Run docker compose up -d on a pc with at least 2 GPUs.
  3. Run docker attach smallm_xla_train_0 and see logs

Expected behavior

Every one-two seconds I see new logs reporting that 10 forward-backward iterations are done

Environment

  • Reproducible on XLA backend [CPU/TPU]: NVIDIA GPU (torch_xla 2.5.0 supports it)
  • torch_xla version: 2.5.0
  • GPU: 2 x RTX 4090
  • CUDA 12.1
  • Python 3.8

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions