-
Notifications
You must be signed in to change notification settings - Fork 560
Open
Labels
Description
🐛 Bug
all_reduce takes much more time and memory being used with torch_xla model than the one used with torch model. With torch_xla it takes about 20 minutes for one forward-backward with DDP gradient synchronization
To Reproduce
model.py
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Positional Embedding (RoPE)"""
def __init__(self, dim: int, device: torch.device, max_seq_length: int = 2048):
super().__init__()
self.dim = dim
self.max_seq_length = max_seq_length
self.device = device
# Precompute frequencies
inv_freq = 1.0 / (
10000 ** (torch.arange(0, dim, 2, device=device).float() / dim)
)
self.register_buffer("inv_freq", inv_freq)
def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
# Generate positions [0, 1, 2, ..., seq_len-1]
t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)
# Compute frequencies [seq_len, dim//2]
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Create complex exponential form [seq_len, dim//2]
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]
emb = emb[None, :, None, :] # [1, seq_len, 1, dim]
# Apply rotary embeddings
cos = torch.cos(emb)
sin = torch.sin(emb)
# Rotate queries/keys
x_rot = x * cos + self._rotate_half(x) * sin
return x_rot
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class MLP(nn.Module):
"""Feed-forward network with GELU activation"""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.gelu(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Attention(nn.Module):
"""Multi-head self-attention with RoPE"""
def __init__(
self, hidden_size: int, num_heads: int, head_dim: int, device: torch.device
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
self.rotary_emb = RotaryPositionalEmbedding(head_dim, device)
def forward(
self,
x: torch.Tensor,
batch_size: int,
seq_len: int,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Project queries, keys, values
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# Apply rotary positional embeddings
q = self.rotary_emb(q, seq_len)
k = self.rotary_emb(k, seq_len)
# Transpose for attention computation [batch_size, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply attention mask if provided
if attention_mask is not None:
scores = scores + attention_mask
# Softmax and attention output
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# Transpose back and combine heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(
batch_size, seq_len, self.num_heads * self.head_dim
)
# Final projection
return self.o_proj(attn_output)
class TransformerBlock(nn.Module):
"""Single transformer decoder block"""
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
intermediate_size: int,
device: torch.device,
):
super().__init__()
self.self_attn = Attention(hidden_size, num_heads, head_dim, device)
self.mlp = MLP(hidden_size, intermediate_size)
self.input_layernorm = nn.RMSNorm(hidden_size)
self.post_attention_layernorm = nn.RMSNorm(hidden_size)
def forward(
self,
x: torch.Tensor,
batch_size: int,
seq_len: int,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Self-attention with residual connection
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, batch_size, seq_len, attention_mask)
x = residual + x
# MLP with residual connection
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class SmallM135M(nn.Module):
"""SmallM 135M parameter model implementation"""
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 768,
num_layers: int = 24,
num_heads: int = 12,
head_dim: int = 64,
intermediate_size: int = 2048,
max_seq_length: int = 2048,
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_length = max_seq_length
self.device = device
# Model components
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Transformer layers
self.layers = nn.ModuleList(
[
TransformerBlock(
hidden_size, num_heads, head_dim, intermediate_size, device
)
for _ in range(num_layers)
]
)
self.norm = nn.RMSNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
# Tie weights
self.lm_head.weight = self.embed_tokens.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
batch_size: int,
seq_len: int,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Create attention mask if not provided
if attention_mask is None:
attention_mask = self._create_attention_mask(seq_len)
# Embed tokens
x = self.embed_tokens(input_ids)
# Apply transformer layers
for layer in self.layers:
x = layer(x, batch_size, seq_len, attention_mask)
# Final normalization
x = self.norm(x)
# Language modeling head
logits = self.lm_head(x)
return logits
def _create_attention_mask(self, seq_len: int) -> torch.Tensor:
"""Create causal attention mask"""
mask = torch.triu(
torch.ones(seq_len, seq_len, device=self.device) * float("-inf"), diagonal=1
)
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
batch_size: int,
seq_len: int,
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
"""Simple greedy generation"""
self.eval()
current_seq_len = seq_len
current_input_ids = input_ids
for _ in range(max_length - current_seq_len):
outputs = self(current_input_ids, batch_size, current_seq_len)
logits = outputs["logits"][:, -1, :] / temperature
# Apply top-k filtering if specified
if top_k is not None:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")
# Greedy sampling
next_token = torch.argmax(logits, dim=-1, keepdim=True)
current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
current_seq_len += 1
return current_input_ids
# Example usage with wrapper for convenience
class SmallM135MWrapper(nn.Module):
"""Wrapper that automatically extracts shape parameters"""
def __init__(self, device):
super().__init__()
self.model = SmallM135M(device=device).to(device)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
batch_size, seq_len = input_ids.shape
return self.model(input_ids, batch_size, seq_len, attention_mask, labels)
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
return self.model.generate(
input_ids, batch_size, seq_len, max_length, temperature, top_k
)train_xla.py
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import os
import torch.distributed as distr
import time
import itertools
import torch
import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn
from model import SmallM135M
class TrainBaseTransformer:
def __init__(self):
self.img_dim = 224
self.batch_size = 128
self.num_steps = 300
self.num_epochs = 100
# Roughly the size of Imagenet dataset.
self.train_dataset_len = 12000
# For the purpose of this example, we are going to use fake data.
self.device = torch_xla.device()
self.model = SmallM135M().to(self.device) # Any LM model, there are no torch_xla in it
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
def _train_update(self, step, loss, tracker, epoch):
print(f"epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}")
def run_optimizer(self):
self.optimizer.step()
def loss_fn(self, out, target):
return ((out) ** 2).mean()
def step_fn(self, data, attn_mask, target, batch_size, seq_len):
self.optimizer.zero_grad()
output = self.model(data, batch_size, seq_len, attn_mask)
loss = self.loss_fn(output, target)
loss.backward()
self.sync_gradients() # Comment this out and the problem is solved but the weights will not be synchronized across machines
self.optimizer.step()
return loss
def train_loop_fn(self, epoch):
tracker = xm.RateTracker()
self.model.train()
for step in range(
self.train_dataset_len // self.batch_size // distr.get_world_size()
):
data, mask, target = (
torch.randint(0, 100, (1, 100)).to(self.device),
torch.ones((1, 100)).to(self.device),
torch.zeros((self.batch_size,)).to(self.device),
)
batch_size, seq_len = data.shape
loss = self.step_fn(data, mask, target, batch_size, seq_len)
print(f"Step: {step} Model weight sum: {self.get_weight_sum()}")
if step % 10 == 0:
self._train_update(step, loss, tracker, epoch)
def sync_gradients(self):
world_size = float(distr.get_world_size())
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
distr.all_reduce(p.grad.data, op=distr.ReduceOp.SUM)
p.grad.data /= world_size
def sync_model_parameters(self):
world_size = float(distr.get_world_size())
for p in self.model.parameters():
distr.all_reduce(p.data, op=distr.ReduceOp.SUM)
p.data /= world_size
def get_weight_sum(self):
weight_sum = 0
for p in self.model.parameters():
weight_sum += p.sum()
return weight_sum
def start_training(self, rank: int):
rank = distr.get_rank()
self.sync_model_parameters()
# self.sync_optimizer_parameters()
for epoch in range(1, self.num_epochs + 1):
print(
"Rank {} Epoch {} train begin {}".format(
rank, epoch, time.strftime("%l:%M%p %Z on %b %d, %Y")
)
)
self.train_loop_fn(epoch)
print(
"Rank {} Epoch {} Weight sum {} train end {}".format(
rank,
epoch,
self.get_weight_sum(),
time.strftime("%l:%M%p %Z on %b %d, %Y"),
)
)
def init_process_group():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
print(f"Initializing rank {rank}/{world_size - 1} via {master_addr}:{master_port}")
# Increase timeout for multi-machine setups
timeout = torch.distributed.default_pg_timeout
distr.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method=f"tcp://{master_addr}:{master_port}",
world_size=world_size,
rank=rank,
timeout=timeout,
)
def print_distributed_info():
if distr.is_initialized():
print("=== Distributed Training Info ===")
print(f"Rank: {distr.get_rank()}/{distr.get_world_size() - 1}")
print(f"Backend: {distr.get_backend()}")
print(f"Initialized: {distr.is_initialized()}")
# Try to get more details (may not be available in all backends)
try:
print(f"MPI available: {distr.is_mpi_available()}")
print(f"NCCL available: {distr.is_nccl_available()}")
print(f"GLOO available: {distr.is_gloo_available()}")
except:
pass
# Check if CUDA is available and being used
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Current CUDA device: {torch.cuda.current_device()}")
print(f"CUDA device name: {torch.cuda.get_device_name()}")
else:
print("Distributed not initialized")
if __name__ == "__main__":
# torch.manual_seed(42)
# torch.cuda.manual_seed_all(42)
init_process_group()
print_distributed_info()
base = TrainBaseTransformer()
try:
torch_xla.launch(base.start_training)
finally:
distr.destroy_process_group()Dockerfile
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y python3 python3-distutils python3-venv python3-pip && \
apt-get install -y curl wget unzip vim nano
RUN apt-get install -y libjemalloc2
RUN apt-get install -y iputils-ping net-tools
RUN echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc
RUN pip3 install torch==2.5.0 --index-url https://download.pytorch.org/whl/cu121 torchvision
RUN pip install torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whlcompose.yaml
services:
smallm_xla_train_0:
build:
dockerfile: Dockerfile
container_name: smallm_xla_train_0
shm_size: 16GB
volumes:
- ./:/app
working_dir: /app
networks:
- gpu_xla_25_network
tty: true
stdin_open: true
environment:
- RANK=0
- MASTER_ADDR=smallm_xla_train_0
- MASTER_PORT=1234
- GPU_NUM_DEVICES=1
- PJRT_DEVICE=CUDA
- WORLD_SIZE=2
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr="smallm_xla_train_0" --master_port=1234 train_xla.py
smallm_xla_train_1:
build:
dockerfile: Dockerfile
container_name: smallm_xla_train_1
shm_size: 16GB
volumes:
- ./:/app
working_dir: /app
networks:
- gpu_xla_25_network
tty: true
stdin_open: true
environment:
- RANK=1
- MASTER_ADDR=smallm_xla_train_0
- MASTER_PORT=1234
- GPU_NUM_DEVICES=1
- PJRT_DEVICE=CUDA
- WORLD_SIZE=2
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['1']
capabilities: [gpu]
command: torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr="smallm_xla_train_0" --master_port=1234 train_xla.py
networks:
gpu_xla_25_network:
name: gpu_xla_25_network
driver: bridgeSteps to reproduce the behavior:
- Go to the directory containing all the 4 files provided above
- Run
docker compose up -don a pc with at least 2 GPUs. - Run
docker attach smallm_xla_train_0and see logs
Expected behavior
Every one-two seconds I see new logs reporting that 10 forward-backward iterations are done
Environment
- Reproducible on XLA backend [CPU/TPU]: NVIDIA GPU (torch_xla 2.5.0 supports it)
- torch_xla version: 2.5.0
- GPU: 2 x RTX 4090
- CUDA 12.1
- Python 3.8