1
1
import functools
2
+ from enum import Enum
2
3
from typing import Optional
3
4
4
5
import torch
5
6
6
7
from ..jit import get_cudnn_fmha_gen_module
7
8
9
+ try :
10
+ import cudnn
11
+
12
+ CUDNN_AVAILABLE = True
13
+ except ImportError :
14
+ cudnn = None
15
+ CUDNN_AVAILABLE = False
16
+
17
+ # Global cudnn handle. need to make it per device in future
18
+ _cudnn_handle = None
19
+
20
+
21
+ def _create_cudnn_handle (stream : torch .cuda .Stream ):
22
+ global _cudnn_handle
23
+ if _cudnn_handle is None :
24
+ _cudnn_handle = cudnn .create_handle ()
25
+ cudnn .set_stream (_cudnn_handle , stream .cuda_stream )
26
+ return _cudnn_handle
27
+
28
+
29
+ # Tensor ids
30
+ class UIDs (Enum ):
31
+ RESERVED_INVALID_UID = 0
32
+
33
+ Q_UID = 1 # Query tensor
34
+ K_UID = 2 # Key cache tensor
35
+ V_UID = 3 # Value cache tensor
36
+
37
+ ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor
38
+ ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor
39
+
40
+ BLOCK_TABLES_UID = 200 # Block tables tensor
41
+ BLOCK_TABLES_K_UID = 201 # Block tables tensor for key
42
+ BLOCK_TABLES_V_UID = 202 # Block tables tensor for value
43
+
44
+ RAGGED_Q_UID = 50 # Ragged query tensor
45
+ RAGGED_O_UID = 51 # Ragged output tensor
46
+ RAGGED_STATS_UID = 52 # Ragged stats tensor
47
+
48
+ O_UID = 1000 # Output tensor
49
+ STATS_UID = 1001 # Stats tensor
50
+
51
+
52
+ def _sdpa_decode_key_fn (
53
+ q : torch .Tensor ,
54
+ k_cache : torch .Tensor ,
55
+ v_cache : torch .Tensor ,
56
+ scale : float ,
57
+ * ,
58
+ max_sequence_kv : int ,
59
+ block_size : Optional [int ] = 1 ,
60
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
61
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
62
+ block_tables : Optional [torch .Tensor ] = None ,
63
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
64
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
65
+ ):
66
+ return (
67
+ "decode" ,
68
+ max_sequence_kv ,
69
+ tuple (q .shape ),
70
+ tuple (k_cache .shape ),
71
+ )
72
+
73
+
74
+ @cudnn .jit (heur_modes = [cudnn .heur_mode .A ])
75
+ @cudnn .graph_cache (key_fn = _sdpa_decode_key_fn )
76
+ def _build_decode_graph (
77
+ q : torch .Tensor ,
78
+ k_cache : torch .Tensor ,
79
+ v_cache : torch .Tensor ,
80
+ scale : float ,
81
+ * ,
82
+ max_sequence_kv : int ,
83
+ block_size : Optional [int ] = 1 ,
84
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
85
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
86
+ block_tables : Optional [torch .Tensor ] = None ,
87
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
88
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
89
+ ):
90
+ handle = _create_cudnn_handle (torch .cuda .current_stream ())
91
+
92
+ with cudnn .graph (handle ) as (g , _ ):
93
+
94
+ if q .dim () == 3 :
95
+ s_qo = 1
96
+ b , h_qo , d_qk = q .shape [0 ], q .shape [1 ], q .shape [2 ]
97
+ elif q .dim () == 4 :
98
+ b , h_qo , s_qo , d_qk = (
99
+ q .shape [0 ],
100
+ q .shape [1 ],
101
+ q .shape [2 ],
102
+ q .shape [3 ],
103
+ )
104
+ else :
105
+ raise ValueError (f"q must have 3 or 4 dimensions, got { q .dim ()} " )
106
+
107
+ assert s_qo == 1 , "q must have a sequence length of 1"
108
+ assert k_cache .dim () == 4 , "k_cache must have 4 dimensions"
109
+
110
+ h_kv = k_cache .shape [1 ]
111
+ s_kv = max_sequence_kv
112
+ d_vo = v_cache .shape [3 ]
113
+
114
+ cudnn_q = g .tensor (
115
+ name = "q" ,
116
+ dim = (b , h_qo , s_qo , d_qk ),
117
+ stride = (h_qo * d_qk , d_qk , d_qk * h_qo , 1 ),
118
+ data_type = cudnn .data_type .BFLOAT16 ,
119
+ )
120
+ if batch_offsets_q is not None :
121
+ ragged_q = g .tensor_like (batch_offsets_q )
122
+ ragged_q .set_uid (UIDs .RAGGED_Q_UID .value )
123
+ cudnn_q .set_ragged_offset (ragged_q )
124
+
125
+ cudnn_k_cache = g .tensor_like (k_cache )
126
+ cudnn_v_cache = g .tensor_like (v_cache )
127
+
128
+ cudnn_q .set_uid (UIDs .Q_UID .value )
129
+ cudnn_k_cache .set_uid (UIDs .K_UID .value )
130
+ cudnn_v_cache .set_uid (UIDs .V_UID .value )
131
+
132
+ if block_tables is not None :
133
+ nd_block_tables = block_tables .reshape (
134
+ block_tables .shape [0 ], 1 , block_tables .shape [1 ], 1
135
+ )
136
+ cudnn_k_block_tables = g .tensor_like (nd_block_tables )
137
+ cudnn_k_block_tables .set_uid (UIDs .BLOCK_TABLES_K_UID .value )
138
+
139
+ cudnn_v_block_tables = g .tensor_like (nd_block_tables )
140
+ cudnn_v_block_tables .set_uid (UIDs .BLOCK_TABLES_V_UID .value )
141
+
142
+ if actual_seq_lens_q is not None :
143
+ cudnn_actual_seq_lens_q = g .tensor_like (actual_seq_lens_q )
144
+ cudnn_actual_seq_lens_q .set_uid (UIDs .ACTUAL_SEQ_LENS_Q_UID .value )
145
+
146
+ if actual_seq_lens_kv is not None :
147
+ cudnn_actual_seq_lens_kv = g .tensor_like (actual_seq_lens_kv )
148
+ cudnn_actual_seq_lens_kv .set_uid (UIDs .ACTUAL_SEQ_LENS_KV_UID .value )
149
+ cudnn_actual_seq_lens_kv .set_is_pass_by_value (False )
150
+
151
+ padding_mask = actual_seq_lens_kv is not None
152
+
153
+ O , _ = g .sdpa (
154
+ name = "sdpa" ,
155
+ q = cudnn_q ,
156
+ k = cudnn_k_cache ,
157
+ v = cudnn_v_cache ,
158
+ seq_len_q = (
159
+ cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
160
+ ),
161
+ seq_len_kv = (
162
+ cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
163
+ ),
164
+ use_padding_mask = padding_mask ,
165
+ is_inference = True ,
166
+ attn_scale = scale ,
167
+ paged_attention_k_table = cudnn_k_block_tables ,
168
+ paged_attention_v_table = cudnn_v_block_tables ,
169
+ paged_attention_max_seq_len_kv = max_sequence_kv ,
170
+ compute_data_type = cudnn .data_type .FLOAT ,
171
+ )
172
+
173
+ if batch_offsets_o is not None :
174
+ ragged_o = g .tensor_like (batch_offsets_o )
175
+ ragged_o .set_uid (UIDs .RAGGED_O_UID .value )
176
+ O .set_ragged_offset (ragged_o )
177
+
178
+ O .set_uid (UIDs .O_UID .value ).set_output (True ).set_dim (
179
+ [b , h_qo , s_qo , d_vo ]
180
+ ).set_stride ([d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]).set_data_type (
181
+ cudnn .data_type .BFLOAT16
182
+ )
183
+
184
+ tensors_to_return = [cudnn_q , cudnn_k_cache , cudnn_v_cache , O ]
185
+
186
+ if actual_seq_lens_q is not None :
187
+ tensors_to_return .append (cudnn_actual_seq_lens_q )
188
+ if actual_seq_lens_kv is not None :
189
+ tensors_to_return .append (cudnn_actual_seq_lens_kv )
190
+
191
+ return g , tensors_to_return
192
+
193
+
194
+ def _batch_decode_with_kv_cache (
195
+ q : torch .Tensor ,
196
+ k_cache : torch .Tensor ,
197
+ v_cache : torch .Tensor ,
198
+ scale : float ,
199
+ workspace_buffer : torch .Tensor ,
200
+ * ,
201
+ max_sequence_kv : int ,
202
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
203
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
204
+ block_tables : Optional [torch .Tensor ] = None ,
205
+ block_size : Optional [int ] = 1 ,
206
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
207
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
208
+ batch_offsets_k : Optional [torch .Tensor ] = None ,
209
+ batch_offsets_v : Optional [torch .Tensor ] = None ,
210
+ out : torch .Tensor ,
211
+ ) -> torch .Tensor :
212
+
213
+ graph , tensors = _build_decode_graph (
214
+ q = q ,
215
+ k_cache = k_cache ,
216
+ v_cache = v_cache ,
217
+ scale = scale ,
218
+ max_sequence_kv = max_sequence_kv ,
219
+ actual_seq_lens_q = actual_seq_lens_q ,
220
+ actual_seq_lens_kv = actual_seq_lens_kv ,
221
+ block_tables = block_tables ,
222
+ block_size = block_size ,
223
+ batch_offsets_q = batch_offsets_q if batch_offsets_q is not None else None ,
224
+ batch_offsets_o = batch_offsets_q if batch_offsets_q is not None else None ,
225
+ )
226
+
227
+ var_map = {
228
+ UIDs .Q_UID .value : q ,
229
+ UIDs .K_UID .value : k_cache ,
230
+ UIDs .V_UID .value : v_cache ,
231
+ UIDs .O_UID .value : out ,
232
+ }
233
+ if actual_seq_lens_q is not None :
234
+ var_map [UIDs .ACTUAL_SEQ_LENS_Q_UID .value ] = actual_seq_lens_q
235
+ if actual_seq_lens_kv is not None :
236
+ var_map [UIDs .ACTUAL_SEQ_LENS_KV_UID .value ] = actual_seq_lens_kv
237
+
238
+ if batch_offsets_q is not None :
239
+ var_map [UIDs .RAGGED_Q_UID .value ] = batch_offsets_q
240
+ if batch_offsets_o is not None :
241
+ var_map [UIDs .RAGGED_O_UID .value ] = batch_offsets_o
242
+
243
+ if block_tables is not None :
244
+ var_map [UIDs .BLOCK_TABLES_K_UID .value ] = block_tables
245
+ var_map [UIDs .BLOCK_TABLES_V_UID .value ] = block_tables
246
+
247
+ graph .execute (var_map , workspace = workspace_buffer )
248
+
249
+ return out
250
+
8
251
9
252
def cudnn_batch_decode_with_kv_cache (
10
253
q : torch .Tensor ,
@@ -37,7 +280,6 @@ def cudnn_batch_decode_with_kv_cache(
37
280
is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph
38
281
batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU
39
282
out: Optional pre-allocated output tensor
40
- lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
41
283
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
42
284
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
43
285
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -53,30 +295,51 @@ def cudnn_batch_decode_with_kv_cache(
53
295
"""
54
296
55
297
bs = q .shape [0 ]
56
- s_q = 1
57
298
h_qo = q .shape [1 ]
58
299
d_vo = v_cache .shape [3 ]
59
300
60
301
if out is None :
61
302
out = torch .empty (bs , h_qo , d_vo , device = q .device , dtype = q .dtype )
62
303
63
- actual_seq_lens_kv_gpu = actual_seq_lens_kv .to (q .device , non_blocking = True )
304
+ if not CUDNN_AVAILABLE :
305
+ actual_seq_lens_kv_gpu = actual_seq_lens_kv .to (q .device , non_blocking = True )
64
306
65
- run_func = get_cudnn_fmha_gen_module ().decode
66
- run_func (
67
- max_sequence_kv ,
68
- q ,
69
- k_cache ,
70
- v_cache ,
71
- scale ,
72
- workspace_buffer ,
73
- actual_seq_lens_kv ,
74
- actual_seq_lens_kv_gpu ,
75
- block_tables ,
76
- out ,
77
- batch_offsets_q ,
78
- batch_offsets_o ,
79
- is_cuda_graph_compatible ,
80
- )
307
+ run_func = get_cudnn_fmha_gen_module ().decode
308
+ run_func (
309
+ max_sequence_kv ,
310
+ q ,
311
+ k_cache ,
312
+ v_cache ,
313
+ scale ,
314
+ workspace_buffer ,
315
+ actual_seq_lens_kv ,
316
+ actual_seq_lens_kv_gpu ,
317
+ block_tables ,
318
+ out ,
319
+ batch_offsets_q ,
320
+ batch_offsets_o ,
321
+ is_cuda_graph_compatible ,
322
+ )
323
+ else :
324
+ actual_seq_lens_q = torch .ones (
325
+ (bs , 1 , 1 , 1 ), device = q .device , dtype = torch .int32
326
+ )
327
+ block_size = k_cache .shape [2 ]
328
+
329
+ _batch_decode_with_kv_cache (
330
+ q = q ,
331
+ k_cache = k_cache ,
332
+ v_cache = v_cache ,
333
+ scale = scale ,
334
+ workspace_buffer = workspace_buffer ,
335
+ max_sequence_kv = max_sequence_kv ,
336
+ actual_seq_lens_q = actual_seq_lens_q ,
337
+ actual_seq_lens_kv = actual_seq_lens_kv ,
338
+ block_tables = block_tables ,
339
+ batch_offsets_q = batch_offsets_q ,
340
+ batch_offsets_o = batch_offsets_o ,
341
+ block_size = block_size ,
342
+ out = out ,
343
+ )
81
344
82
345
return out
0 commit comments