-
Notifications
You must be signed in to change notification settings - Fork 560
Description
🐛 Bug
I’m encountering a segmentation fault (SIGSEGV) crash when running multi-host / multi-worker training on a TPU VM (v4-32) setup. The error is vague and I can’t pinpoint whether it’s a bug in Torch XLA or a misuse in my training setup. Because documentation on multi-host training is limited, I wanted to raise an issue here to get help (or confirm a bug).
I have tested with both torch_xla 2.6.0 and 2.8.0.
The crash happens in PjRtComputationClient::ExecuteReplicated() during training.
There is also a memory allocation warning (large alloc) before the crash.
The training seems to “desynchronize” (some hosts lag, some hosts crash) despite placing rendezvous/synchronization logic in the script.
Steps to Reproduce
Below is the sequence I used to set up and launch the run. I’ll include the training script (or at least the relevant snippet) in a gist / attachment when filing.
-
Create a TPU VM with TPU v4-32, using software stack v2-alpha-tpuv4-pod. Name it node-1, located in zone us-central2-b.
-
SSH (all workers) and install environment:
gcloud compute tpus tpu-vm ssh node-1 --zone=us-central2-b --worker=all --strict-host-key-checking=no --command="wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && bash ~/miniconda.sh -b -p ~/miniconda && rm ~/miniconda.sh && ~/miniconda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && ~/miniconda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && ~/miniconda/bin/conda create -n py310 python=3.10 -y && source ~/miniconda/bin/activate py310 && pip install transformers zstandard jsonlines peft wandb bitsandbytes accelerate datasets sentencepiece langchain && pip install --upgrade torch==2.6.0 'torch_xla[tpu]==2.6.0' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html && pip install 'torch_xla[pallas]' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html && pip uninstall -y tensorflow && pip install tensorflow-cpu && cd /tmp && git clone https://github.com/IsNoobgrammer/Pytorch-Optimizers optims && cp -r optims ~/optims”
-
Copy the training script (repro.py) to each worker.
-
Launch the training across workers:
gcloud compute tpus tpu-vm ssh node-1 \
--zone=us-central2-b \
--worker=all \
--strict-host-key-checking=no \
--command="export PJRT_DEVICE=TPU && \
source ~/miniconda/bin/activate && \
python repro.py"
- The crash manifests with a stacktrace like this:
*** SIGSEGV (@0x208), see go/stacktraces#s15 received by PID 20977 (TID 22268) on cpu 77; stack trace: ***
PC: @ 0x7f6fafd71651 (unknown) torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()::{lambda()#1}::operator()()
@ 0x7f6f246a7a01 1888 (unknown)
@ 0x7f711f5e43c0 1936 (unknown)
@ 0x7f6fb9c7464e 32 std::_Function_handler<>::_M_invoke()
@ 0x7f6fb09a7c32 304 Eigen::ThreadPoolDevice::parallelFor()
@ 0x7f6fb9c77464 576 tsl::thread::ThreadPool::ParallelFor()
@ 0x7f6fb087f438 928 torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
@ 0x7f6fb05c3593 624 torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
@ 0x7f70645c93fa (unknown) torch::lazy::MultiWait::Complete()
@ 0x100000000 (unknown) (unknown)
… (more frames) …
There is also memory allocation warnings before the crash, e.g.:
tcmalloc: large alloc 1555824640 bytes == 0x1dfd60000 @ …
Expected Behavior
The multi-host training run should proceed without a crash.
Each worker should stay synchronized (i.e. no drifting, missing rendezvous).
If there’s an internal bug in Torch XLA’s multi-host replication logic, I hope this issue helps reveal it and get resolved.
Additional Context
During program execution, I observed signs of desynchronization—some worker ranks lag or stall, even though I included rendezvous/synchronization calls in the script.
Reproduction code:
https://gist.github.com/Locutusque/1cd36c96609a8aff3422f78ee69591f7