|
| 1 | +################################################################################ |
| 2 | +# |
| 3 | +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +################################################################################ |
| 17 | + |
| 18 | +import argparse |
| 19 | +from functools import partial |
| 20 | +import os |
| 21 | +import torch |
| 22 | +import time |
| 23 | +import datetime |
| 24 | +import numpy as np |
| 25 | +import flux |
| 26 | +from typing import List |
| 27 | +import itertools |
| 28 | +import dataclasses |
| 29 | +from flux import pynvshmem |
| 30 | + |
| 31 | +RANK = int(os.environ.get("RANK", 0)) |
| 32 | +LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) |
| 33 | +LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) |
| 34 | +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) |
| 35 | +NNODES = WORLD_SIZE // LOCAL_WORLD_SIZE |
| 36 | +torch.cuda.set_device(LOCAL_RANK) |
| 37 | + |
| 38 | +os.environ["NCCL_DEBUG"] = "ERROR" |
| 39 | +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" |
| 40 | +torch.use_deterministic_algorithms(True, warn_only=True) |
| 41 | +torch.set_printoptions(precision=8) |
| 42 | +torch.manual_seed(3 + RANK) |
| 43 | +torch.cuda.manual_seed_all(3 + RANK) |
| 44 | +torch.backends.cudnn.deterministic = True |
| 45 | +torch.backends.cudnn.benchmark = False |
| 46 | +torch.backends.cuda.matmul.allow_tf32 = False |
| 47 | +np.random.seed(3 + RANK) |
| 48 | + |
| 49 | +torch.distributed.init_process_group( |
| 50 | + backend="nccl", world_size=WORLD_SIZE, rank=RANK, timeout=datetime.timedelta(seconds=1800) |
| 51 | +) |
| 52 | +# use all ranks as tp group |
| 53 | +TP_GROUP = torch.distributed.new_group(ranks=list(range(WORLD_SIZE)), backend="nccl") |
| 54 | +print = partial(print, flush=True) |
| 55 | + |
| 56 | + |
| 57 | +@dataclasses.dataclass |
| 58 | +class TuningConfig: |
| 59 | + M: int |
| 60 | + N: int |
| 61 | + K: int |
| 62 | + transpose_weight: bool |
| 63 | + dtype: str |
| 64 | + has_bias: bool |
| 65 | + |
| 66 | + |
| 67 | +def gen_tuning_space(): |
| 68 | + space: List[TuningConfig] = [] |
| 69 | + ## Training shapes |
| 70 | + # space_M = [4096] |
| 71 | + # space_N = [10240, 24576, 8192, 57344, 28672] |
| 72 | + # space_K = [8192] |
| 73 | + space_M = [64, 256, 512, 1024, 2048, 4096, 8192] |
| 74 | + space_N = [16384, 49152] |
| 75 | + space_K = [12288] |
| 76 | + space_transpose_weight = [True, False] |
| 77 | + space_dtype = [torch.bfloat16, torch.float16] |
| 78 | + space_has_bias = [False, True] |
| 79 | + for M, N, K, transpose_weight, dtype, has_bias in itertools.product( |
| 80 | + space_M, |
| 81 | + space_N, |
| 82 | + space_K, |
| 83 | + space_transpose_weight, |
| 84 | + space_dtype, |
| 85 | + space_has_bias, |
| 86 | + ): |
| 87 | + config = TuningConfig( |
| 88 | + M=M, |
| 89 | + N=N, |
| 90 | + K=K, |
| 91 | + transpose_weight=transpose_weight, |
| 92 | + dtype=dtype, |
| 93 | + has_bias=has_bias, |
| 94 | + ) |
| 95 | + space.append(config) |
| 96 | + return space |
| 97 | + |
| 98 | + |
| 99 | +def get_torch_output(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): |
| 100 | + local_M = input.size(0) |
| 101 | + M = local_M * TP_GROUP.size() |
| 102 | + |
| 103 | + torch.distributed.barrier() |
| 104 | + full_input = torch.zeros( |
| 105 | + (M, input.size(1)), |
| 106 | + dtype=input.dtype, |
| 107 | + device=torch.cuda.current_device(), |
| 108 | + requires_grad=False, |
| 109 | + ) |
| 110 | + torch.distributed.all_gather_into_tensor(full_input, input, group=TP_GROUP) |
| 111 | + output = torch.matmul(full_input, weight.t()) |
| 112 | + if bias is not None: |
| 113 | + output += bias |
| 114 | + torch.distributed.barrier() |
| 115 | + return output |
| 116 | + |
| 117 | + |
| 118 | +def run_flux_profiling( |
| 119 | + prof_ctx: flux.ProfilingContext, |
| 120 | + input: torch.Tensor, |
| 121 | + weight: torch.Tensor, |
| 122 | + bias: torch.Tensor, |
| 123 | + config: TuningConfig, |
| 124 | +): |
| 125 | + local_M = input.size(0) |
| 126 | + M = local_M * TP_GROUP.size() |
| 127 | + K = input.size(1) |
| 128 | + |
| 129 | + if config.transpose_weight: |
| 130 | + w = weight.t().contiguous() |
| 131 | + N = w.size(1) |
| 132 | + else: |
| 133 | + w = weight |
| 134 | + N = w.size(0) |
| 135 | + |
| 136 | + ag_gemm_kernel = flux.AGKernel( |
| 137 | + TP_GROUP, |
| 138 | + NNODES, |
| 139 | + M, |
| 140 | + N, |
| 141 | + K, |
| 142 | + input.dtype, |
| 143 | + transpose_weight=config.transpose_weight, |
| 144 | + local_copy=True, |
| 145 | + ) |
| 146 | + |
| 147 | + output = ag_gemm_kernel.profiling(input, w, bias=bias, prof_ctx=prof_ctx) |
| 148 | + return output |
| 149 | + |
| 150 | + |
| 151 | +def tune_one_config(prof_ctx: flux.ProfilingContext, config: TuningConfig): |
| 152 | + assert config.M % TP_GROUP.size() == 0 |
| 153 | + assert config.N % TP_GROUP.size() == 0 |
| 154 | + local_M = config.M // TP_GROUP.size() |
| 155 | + local_N = config.N // TP_GROUP.size() |
| 156 | + |
| 157 | + # input: [M, K], weight: [N, K] |
| 158 | + input = ( |
| 159 | + torch.rand((local_M, config.K), dtype=config.dtype).cuda() / 100 * ((TP_GROUP.rank() + 1)) |
| 160 | + ) |
| 161 | + weight = ( |
| 162 | + torch.rand((local_N, config.K), dtype=config.dtype).cuda() / 100 * ((TP_GROUP.rank() + 1)) |
| 163 | + ) |
| 164 | + |
| 165 | + bias = None |
| 166 | + if config.has_bias: |
| 167 | + bias = torch.rand((config.M, local_N), dtype=input.dtype, device=input.device) |
| 168 | + |
| 169 | + torch_output = get_torch_output(input, weight, bias) |
| 170 | + torch.distributed.barrier() |
| 171 | + flux_output = run_flux_profiling(prof_ctx, input, weight, bias, config) |
| 172 | + torch.distributed.barrier() |
| 173 | + torch.cuda.current_stream().synchronize() |
| 174 | + |
| 175 | + if config.dtype == torch.bfloat16: |
| 176 | + atol, rtol = 0.02, 0.02 |
| 177 | + else: |
| 178 | + atol, rtol = 0.01, 0.01 |
| 179 | + |
| 180 | + for rank in range(WORLD_SIZE): |
| 181 | + if rank == RANK: |
| 182 | + flux.torch_allclose(flux_output.cpu(), torch_output.cpu(), atol=atol, rtol=rtol) |
| 183 | + torch.distributed.barrier() |
| 184 | + |
| 185 | + time.sleep(1) |
| 186 | + |
| 187 | + |
| 188 | +def parse_args(): |
| 189 | + parser = argparse.ArgumentParser() |
| 190 | + parser.add_argument( |
| 191 | + "--output_dir", default="", type=str, help="Directory to store generated files" |
| 192 | + ) |
| 193 | + return parser.parse_args() |
| 194 | + |
| 195 | + |
| 196 | +if __name__ == "__main__": |
| 197 | + args = parse_args() |
| 198 | + if args.output_dir and not os.path.isdir(args.output_dir): |
| 199 | + raise Exception(f"{args.output_dir} not exist") |
| 200 | + |
| 201 | + torch.cuda.set_device(LOCAL_RANK) |
| 202 | + flux.init_flux_shm(TP_GROUP) |
| 203 | + torch.cuda.synchronize() |
| 204 | + |
| 205 | + arch: int = flux.get_arch() |
| 206 | + name: str = f"config_ag_gemm_kernel_sm{arch}_tp{WORLD_SIZE}_nnodes{NNODES}" |
| 207 | + prof_ctx = flux.ProfilingContext(name) |
| 208 | + config_space = gen_tuning_space() |
| 209 | + for i, config in enumerate(config_space): |
| 210 | + if TP_GROUP.rank() == 0: |
| 211 | + print(f"==== #{i + 1}/{len(config_space)} Tuning for {config}") |
| 212 | + tune_one_config(prof_ctx=prof_ctx, config=config) |
| 213 | + if TP_GROUP.rank() == 0: |
| 214 | + print(prof_ctx.get_latest_prof_result()) |
| 215 | + |
| 216 | + if TP_GROUP.rank() == 0: |
| 217 | + if os.path.isdir(args.output_dir): |
| 218 | + code_path = os.path.join(args.output_dir, f"{name}.cu") |
| 219 | + result_path = os.path.join(args.output_dir, f"{name}.prof.log") |
| 220 | + |
| 221 | + with open(code_path, "w") as fout: |
| 222 | + print(prof_ctx.get_code(), file=fout) |
| 223 | + |
| 224 | + with open(result_path, "w") as fout: |
| 225 | + for record in prof_ctx.get_all_prof_results(): |
| 226 | + print(record, file=fout) |
| 227 | + else: |
| 228 | + print("Generated Code:") |
| 229 | + print(prof_ctx.get_code()) |
| 230 | + print() |
| 231 | + |
| 232 | + print("Profiling Results:") |
| 233 | + for record in prof_ctx.get_all_prof_results(): |
| 234 | + print(record) |
0 commit comments