Skip to content

Commit 775e061

Browse files
zheng-ningxinkongroowenlei-baohouqiZihengJiang
authored
Support IPC && SM90 version of AG-GEMM, GEMM-RS (bytedance#9)
* Support IPC && SM90 version of AG-GEMM, GEMM-RS Simultaneously supports IPC and NVSHMEM, allowing users to choose whether to enable NVSHMEM, and also supports two OPs of the SM90 version. Besides, Update the README accordlingly and add some performance data. --------- Co-authored-by: Chengquan Jiang <imjcqt@gmail.com> Co-authored-by: Wenlei Bao <wenlei.bao@bytedance.com> Co-authored-by: Qi Hou <houqi1993@gmail.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Xin Liu <liuxin.ai@bytedance.com> Co-authored-by: Liwen Chang <liwen.chang@bytedance.com> Co-authored-by: Haibin Lin <haibin.lin@bytedance.com>
1 parent 96b2e03 commit 775e061

File tree

119 files changed

+10236
-20580
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+10236
-20580
lines changed

CMakeLists.txt

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(BUILD_TEST ON CACHE INTERNAL "Build unit tests")
1010
set(ENABLE_NVSHMEM ON CACHE INTERNAL "Use NVSHMEM to transfer data")
1111
set(CUTLASS_TRACE OFF CACHE INTERNAL "Print CUTLASS Host Trace info")
1212
set(FLUX_DEBUG OFF CACHE INTERNAL "Define FLUX_DEBUG")
13+
OPTION(WITH_PROTOBUF "build with protobuf" OFF)
1314
message("PYTHONPATH: ${PYTHONPATH}")
1415
message("NVShmem Support: ${ENABLE_NVSHMEM}")
1516

@@ -21,6 +22,8 @@ if(CUDAToolkit_VERSION VERSION_LESS "11.0")
2122
message(FATAL_ERROR "requires cuda to be >= 11.0")
2223
elseif(CUDAToolkit_VERSION VERSION_LESS "12.0")
2324
set(CUDAARCHS "80" CACHE STRING "CUDA Architectures")
25+
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.4")
26+
set(CUDAARCHS "80;89;90" CACHE STRING "CUDA Architectures")
2427
else()
2528
set(CUDAARCHS "80;90" CACHE STRING "CUDA Architectures")
2629
endif()
@@ -143,9 +146,9 @@ set(COMMON_HEADER_DIRS
143146

144147
set(COMMON_LIB_DIRS "")
145148
list(APPEND COMMON_LIB_DIRS "${CUDAToolkit_LIBRARY_DIR}")
146-
149+
message(ENABLE_NVSHMEM "ENABLE_NVSHMEM is set to: ${ENABLE_NVSHMEM}")
147150
if(ENABLE_NVSHMEM)
148-
add_definitions(-DFLUX_USE_NVSHMEM)
151+
add_definitions(-DFLUX_SHM_USE_NVSHMEM)
149152
set(NVSHMEM_BUILD_DIR ${PROJECT_SOURCE_DIR}/3rdparty/nvshmem/build)
150153
message(STATUS "NVSHMEM build dir: ${NVSHMEM_BUILD_DIR}")
151154
if(NOT EXISTS ${NVSHMEM_BUILD_DIR})

README.md

+35-10
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,45 @@ Flux significantly can reduce latency and increase throughput for tensor paralle
99

1010
## Build
1111
```bash
12+
git clone https://github.com/bytedance/flux.git
13+
git submodule update --init --recursive
14+
# Ampere
15+
./build.sh --arch 80
16+
# Hopper
17+
./build.sh --arch 90
18+
```
19+
## Build for cross-machine TP
20+
FLUX relies on NVSHMEM for communication across nodes. Therefore, if you need support for cross-machine tensor parallelism (TP), you must manually download the NVSHMEM source code and enable the nvshmem option during compilation.
21+
22+
```bash
23+
git clone https://github.com/bytedance/flux.git
1224
# Download nvshmem-2.11(https://developer.nvidia.com/nvshmem) and place it to flux/3rdparty/nvshmem
1325
# Flux is temporarily dependent on a specific version of nvshmem (2.11).
1426
tar Jxvf nvshmem_src_2.11.0-5.txz
1527
mv nvshmem_src_2.11.0-5 ${YOUR_PATH}/flux/3rdparty/nvshmem
1628
git submodule update --init --recursive
17-
# Ampere
18-
./build.sh --arch 80
1929

30+
# Ampere
31+
./build.sh --arch 80 --nvshmem
32+
# Hopper
33+
./build.sh --arch 90 --nvshmem
2034
```
2135

2236
If you are tired of the cmake process, you can set environment variable `FLUX_BUILD_SKIP_CMAKE` to 1 to skip cmake if `build/CMakeCache.txt` already exists.
2337

2438
If you want to build a wheel package, add `--package` to the build command. find the output wheel file under dist/
2539

26-
```
40+
```bash
2741
# Ampere
2842
./build.sh --arch 80 --package
29-
```
3043

31-
For development release, run build script with `FLUX_FINAL_RELEASE=0`.
32-
33-
```
34-
# Ampere
35-
FLUX_FINAL_RELEASE=0 ./build.sh --arch 80 --package
44+
# Hopper
45+
./build.sh --arch 90 --package
3646
```
3747

48+
3849
## Run Demo
39-
```
50+
```bash
4051
# gemm only
4152
PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 --dtype=float16
4253

@@ -47,6 +58,20 @@ PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 -
4758
./scripts/launch.sh test/test_ag_kernel.py 4096 49152 12288 --dtype=float16 --iters=10
4859
```
4960

61+
## Performance
62+
We measured the examples from the above demo on both A800s and H800s. Each machine has 8 GPUs, with a TP size set to 8. The table below shows the performance comparison between flux and torch+nccl. It can be observed that by overlapping fine-grained computation and communication, Flux is able to effectively hide a significant portion of the communication time
63+
64+
| | M | K | N | Torch Gemm | Torch NCCL | Torch Total | Flux Gemm | Flux NCCL | Flux Total |
65+
|----------|----------|----------|----------|----------|----------|----------|----------|----------|-----------|
66+
| AG+Gemm(A800) | 4096 | 12288 | 49152 | 2.438ms | 0.662ms | 3.099ms | 2.378ms | 0.091ms | 2.469ms |
67+
| Gemm+RS(A800) | 4096 | 49152 | 12288 | 2.453ms | 0.646ms | 3.100ms | 2.429ms | 0.080ms | 2.508ms |
68+
| AG+Gemm(H800) | 4096 | 12288 | 49152 | 0.846ms | 0.583ms | 1.429ms | 0.814ms | 0.143ms | 0.957ms |
69+
| Gemm+RS(H800) | 4096 | 49152 | 12288 | 0.818ms | 0.590ms | 1.408ms | 0.822ms | 0.111ms | 0.932ms |
70+
71+
AG refers to AllGather.
72+
RS refers to ReduceScatter.
73+
74+
5075
## Citing
5176

5277
If you use Flux in a scientific publication, we encourage you to add the following reference

build.sh

+75-31
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22
set -x
33
set -e
44

5+
export PATH=/usr/local/cuda/bin:$PATH
6+
CMAKE=${CMAKE:-cmake}
7+
58
ARCH=""
69
BUILD_TEST="ON"
710
BDIST_WHEEL="OFF"
11+
WITH_PROTOBUF="OFF"
812
FLUX_DEBUG="OFF"
13+
ENABLE_NVSHMEM="OFF"
914

1015
function clean_py() {
1116
rm -rf build/lib.*
@@ -52,11 +57,20 @@ while [[ $# -gt 0 ]]; do
5257
;;
5358
--debug)
5459
FLUX_DEBUG="ON"
55-
shift;;
60+
shift
61+
;;
5662
--package)
5763
BDIST_WHEEL="ON"
5864
shift # Skip the argument key
5965
;;
66+
--protobuf)
67+
WITH_PROTOBUF="ON"
68+
shift
69+
;;
70+
--nvshmem)
71+
ENABLE_NVSHMEM="ON"
72+
shift
73+
;;
6074
*)
6175
# Unknown argument
6276
echo "Unknown argument: $1"
@@ -67,6 +81,7 @@ done
6781

