Skip to content

Commit de19de7

Browse files
committed
Implement for bf16
1 parent 6a77a6d commit de19de7

10 files changed

+329
-262
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ Our tentative roadmap:
3131
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3232
3. [Jun 2022] Refactor to use Cutlass.
3333
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
34-
5. [Jun 2022] Support bf16.
34+
5. ~~[Jun 2022] Support bf16~~[Done].
3535
6. ~~[Jul 2022] Implement cross-attention~~[Done].
3636
7. ~~[Jul 2022] Support head dimension 128~~[Done].
3737
8. [Jul 2022] Support SM70 GPUs (V100).
3838
9. [Aug 2022] Fuse rotary embedding.
39-
10. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
39+
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
4040

4141
## Speedup and Memory Savings
4242

csrc/flash_attn/fmha_api.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ void set_params_fprop(FMHA_fprop_params &params,
5656
bool is_causal) {
5757

5858
Data_type acc_type = DATA_TYPE_FP32;
59-
Data_type data_type = DATA_TYPE_FP16;
59+
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
6060

6161
// Reset the parameters
6262
memset(&params, 0, sizeof(params));
6363

64+
params.is_bf16 = q.dtype() == torch::kBFloat16;
65+
6466
// Set the pointers and strides.
6567
params.q_ptr = q.data_ptr();
6668
params.k_ptr = k.data_ptr();
@@ -192,9 +194,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
192194
bool is_dropout = p_dropout > 0.0;
193195
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
194196

195-
TORCH_CHECK(q.dtype() == torch::kFloat16);
196-
TORCH_CHECK(k.dtype() == torch::kFloat16);
197-
TORCH_CHECK(v.dtype() == torch::kFloat16);
197+
auto q_dtype = q.dtype();
198+
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
199+
TORCH_CHECK(k.dtype() == q_dtype);
200+
TORCH_CHECK(v.dtype() == q_dtype);
198201
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
199202
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
200203

@@ -326,14 +329,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
326329
bool is_dropout = p_dropout > 0.0;
327330
auto stream = at::cuda::getCurrentCUDAStream().stream();
328331

329-
TORCH_CHECK(q.dtype() == torch::kFloat16);
330-
TORCH_CHECK(k.dtype() == torch::kFloat16);
331-
TORCH_CHECK(v.dtype() == torch::kFloat16);
332-
TORCH_CHECK(out.dtype() == torch::kFloat16);
333-
TORCH_CHECK(dout.dtype() == torch::kFloat16);
334-
TORCH_CHECK(dq.dtype() == torch::kFloat16);
335-
TORCH_CHECK(dk.dtype() == torch::kFloat16);
336-
TORCH_CHECK(dv.dtype() == torch::kFloat16);
332+
auto q_dtype = q.dtype();
333+
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
334+
TORCH_CHECK(k.dtype() == q_dtype);
335+
TORCH_CHECK(v.dtype() == q_dtype);
336+
TORCH_CHECK(out.dtype() == q_dtype);
337+
TORCH_CHECK(dout.dtype() == q_dtype);
338+
TORCH_CHECK(dq.dtype() == q_dtype);
339+
TORCH_CHECK(dk.dtype() == q_dtype);
340+
TORCH_CHECK(dv.dtype() == q_dtype);
337341
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
338342
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
339343

@@ -720,4 +724,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
720724
m.def("bwd", &mha_bwd, "Backward pass");
721725
m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
722726
m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
723-
}
727+
}

csrc/flash_attn/src/fmha.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ struct FMHA_fprop_params : public Qkv_params {
123123
// Random state.
124124
at::PhiloxCudaState philox_args;
125125

126+
bool is_bf16;
126127
bool is_causal;
127128
};
128129

csrc/flash_attn/src/fmha/kernel_traits.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
*
2626
******************************************************************************/
2727

28+
#include <cuda_fp16.h>
29+
2830
#pragma once
2931

3032
////////////////////////////////////////////////////////////////////////////////////////////////////
3133

32-
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
34+
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
3335
struct FMHA_kernel_traits {
3436

3537
// The CTA description for the 1st GEMM.
@@ -80,6 +82,8 @@ struct FMHA_kernel_traits {
8082
// The shared memory tile to store dp sum.
8183
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
8284

85+
using elem_type = elem_type_;
86+
8387
// Make sure the number of threads match.
8488
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
8589

Lines changed: 94 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* Copyright (c) 2022, Tri Dao.
22
*/
33

4+
#include "static_switch.h"
45
#include "fmha.h"
56
#include "fmha_dgrad_kernel_1xN_loop.h"
67

@@ -22,106 +23,107 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
2223
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
2324

2425
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
25-
26-
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
27-
bool is_causal = params.is_causal;
28-
auto kernel = is_dropout
29-
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
30-
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
3126
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
32-
if (params.seqlen_k == blocksize_c) {
33-
kernel = is_dropout
34-
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
35-
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
36-
} else if (params.seqlen_k == blocksize_c * 2) {
37-
kernel = is_dropout
38-
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
39-
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
40-
}
41-
4227
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
43-
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
44-
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
45-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
46-
}
47-
dim3 grid(params.b, params.h);
48-
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
49-
FMHA_CHECK_CUDA(cudaPeekAtLastError());
28+
29+
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
30+
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
31+
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
32+
auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
33+
Kernel_traits, IsDropoutConst, IsCausalConst>;
34+
if (params.seqlen_k == blocksize_c) {
35+
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
36+
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>;
37+
} else if (params.seqlen_k == blocksize_c * 2) {
38+
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
39+
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>;
40+
}
41+
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
42+
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
43+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
44+
}
45+
dim3 grid(params.b, params.h);
46+
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
47+
FMHA_CHECK_CUDA(cudaPeekAtLastError());
48+
});
49+
});
5050
}
5151

5252
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
53-
if (params.d == 16) {
54-
if( params.seqlen_k == 128 ) {
55-
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>;
56-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
57-
} else if( params.seqlen_k == 256 ) {
58-
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
59-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
60-
} else {
61-
// TD [2022-05-15] 512 gives wrong results rn
62-
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
63-
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
64-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
65-
}
66-
} else if (params.d == 32) {
67-
if( params.seqlen_k == 128 ) {
68-
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>;
69-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
70-
} else if( params.seqlen_k >= 256 ) {
71-
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
72-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
73-
}
74-
} else if (params.d == 64) {
75-
if( params.seqlen_k == 128 ) {
76-
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
77-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
78-
} else if( params.seqlen_k >= 256 ) {
79-
auto dprops = at::cuda::getCurrentDeviceProperties();
80-
if (dprops->major == 8 && dprops->minor == 0) {
81-
// Don't share smem for K & V, and don't keep V in registers
82-
// This speeds things up by 2-3% by avoiding register spills, but it
83-
// uses more shared memory, which is fine on A100 but not other GPUs.
84-
// For other GPUs, we keep V in registers.
85-
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
53+
BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
54+
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
55+
auto dprops = at::cuda::getCurrentDeviceProperties();
56+
if (params.d == 16) {
57+
if( params.seqlen_k == 128 ) {
58+
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
59+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
60+
} else if( params.seqlen_k == 256 ) {
61+
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
8662
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
87-
} else if (dprops->major == 8 && dprops->minor > 0) {
88-
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
63+
} else {
64+
// TD [2022-05-15] 512 gives wrong results rn
65+
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
66+
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
8967
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
90-
} else if (dprops->major == 7 && dprops->minor == 5) {
91-
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
68+
}
69+
} else if (params.d == 32) {
70+
if( params.seqlen_k == 128 ) {
71+
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
72+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
73+
} else if( params.seqlen_k >= 256 ) {
74+
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
9275
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
9376
}
77+
} else if (params.d == 64) {
78+
if( params.seqlen_k == 128 ) {
79+
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
80+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
81+
} else if( params.seqlen_k >= 256 ) {
82+
if (dprops->major == 8 && dprops->minor == 0) {
83+
// Don't share smem for K & V, and don't keep V in registers
84+
// This speeds things up by 2-3% by avoiding register spills, but it
85+
// uses more shared memory, which is fine on A100 but not other GPUs.
86+
// For other GPUs, we keep V in registers.
87+
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
88+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
89+
} else if (dprops->major == 8 && dprops->minor > 0) {
90+
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
91+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
92+
} else if (dprops->major == 7 && dprops->minor == 5) {
93+
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
94+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
95+
}
96+
}
97+
} else if (params.d == 128) {
98+
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
99+
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
94100
}
95-
} else if (params.d == 128) {
96-
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
97-
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
98-
}
99-
// if (params.d == 64) {
100-
// auto dprops = at::cuda::getCurrentDeviceProperties();
101-
// if (dprops->major == 7 && dprops->minor == 5) {
102-
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
103-
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
104-
// } else {
105-
// if( params.seqlen_k == 128 ) {
106-
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
107-
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
108-
// } else if( params.seqlen_k >= 256 ) {
109-
// if (dprops->major == 8 && dprops->minor == 0) {
110-
// // Don't share smem for K & V, and don't keep V in registers
111-
// // This speeds things up by 2-3% by avoiding register spills, but it
112-
// // uses more shared memory, which is fine on A100 but not other GPUs.
113-
// // For other GPUs, we keep V in registers.
114-
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
115-
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
116-
// } else if (dprops->major == 8 && dprops->minor > 0) {
117-
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
118-
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
119-
// }
120-
// }
121-
// }
122-
// }
123-
// if (params.d == 128) {
124-
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
125-
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
126-
// }
101+
// if (params.d == 64) {
102+
// if (dprops->major == 7 && dprops->minor == 5) {
103+
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
104+
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
105+
// } else {
106+
// if( params.seqlen_k == 128 ) {
107+
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
108+
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
109+
// } else if( params.seqlen_k >= 256 ) {
110+
// if (dprops->major == 8 && dprops->minor == 0) {
111+
// // Don't share smem for K & V, and don't keep V in registers
112+
// // This speeds things up by 2-3% by avoiding register spills, but it
113+
// // uses more shared memory, which is fine on A100 but not other GPUs.
114+
// // For other GPUs, we keep V in registers.
115+
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
116+
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
117+
// } else if (dprops->major == 8 && dprops->minor > 0) {
118+
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
119+
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
120+
// }
121+
// }
122+
// }
123+
// }
124+
// if (params.d == 128) {
125+
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
126+
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
127+
// }
128+
});
127129
}

0 commit comments

Comments
 (0)