Skip to content

Commit 322710d

Browse files
feat: fix tuning for the all-gather gemm && move the reset-signal() to the forward critical path (bytedance#19)
* Change the default tile config for the all-gather gemm v2 * fix tuning for the all-gather and add the tuning scripts * mv the reset-signal() to the critical forward path of ag-gemm There is no need for the users to reset-signal manually. * gemm-rs tunning still have some issues, remote it Temporarily
1 parent 9eb16a9 commit 322710d

File tree

4 files changed

+241
-5
lines changed

4 files changed

+241
-5
lines changed

src/all_gather/gemm_v2_ag_kernel.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,12 @@ struct GemmV2AGKernel_Space : OpSpaceBase<GemmV2AGKernel_Space> {
182182
make_gemm_v2_hparams(Shape<_64, _64, _32>{}, Shape<_16, _8, _16>{}, _StreamkDP{})),
183183
cute::make_tuple(Auto{}),
184184
cute::make_tuple(
185+
Shape<_128, _128, _64>{},
186+
Shape<_128, _128, _32>{},
185187
Shape<_64, _128, _32>{},
186188
Shape<_64, _128, _64>{},
187189
Shape<_64, _256, _32>{},
188190
Shape<_64, _256, _64>{},
189-
Shape<_128, _128, _32>{},
190-
Shape<_128, _128, _64>{},
191191
Shape<_128, _256, _32>{},
192192
Shape<_256, _128, _32>{}),
193193
cute::make_tuple(Auto{}),

src/all_gather/ths_op/all_gather_gemm_kernel.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ class AGKernel : public torch::CustomClassHolder {
651651

652652
this->chunk_size = input.numel() * input.element_size();
653653
this->split_chunk_size = this->chunk_size / SPLIT;
654-
return forward_impl(
654+
auto result = forward_impl(
655655
std::move(input),
656656
std::move(weight),
657657
std::move(bias),
@@ -660,6 +660,8 @@ class AGKernel : public torch::CustomClassHolder {
660660
std::move(output_scale),
661661
fast_accum,
662662
c10::nullopt);
663+
this->reset_signals(); // clear the signals at the end of the forward
664+
return result;
663665
}
664666

665667
torch::Tensor
@@ -678,14 +680,14 @@ class AGKernel : public torch::CustomClassHolder {
678680
#else
679681
flux_barrier_all_on_stream(current_stream, this->sync_buffers, this->rank);
680682
#endif
681-
c10::cuda::stream_synchronize(current_stream);
682683

683684
this->barrier_buffer.zero_();
684685
}
685686

686687
void
687688
copy_local(const torch::Tensor &input) {
688689
this->chunk_size = input.numel() * input.element_size();
690+
this->split_chunk_size = this->chunk_size / SPLIT;
689691
cudaStream_t current_stream = c10::cuda::getCurrentCUDAStream();
690692
void *input_ptr = input.data_ptr();
691693
void *input_buffer_ptr = this->input_buffer.data_ptr();

test/test_ag_kernel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def perf_flux(
269269

270270
torch.distributed.barrier()
271271
for i in range(total_iters):
272-
all_gather_gemm_kernel.reset_signals()
272+
# all_gather_gemm_kernel.reset_signals() # move to the critical path, no need to reset the signal manually
273273
if local_copy:
274274
all_gather_gemm_kernel.copy_local(input)
275275
start_events[i].record()

tools/tune_ag_gemm_kernel.py

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
################################################################################
2+
#
3+
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import argparse
19+
from functools import partial
20+
import os
21+
import torch
22+
import time
23+
import datetime
24+
import numpy as np
25+
import flux
26+
from typing import List
27+
import itertools
28+
import dataclasses
29+
from flux import pynvshmem
30+
31+
RANK = int(os.environ.get("RANK", 0))
32+
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
33+
LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
34+
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
35+
NNODES = WORLD_SIZE // LOCAL_WORLD_SIZE
36+
torch.cuda.set_device(LOCAL_RANK)
37+
38+
os.environ["NCCL_DEBUG"] = "ERROR"
39+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
40+
torch.use_deterministic_algorithms(True, warn_only=True)
41+
torch.set_printoptions(precision=8)
42+
torch.manual_seed(3 + RANK)
43+
torch.cuda.manual_seed_all(3 + RANK)
44+
torch.backends.cudnn.deterministic = True
45+
torch.backends.cudnn.benchmark = False
46+
torch.backends.cuda.matmul.allow_tf32 = False
47+
np.random.seed(3 + RANK)
48+
49+
torch.distributed.init_process_group(
50+
backend="nccl", world_size=WORLD_SIZE, rank=RANK, timeout=datetime.timedelta(seconds=1800)
51+
)
52+
# use all ranks as tp group
53+
TP_GROUP = torch.distributed.new_group(ranks=list(range(WORLD_SIZE)), backend="nccl")
54+
print = partial(print, flush=True)
55+
56+
57+
@dataclasses.dataclass
58+
class TuningConfig:
59+
M: int
60+
N: int
61+
K: int
62+
transpose_weight: bool
63+
dtype: str
64+
has_bias: bool
65+
66+
67+
def gen_tuning_space():
68+
space: List[TuningConfig] = []
69+
## Training shapes
70+
# space_M = [4096]
71+
# space_N = [10240, 24576, 8192, 57344, 28672]
72+
# space_K = [8192]
73+
space_M = [64, 256, 512, 1024, 2048, 4096, 8192]
74+
space_N = [16384, 49152]
75+
space_K = [12288]
76+
space_transpose_weight = [True, False]
77+
space_dtype = [torch.bfloat16, torch.float16]
78+
space_has_bias = [False, True]
79+
for M, N, K, transpose_weight, dtype, has_bias in itertools.product(
80+
space_M,
81+
space_N,
82+
space_K,
83+
space_transpose_weight,
84+
space_dtype,
85+
space_has_bias,
86+
):
87+
config = TuningConfig(
88+
M=M,
89+
N=N,
90+
K=K,
91+
transpose_weight=transpose_weight,
92+
dtype=dtype,
93+
has_bias=has_bias,
94+
)
95+
space.append(config)
96+
return space
97+
98+
99+
def get_torch_output(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
100+
local_M = input.size(0)
101+
M = local_M * TP_GROUP.size()
102+
103+
torch.distributed.barrier()
104+
full_input = torch.zeros(
105+
(M, input.size(1)),
106+
dtype=input.dtype,
107+
device=torch.cuda.current_device(),
108+
requires_grad=False,
109+
)
110+
torch.distributed.all_gather_into_tensor(full_input, input, group=TP_GROUP)
111+
output = torch.matmul(full_input, weight.t())
112+
if bias is not None:
113+
output += bias
114+
torch.distributed.barrier()
115+
return output
116+
117+
118+
def run_flux_profiling(
119+
prof_ctx: flux.ProfilingContext,
120+
input: torch.Tensor,
121+
weight: torch.Tensor,
122+
bias: torch.Tensor,
123+
config: TuningConfig,
124+
):
125+
local_M = input.size(0)
126+
M = local_M * TP_GROUP.size()
127+
K = input.size(1)
128+
129+
if config.transpose_weight:
130+
w = weight.t().contiguous()
131+
N = w.size(1)
132+
else:
133+
w = weight
134+
N = w.size(0)
135+
136+
ag_gemm_kernel = flux.AGKernel(
137+
TP_GROUP,
138+
NNODES,
139+
M,
140+
N,
141+
K,
142+
input.dtype,
143+
transpose_weight=config.transpose_weight,
144+
local_copy=True,
145+
)
146+
147+
output = ag_gemm_kernel.profiling(input, w, bias=bias, prof_ctx=prof_ctx)
148+
return output
149+
150+
151+
def tune_one_config(prof_ctx: flux.ProfilingContext, config: TuningConfig):
152+
assert config.M % TP_GROUP.size() == 0
153+
assert config.N % TP_GROUP.size() == 0
154+
local_M = config.M // TP_GROUP.size()
155+
local_N = config.N // TP_GROUP.size()
156+
157+
# input: [M, K], weight: [N, K]
158+
input = (
159+
torch.rand((local_M, config.K), dtype=config.dtype).cuda() / 100 * ((TP_GROUP.rank() + 1))
160+
)
161+
weight = (
162+
torch.rand((local_N, config.K), dtype=config.dtype).cuda() / 100 * ((TP_GROUP.rank() + 1))
163+
)
164+
165+
bias = None
166+
if config.has_bias:
167+
bias = torch.rand((config.M, local_N), dtype=input.dtype, device=input.device)
168+
169+
torch_output = get_torch_output(input, weight, bias)
170+
torch.distributed.barrier()
171+
flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config)
172+
torch.distributed.barrier()
173+
torch.cuda.current_stream().synchronize()
174+
175+
if config.dtype == torch.bfloat16:
176+
atol, rtol = 0.02, 0.02
177+
else:
178+
atol, rtol = 0.01, 0.01
179+
180+
for rank in range(WORLD_SIZE):
181+
if rank == RANK:
182+
flux.torch_allclose(flux_output.cpu(), torch_output.cpu(), atol=atol, rtol=rtol)
183+
torch.distributed.barrier()
184+
185+
time.sleep(1)
186+
187+
188+
def parse_args():
189+
parser = argparse.ArgumentParser()
190+
parser.add_argument(
191+
"--output_dir", default="", type=str, help="Directory to store generated files"
192+
)
193+
return parser.parse_args()
194+
195+
196+
if __name__ == "__main__":
197+
args = parse_args()
198+
if args.output_dir and not os.path.isdir(args.output_dir):
199+
raise Exception(f"{args.output_dir} not exist")
200+
201+
torch.cuda.set_device(LOCAL_RANK)
202+
flux.init_flux_shm(TP_GROUP)
203+
torch.cuda.synchronize()
204+
205+
arch: int = flux.get_arch()
206+
name: str = f"config_ag_gemm_kernel_sm{arch}_tp{WORLD_SIZE}_nnodes{NNODES}"
207+
prof_ctx = flux.ProfilingContext(name)
208+
config_space = gen_tuning_space()
209+
for i, config in enumerate(config_space):
210+
if TP_GROUP.rank() == 0:
211+
print(f"==== #{i + 1}/{len(config_space)} Tuning for {config}")
212+
tune_one_config(prof_ctx=prof_ctx, config=config)
213+
if TP_GROUP.rank() == 0:
214+
print(prof_ctx.get_latest_prof_result())
215+
216+
if TP_GROUP.rank() == 0:
217+
if os.path.isdir(args.output_dir):
218+
code_path = os.path.join(args.output_dir, f"{name}.cu")
219+
result_path = os.path.join(args.output_dir, f"{name}.prof.log")
220+
221+
with open(code_path, "w") as fout:
222+
print(prof_ctx.get_code(), file=fout)
223+
224+
with open(result_path, "w") as fout:
225+
for record in prof_ctx.get_all_prof_results():
226+
print(record, file=fout)
227+
else:
228+
print("Generated Code:")
229+
print(prof_ctx.get_code())
230+
print()
231+
232+
print("Profiling Results:")
233+
for record in prof_ctx.get_all_prof_results():
234+
print(record)

0 commit comments

Comments
 (0)