6882
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6983
PROJECT_ROOT=${SCRIPT_DIR}
84+
PROTOBUF_ROOT=$PROJECT_ROOT/3rdparty/protobuf
7085

7186
cd ${PROJECT_ROOT}
7287

@@ -78,6 +93,20 @@ if [[ -z $JOBS ]]; then
7893
JOBS=$(nproc --ignore 2)
7994
fi
8095

96+
##### build protobuf #####
97+
function build_protobuf() {
98+
if [ $WITH_PROTOBUF == "ON" ]; then
99+
pushd $PROTOBUF_ROOT
100+
mkdir -p $PWD/build/local
101+
pushd build
102+
CFLAGS="-fPIC" CXXFLAGS="-fPIC" cmake ../cmake -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_PREFIX=$(realpath local)
103+
make -j$(nproc)
104+
make install
105+
popd
106+
popd
107+
fi
108+
}
109+
81110
function build_nccl() {
82111
pushd $NCCL_ROOT
83112
export BUILDDIR=${NCCL_ROOT}/build
@@ -108,6 +137,8 @@ function build_nccl() {
108137

109138
##### build nvshmem_bootstrap_torch #####
110139
function build_pynvshmem() {
140+
PYNVSHMEM_DIR=$PROJECT_ROOT/pynvshmem
141+
export NVSHMEM_HOME=$PROJECT_ROOT/3rdparty/nvshmem/build/src
111142
mkdir -p ${PYNVSHMEM_DIR}/build
112143

113144
pushd ${PYNVSHMEM_DIR}/build
@@ -126,11 +157,18 @@ function build_flux_cuda() {
126157
pushd build
127158
if [ ! -f CMakeCache.txt ] || [ -z ${FLUX_BUILD_SKIP_CMAKE} ]; then
128159
CMAKE_ARGS=(
129-
-DENABLE_NVSHMEM=on
160+
-DENABLE_NVSHMEM=${ENABLE_NVSHMEM}
130161
-DCUDAARCHS=${ARCH}
131162
-DCMAKE_EXPORT_COMPILE_COMMANDS=1
132163
-DBUILD_TEST=${BUILD_TEST}
133164
)
165+
if [ $WITH_PROTOBUF == "ON" ]; then
166+
CMAKE_ARGS+=(
167+
-DWITH_PROTOBUF=ON
168+
-DProtobuf_ROOT=${PROTOBUF_ROOT}/build/local
169+
-DProtobuf_PROTOC_EXECUTABLE=${PROTOBUF_ROOT}/build/local/bin/protoc
170+
)
171+
fi
134172
if [ $FLUX_DEBUG == "ON" ]; then
135173
CMAKE_ARGS+=(
136174
-DFLUX_DEBUG=ON
@@ -142,28 +180,6 @@ function build_flux_cuda() {
142180
popd
143181
}
144182

145-
function build_flux_py {
146-
LIBDIR=${PROJECT_ROOT}/python/lib
147-
mkdir -p ${LIBDIR}
148-
149-
rm -f ${LIBDIR}/libflux_cuda.so
150-
rm -f ${LIBDIR}/nvshmem_bootstrap_torch.so
151-
rm -f ${LIBDIR}/nvshmem_transport_ibrc.so.2
152-
rm -f ${LIBDIR}/libnvshmem_host.so.2
153-
pushd ${LIBDIR}
154-
cp -s ../../build/lib/libflux_cuda.so .
155-
cp -s ../../pynvshmem/build/nvshmem_bootstrap_torch.so .
156-
cp -s ../../3rdparty/nvshmem/build/src/lib/nvshmem_transport_ibrc.so.2 .
157-
cp -s ../../3rdparty/nvshmem/build/src/lib/libnvshmem_host.so.2 .
158-
popd
159-
160-
##### build flux torch bindings #####
161-
MAX_JOBS=${JOBS} python3 setup.py develop --user
162-
if [ $BDIST_WHEEL == "ON" ]; then
163-
MAX_JOBS=${JOBS} python3 setup.py bdist_wheel
164-
fi
165-
}
166-
167183
function merge_compile_commands() {
168184
if command -v ninja >/dev/null 2>&1; then
169185
# generate compile_commands.json
@@ -185,17 +201,45 @@ EOF
185201
fi
186202
}
187203

204+
function build_flux_py {
205+
LIBDIR=${PROJECT_ROOT}/python/lib
206+
rm -rf ${LIBDIR}
207+
mkdir -p ${LIBDIR}
208+
209+
# rm -f ${LIBDIR}/libflux_cuda.so
210+
# rm -f ${LIBDIR}/nvshmem_bootstrap_torch.so
211+
# rm -f ${LIBDIR}/nvshmem_transport_ibrc.so.2
212+
# rm -f ${LIBDIR}/libnvshmem_host.so.2
213+
pushd ${LIBDIR}
214+
cp -s ../../build/lib/libflux_cuda.so .
215+
if [ $ENABLE_NVSHMEM == "ON" ]; then
216+
cp -s ../../pynvshmem/build/nvshmem_bootstrap_torch.so .
217+
cp -s ../../3rdparty/nvshmem/build/src/lib/nvshmem_transport_ibrc.so.2 .
218+
cp -s ../../3rdparty/nvshmem/build/src/lib/libnvshmem_host.so.2 .
219+
export FLUX_SHM_USE_NVSHMEM=1
220+
fi
221+
popd
222+
##### build flux torch bindings #####
223+
MAX_JOBS=${JOBS} python3 setup.py develop --user
224+
if [ $BDIST_WHEEL == "ON" ]; then
225+
MAX_JOBS=${JOBS} python3 setup.py bdist_wheel
226+
fi
227+
merge_compile_commands
228+
}
229+
188230
NCCL_ROOT=$PROJECT_ROOT/3rdparty/nccl
189231
build_nccl
190232

191-
./build_nvshmem.sh ${build_args} --jobs ${JOBS}
192233

193-
export PATH=/usr/local/cuda/bin:$PATH
194-
CMAKE=${CMAKE:-cmake}
195-
PYNVSHMEM_DIR=$PROJECT_ROOT/pynvshmem
196-
export NVSHMEM_HOME=$PROJECT_ROOT/3rdparty/nvshmem/build/src
234+
if [ $ENABLE_NVSHMEM == "ON" ]; then
235+
./build_nvshmem.sh ${build_args} --jobs ${JOBS}
236+
fi
237+
238+
build_protobuf
239+
240+
if [ $ENABLE_NVSHMEM == "ON" ]; then
241+
build_pynvshmem
242+
fi
197243

198-
build_pynvshmem
199244
build_flux_cuda
200245
build_flux_py
201-
merge_compile_commands

gen_version.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import argparse
19+
import os
20+
import subprocess
21+
from pathlib import Path
22+
import shutil
23+
import re
24+
from typing import Optional, Tuple
25+
26+
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
27+
28+
29+
def _check_env_option(opt, default=""):
30+
return os.getenv(opt, default).upper() in ["ON", "1", "YES", "TRUE"]
31+
32+
33+
def check_final_release():
34+
return _check_env_option("FLUX_FINAL_RELEASE", "1")
35+
36+
37+
def get_git_commit(src_dir):
38+
try:
39+
return (
40+
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=src_dir)
41+
.decode("ascii")
42+
.strip()
43+
)
44+
except Exception:
45+
return "unknown"
46+
47+
48+
def cuda_version() -> Tuple[int, ...]:
49+
"""CUDA Toolkit version as a (major, minor) by nvcc --version"""
50+
51+
# Try finding NVCC
52+
nvcc_bin: Optional[Path] = None
53+
if nvcc_bin is None and os.getenv("CUDA_HOME"):
54+
# Check in CUDA_HOME
55+
cuda_home = Path(os.getenv("CUDA_HOME"))
56+
nvcc_bin = cuda_home / "bin" / "nvcc"
57+
if nvcc_bin is None:
58+
# Check if nvcc is in path
59+
nvcc_bin = shutil.which("nvcc")
60+
if nvcc_bin is not None:
61+
nvcc_bin = Path(nvcc_bin)
62+
if nvcc_bin is None:
63+
# Last-ditch guess in /usr/local/cuda
64+
cuda_home = Path("/usr/local/cuda")
65+
nvcc_bin = cuda_home / "bin" / "nvcc"
66+
if not nvcc_bin.is_file():
67+
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
68+
69+
# Query NVCC for version info
70+
output = subprocess.run(
71+
[nvcc_bin, "-V"],
72+
capture_output=True,
73+
check=True,
74+
universal_newlines=True,
75+
)
76+
match = re.search(r"release\s*([\d.]+)", output.stdout)
77+
version = match.group(1).split(".")
78+
return tuple(int(v) for v in version)
79+
80+
81+
def get_flux_version(version_txt, *, dev=False):
82+
with open(version_txt) as f:
83+
version = f.readline()
84+
version = version.strip()
85+
cuda_version_major, cuda_version_minor = cuda_version()
86+
version = version + f"+cu{cuda_version_major}{cuda_version_minor}"
87+
if dev:
88+
commit_id = get_git_commit(CUR_DIR)
89+
90+
version += ".dev{}".format(commit_id[:8])
91+
# version = version + (f'.{os.getenv("ARCH")}' if os.getenv("ARCH") else "")
92+
return version
93+
94+
95+
def generate_versoin_file(version_txt, version_file, *, dev=False):
96+
flux_ver = get_flux_version(version_txt, dev=dev)
97+
98+
with open(version_file, "w") as f:
99+
f.write("__version__ = '{}'\n".format(flux_ver))
100+
f.write("git_version = {}\n".format(repr(get_git_commit(CUR_DIR))))
101+
cuda_version_major, cuda_version_minor = cuda_version()
102+
f.write("cuda = {}.{}\n".format(cuda_version_major, cuda_version_minor))
103+
104+
return flux_ver
105+
106+
107+
if __name__ == "__main__":
108+
parser = argparse.ArgumentParser(description="generate version.py")
109+
parser.add_argument("--input", type=str, required=True)
110+
parser.add_argument("--output", type=str, required=True)
111+
parser.add_argument("--dev", action="store_true")
112+
args = parser.parse_args()
113+
114+
generate_versoin_file(args.input, args.output, dev=args.dev)

0 commit comments

Comments
 (0)