Skip to content

SIGSEGV crash during multi-host training on TPU VM (v4-32) #9677

@Locutusque

Description

@Locutusque

🐛 Bug

I’m encountering a segmentation fault (SIGSEGV) crash when running multi-host / multi-worker training on a TPU VM (v4-32) setup. The error is vague and I can’t pinpoint whether it’s a bug in Torch XLA or a misuse in my training setup. Because documentation on multi-host training is limited, I wanted to raise an issue here to get help (or confirm a bug).

I have tested with both torch_xla 2.6.0 and 2.8.0.

The crash happens in PjRtComputationClient::ExecuteReplicated() during training.

There is also a memory allocation warning (large alloc) before the crash.

The training seems to “desynchronize” (some hosts lag, some hosts crash) despite placing rendezvous/synchronization logic in the script.

Steps to Reproduce

Below is the sequence I used to set up and launch the run. I’ll include the training script (or at least the relevant snippet) in a gist / attachment when filing.

  1. Create a TPU VM with TPU v4-32, using software stack v2-alpha-tpuv4-pod. Name it node-1, located in zone us-central2-b.

  2. SSH (all workers) and install environment:

gcloud compute tpus tpu-vm ssh node-1 --zone=us-central2-b --worker=all --strict-host-key-checking=no --command="wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && bash ~/miniconda.sh -b -p ~/miniconda && rm ~/miniconda.sh && ~/miniconda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && ~/miniconda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && ~/miniconda/bin/conda create -n py310 python=3.10 -y && source ~/miniconda/bin/activate py310 && pip install transformers zstandard jsonlines peft wandb bitsandbytes accelerate datasets sentencepiece langchain && pip install --upgrade torch==2.6.0 'torch_xla[tpu]==2.6.0' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html && pip install 'torch_xla[pallas]' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html && pip uninstall -y tensorflow && pip install tensorflow-cpu && cd /tmp && git clone https://github.com/IsNoobgrammer/Pytorch-Optimizers optims && cp -r optims ~/optims”
  1. Copy the training script (repro.py) to each worker.

  2. Launch the training across workers:

gcloud compute tpus tpu-vm ssh node-1 \
  --zone=us-central2-b \
  --worker=all \
  --strict-host-key-checking=no \
  --command="export PJRT_DEVICE=TPU && \
  source ~/miniconda/bin/activate && \
  python repro.py"
  1. The crash manifests with a stacktrace like this:
*** SIGSEGV (@0x208), see go/stacktraces#s15 received by PID 20977 (TID 22268) on cpu 77; stack trace: ***
PC: @     0x7f6fafd71651  (unknown)  torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()::{lambda()#1}::operator()()
    @     0x7f6f246a7a01       1888  (unknown)
    @     0x7f711f5e43c0       1936  (unknown)
    @     0x7f6fb9c7464e         32  std::_Function_handler<>::_M_invoke()
    @     0x7f6fb09a7c32        304  Eigen::ThreadPoolDevice::parallelFor()
    @     0x7f6fb9c77464        576  tsl::thread::ThreadPool::ParallelFor()
    @     0x7f6fb087f438        928  torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
    @     0x7f6fb05c3593        624  torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
    @     0x7f70645c93fa  (unknown)  torch::lazy::MultiWait::Complete()
    @        0x100000000  (unknown)  (unknown)
… (more frames) …

There is also memory allocation warnings before the crash, e.g.:

tcmalloc: large alloc 1555824640 bytes == 0x1dfd60000 @ …

Expected Behavior

The multi-host training run should proceed without a crash.

Each worker should stay synchronized (i.e. no drifting, missing rendezvous).

If there’s an internal bug in Torch XLA’s multi-host replication logic, I hope this issue helps reveal it and get resolved.

Additional Context

During program execution, I observed signs of desynchronization—some worker ranks lag or stall, even though I included rendezvous/synchronization calls in the script.

Reproduction code:

https://gist.github.com/Locutusque/1cd36c96609a8aff3422f78ee69591f7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdistributedSPMD and other distributed things.xla:tpuTPU specific issues and PRs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions