1
1
/* Copyright (c) 2022, Tri Dao.
2
2
*/
3
3
4
+ #include " static_switch.h"
4
5
#include " fmha.h"
5
6
#include " fmha_dgrad_kernel_1xN_loop.h"
6
7
@@ -22,106 +23,107 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
22
23
static_assert (smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
23
24
24
25
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2 ) + smem_size_dq + smem_size_s * 2 ;
25
-
26
- bool is_dropout = params.p_dropout < 1 .f ; // params.p_dropout is the probability of "keeping"
27
- bool is_causal = params.is_causal ;
28
- auto kernel = is_dropout
29
- ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , true > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , false >)
30
- : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , true > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , false >);
31
26
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
32
- if (params.seqlen_k == blocksize_c) {
33
- kernel = is_dropout
34
- ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , true , /* loop_steps=*/ 1 > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , false , /* loop_steps=*/ 1 >)
35
- : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , true , /* loop_steps=*/ 1 > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , false , /* loop_steps=*/ 1 >);
36
- } else if (params.seqlen_k == blocksize_c * 2 ) {
37
- kernel = is_dropout
38
- ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , true , /* loop_steps=*/ 2 > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true , false , /* loop_steps=*/ 2 >)
39
- : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , true , /* loop_steps=*/ 2 > : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false , false , /* loop_steps=*/ 2 >);
40
- }
41
-
42
27
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
43
- if ( smem_size_dq_dk_dv >= 48 * 1024 ) {
44
- FMHA_CHECK_CUDA (cudaFuncSetAttribute (
45
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
46
- }
47
- dim3 grid (params.b , params.h );
48
- kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>> (params);
49
- FMHA_CHECK_CUDA (cudaPeekAtLastError ());
28
+
29
+ bool is_dropout = params.p_dropout < 1 .f ; // params.p_dropout is the probability of "keeping"
30
+ BOOL_SWITCH (is_dropout, IsDropoutConst, [&] {
31
+ BOOL_SWITCH (params.is_causal , IsCausalConst, [&] {
32
+ auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
33
+ Kernel_traits, IsDropoutConst, IsCausalConst>;
34
+ if (params.seqlen_k == blocksize_c) {
35
+ kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
36
+ Kernel_traits, IsDropoutConst, IsCausalConst, /* loop_steps=*/ 1 >;
37
+ } else if (params.seqlen_k == blocksize_c * 2 ) {
38
+ kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
39
+ Kernel_traits, IsDropoutConst, IsCausalConst, /* loop_steps=*/ 2 >;
40
+ }
41
+ if ( smem_size_dq_dk_dv >= 48 * 1024 ) {
42
+ FMHA_CHECK_CUDA (cudaFuncSetAttribute (
43
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
44
+ }
45
+ dim3 grid (params.b , params.h );
46
+ kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>> (params);
47
+ FMHA_CHECK_CUDA (cudaPeekAtLastError ());
48
+ });
49
+ });
50
50
}
51
51
52
52
void run_fmha_dgrad_fp16_sm80 (const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
53
- if (params.d == 16 ) {
54
- if ( params.seqlen_k == 128 ) {
55
- using Kernel_traits = FMHA_kernel_traits<128 , 16 , 16 , 1 , 8 , 0x08u >;
56
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
57
- } else if ( params.seqlen_k == 256 ) {
58
- using Kernel_traits = FMHA_kernel_traits<256 , 16 , 16 , 1 , 8 , 0x08u >;
59
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
60
- } else {
61
- // TD [2022-05-15] 512 gives wrong results rn
62
- // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
63
- using Kernel_traits = FMHA_kernel_traits<256 , 16 , 16 , 1 , 8 , 0x08u >;
64
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
65
- }
66
- } else if (params.d == 32 ) {
67
- if ( params.seqlen_k == 128 ) {
68
- using Kernel_traits = FMHA_kernel_traits<128 , 32 , 16 , 1 , 8 , 0x08u >;
69
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
70
- } else if ( params.seqlen_k >= 256 ) {
71
- using Kernel_traits = FMHA_kernel_traits<256 , 32 , 16 , 1 , 8 , 0x08u >;
72
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
73
- }
74
- } else if (params.d == 64 ) {
75
- if ( params.seqlen_k == 128 ) {
76
- using Kernel_traits = FMHA_kernel_traits<128 , 64 , 16 , 1 , 8 , 0x08u >;
77
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
78
- } else if ( params.seqlen_k >= 256 ) {
79
- auto dprops = at::cuda::getCurrentDeviceProperties ();
80
- if (dprops->major == 8 && dprops->minor == 0 ) {
81
- // Don't share smem for K & V, and don't keep V in registers
82
- // This speeds things up by 2-3% by avoiding register spills, but it
83
- // uses more shared memory, which is fine on A100 but not other GPUs.
84
- // For other GPUs, we keep V in registers.
85
- using Kernel_traits = FMHA_kernel_traits<256 , 64 , 16 , 1 , 8 , 0x100u >;
53
+ BOOL_SWITCH (params.is_bf16 , IsBf16Const, [&] {
54
+ using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
55
+ auto dprops = at::cuda::getCurrentDeviceProperties ();
56
+ if (params.d == 16 ) {
57
+ if ( params.seqlen_k == 128 ) {
58
+ using Kernel_traits = FMHA_kernel_traits<128 , 16 , 16 , 1 , 8 , 0x08u , elem_type>;
59
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
60
+ } else if ( params.seqlen_k == 256 ) {
61
+ using Kernel_traits = FMHA_kernel_traits<256 , 16 , 16 , 1 , 8 , 0x08u , elem_type>;
86
62
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
87
- } else if (dprops->major == 8 && dprops->minor > 0 ) {
88
- using Kernel_traits = FMHA_kernel_traits<256 , 64 , 16 , 1 , 8 , 0x08u >;
63
+ } else {
64
+ // TD [2022-05-15] 512 gives wrong results rn
65
+ // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
66
+ using Kernel_traits = FMHA_kernel_traits<256 , 16 , 16 , 1 , 8 , 0x08u , elem_type>;
89
67
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
90
- } else if (dprops->major == 7 && dprops->minor == 5 ) {
91
- using Kernel_traits = FMHA_kernel_traits<128 , 64 , 16 , 1 , 8 , 0x08u >;
68
+ }
69
+ } else if (params.d == 32 ) {
70
+ if ( params.seqlen_k == 128 ) {
71
+ using Kernel_traits = FMHA_kernel_traits<128 , 32 , 16 , 1 , 8 , 0x08u , elem_type>;
72
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
73
+ } else if ( params.seqlen_k >= 256 ) {
74
+ using Kernel_traits = FMHA_kernel_traits<256 , 32 , 16 , 1 , 8 , 0x08u , elem_type>;
92
75
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
93
76
}
77
+ } else if (params.d == 64 ) {
78
+ if ( params.seqlen_k == 128 ) {
79
+ using Kernel_traits = FMHA_kernel_traits<128 , 64 , 16 , 1 , 8 , 0x08u , elem_type>;
80
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
81
+ } else if ( params.seqlen_k >= 256 ) {
82
+ if (dprops->major == 8 && dprops->minor == 0 ) {
83
+ // Don't share smem for K & V, and don't keep V in registers
84
+ // This speeds things up by 2-3% by avoiding register spills, but it
85
+ // uses more shared memory, which is fine on A100 but not other GPUs.
86
+ // For other GPUs, we keep V in registers.
87
+ using Kernel_traits = FMHA_kernel_traits<256 , 64 , 16 , 1 , 8 , 0x100u , elem_type>;
88
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
89
+ } else if (dprops->major == 8 && dprops->minor > 0 ) {
90
+ using Kernel_traits = FMHA_kernel_traits<256 , 64 , 16 , 1 , 8 , 0x08u , elem_type>;
91
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
92
+ } else if (dprops->major == 7 && dprops->minor == 5 ) {
93
+ using Kernel_traits = FMHA_kernel_traits<128 , 64 , 16 , 1 , 8 , 0x08u , elem_type>;
94
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
95
+ }
96
+ }
97
+ } else if (params.d == 128 ) {
98
+ using Kernel_traits = FMHA_kernel_traits<128 , 128 , 16 , 1 , 8 , 0x100u , elem_type>;
99
+ run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
94
100
}
95
- } else if (params.d == 128 ) {
96
- using Kernel_traits = FMHA_kernel_traits<128 , 128 , 16 , 1 , 8 , 0x100u >;
97
- run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
98
- }
99
- // if (params.d == 64) {
100
- // auto dprops = at::cuda::getCurrentDeviceProperties();
101
- // if (dprops->major == 7 && dprops->minor == 5) {
102
- // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
103
- // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
104
- // } else {
105
- // if( params.seqlen_k == 128 ) {
106
- // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
107
- // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
108
- // } else if( params.seqlen_k >= 256 ) {
109
- // if (dprops->major == 8 && dprops->minor == 0) {
110
- // // Don't share smem for K & V, and don't keep V in registers
111
- // // This speeds things up by 2-3% by avoiding register spills, but it
112
- // // uses more shared memory, which is fine on A100 but not other GPUs.
113
- // // For other GPUs, we keep V in registers.
114
- // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
115
- // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
116
- // } else if (dprops->major == 8 && dprops->minor > 0) {
117
- // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
118
- // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
119
- // }
120
- // }
121
- // }
122
- // }
123
- // if (params.d == 128) {
124
- // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
125
- // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
126
- // }
101
+ // if (params.d == 64) {
102
+ // if (dprops->major == 7 && dprops->minor == 5) {
103
+ // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
104
+ // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
105
+ // } else {
106
+ // if( params.seqlen_k == 128 ) {
107
+ // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
108
+ // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
109
+ // } else if( params.seqlen_k >= 256 ) {
110
+ // if (dprops->major == 8 && dprops->minor == 0) {
111
+ // // Don't share smem for K & V, and don't keep V in registers
112
+ // // This speeds things up by 2-3% by avoiding register spills, but it
113
+ // // uses more shared memory, which is fine on A100 but not other GPUs.
114
+ // // For other GPUs, we keep V in registers.
115
+ // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
116
+ // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
117
+ // } else if (dprops->major == 8 && dprops->minor > 0) {
118
+ // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
119
+ // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
120
+ // }
121
+ // }
122
+ // }
123
+ // }
124
+ // if (params.d == 128) {
125
+ // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
126
+ // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
127
+ // }
128
+ });
127
129
}
0 commit comments