Skip to content

Commit bfbb68e

Browse files
authored
Merge pull request #93 from coreweave/es/torch-v2.7.0
feat(torch): Update to PyTorch 2.7 & CUDA 12.8.1
2 parents 85e3f72 + 106e5d6 commit bfbb68e

File tree

4 files changed

+29
-22
lines changed

4 files changed

+29
-22
lines changed

.github/configurations/torch-base.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
cuda: [ 12.8.0, 12.6.3, 12.4.1 ]
1+
cuda: [ 12.8.1, 12.6.3, 12.4.1 ]
22
os: [ ubuntu22.04 ]
3-
abi: [ 1, 0 ]
3+
abi: [ 1 ]
44
include:
5-
- torch: 2.6.0
6-
vision: 0.21.0
7-
audio: 2.6.0
5+
- torch: 2.7.0
6+
vision: 0.22.0
7+
audio: 2.7.0

.github/configurations/torch-nccl.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
cuda: [ 12.8.0, 12.6.3, 12.4.1 ]
1+
cuda: [ 12.8.1, 12.6.3, 12.4.1 ]
22
os: [ ubuntu22.04 ]
3-
abi: [ 1, 0 ]
3+
abi: [ 1 ]
44
include:
5-
- torch: 2.6.0
6-
vision: 0.21.0
7-
audio: 2.6.0
8-
nccl: 2.25.1-1
9-
nccl-tests-hash: 57fa979
5+
- torch: 2.7.0
6+
vision: 0.22.0
7+
audio: 2.7.0
8+
nccl: 2.26.5-1
9+
nccl-tests-hash: 4658d92

torch-extras/Dockerfile

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ RUN apt-get -qq update && apt-get -qq install -y \
6868
{ wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc \
6969
| gpg --dearmor -o /etc/apt/trusted.gpg.d/kitware.gpg; } && \
7070
apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" && \
71-
apt-get -qq update && apt-get -qq install -y cmake && apt-get clean
71+
apt-get -qq update && \
72+
apt-get -qq install -y 'cmake=3.31.6-*' 'cmake-data=3.31.6-*' && \
73+
apt-get clean && \
74+
python3 -m pip install --no-cache-dir 'cmake==3.31.6'
7275

7376
# Update compiler (GCC) and linker (LLD) versions
7477
# gfortran-11 is just for compiler_wrapper.f95
@@ -105,10 +108,10 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \
105108
COPY --chmod=755 effective_cpu_count.sh .
106109
COPY --chmod=755 scale.sh .
107110

108-
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=compute_90a"
111+
ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
109112
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
110113
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
111-
FLAGS="${FLAGS} -gencode=arch=compute_100a,code=sm_100a" ;; \
114+
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
112115
esac && \
113116
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
114117
ARG BUILD_MAX_JOBS

torch/Dockerfile

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# syntax=docker/dockerfile:1.7
2-
ARG BUILDER_BASE_IMAGE="nvidia/cuda:12.8.0-devel-ubuntu22.04"
3-
ARG FINAL_BASE_IMAGE="nvidia/cuda:12.8.0-base-ubuntu22.04"
2+
ARG BUILDER_BASE_IMAGE="nvidia/cuda:12.8.1-devel-ubuntu22.04"
3+
ARG FINAL_BASE_IMAGE="nvidia/cuda:12.8.1-base-ubuntu22.04"
44

5-
ARG BUILD_TORCH_VERSION="2.5.1"
6-
ARG BUILD_TORCH_VISION_VERSION="0.20.0"
7-
ARG BUILD_TORCH_AUDIO_VERSION="2.5.0"
5+
ARG BUILD_TORCH_VERSION="2.7.0"
6+
ARG BUILD_TORCH_VISION_VERSION="0.22.0"
7+
ARG BUILD_TORCH_AUDIO_VERSION="2.7.0"
88
ARG BUILD_TRANSFORMERENGINE_VERSION="1.13"
99
ARG BUILD_FLASH_ATTN_VERSION="2.7.4.post1"
1010
ARG BUILD_FLASH_ATTN_3_VERSION="2.7.2.post1"
@@ -174,7 +174,10 @@ RUN apt-get -qq update && apt-get -qq install -y \
174174
{ wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc \
175175
| gpg --dearmor -o /etc/apt/trusted.gpg.d/kitware.gpg; } && \
176176
apt-add-repository -n "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" && \
177-
apt-get -qq update && apt-get -qq install -y cmake && apt-get clean
177+
apt-get -qq update && \
178+
apt-get -qq install -y 'cmake=3.31.6-*' 'cmake-data=3.31.6-*' && \
179+
apt-get clean && \
180+
python3 -m pip install --no-cache-dir 'cmake==3.31.6'
178181

179182
RUN mkdir /tmp/ccache-install && \
180183
cd /tmp/ccache-install && \
@@ -340,7 +343,7 @@ ARG BUILD_NVCC_APPEND_FLAGS="-gencode=arch=compute_90a,code=sm_90a"
340343
# Add sm_100a build if NV_CUDA_LIB_VERSION matches 12.[89].*
341344
RUN FLAGS="$BUILD_NVCC_APPEND_FLAGS" && \
342345
case "${NV_CUDA_LIB_VERSION}" in 12.[89].*) \
343-
FLAGS="${FLAGS} -gencode=arch=compute_100a,code=sm_100a" ;; \
346+
FLAGS="${FLAGS} -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a" ;; \
344347
esac && \
345348
echo "-Wno-deprecated-gpu-targets -diag-suppress 191,186,177${FLAGS:+ $FLAGS}" > /build/nvcc.conf
346349

@@ -659,6 +662,7 @@ RUN export \
659662
cuda-nvrtc-${CUDA_PACKAGE_VERSION} \
660663
libcusparse-${CUDA_PACKAGE_VERSION} \
661664
libcusolver-${CUDA_PACKAGE_VERSION} \
665+
libcufile-${CUDA_PACKAGE_VERSION} \
662666
cuda-cupti-${CUDA_PACKAGE_VERSION} \
663667
libnvjpeg-${CUDA_PACKAGE_VERSION} \
664668
libnvtoolsext1 && \

0 commit comments

Comments
 (0)