Skip to content

Commit 15de811

Browse files
committed
Initial commit to add native cudnn_decode
Fix the code review comments
1 parent 4e9da5d commit 15de811

File tree

2 files changed

+295
-26
lines changed

2 files changed

+295
-26
lines changed

flashinfer/cudnn/decode.py

Lines changed: 282 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,253 @@
11
import functools
2+
from enum import Enum
23
from typing import Optional
34

45
import torch
56

67
from ..jit import get_cudnn_fmha_gen_module
78

9+
try:
10+
import cudnn
11+
12+
CUDNN_AVAILABLE = True
13+
except ImportError:
14+
cudnn = None
15+
CUDNN_AVAILABLE = False
16+
17+
# Global cudnn handle. need to make it per device in future
18+
_cudnn_handle = None
19+
20+
21+
def _create_cudnn_handle(stream: torch.cuda.Stream):
22+
global _cudnn_handle
23+
if _cudnn_handle is None:
24+
_cudnn_handle = cudnn.create_handle()
25+
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
26+
return _cudnn_handle
27+
28+
29+
# Tensor ids
30+
class UIDs(Enum):
31+
RESERVED_INVALID_UID = 0
32+
33+
Q_UID = 1 # Query tensor
34+
K_UID = 2 # Key cache tensor
35+
V_UID = 3 # Value cache tensor
36+
37+
ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor
38+
ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor
39+
40+
BLOCK_TABLES_UID = 200 # Block tables tensor
41+
BLOCK_TABLES_K_UID = 201 # Block tables tensor for key
42+
BLOCK_TABLES_V_UID = 202 # Block tables tensor for value
43+
44+
RAGGED_Q_UID = 50 # Ragged query tensor
45+
RAGGED_O_UID = 51 # Ragged output tensor
46+
RAGGED_STATS_UID = 52 # Ragged stats tensor
47+
48+
O_UID = 1000 # Output tensor
49+
STATS_UID = 1001 # Stats tensor
50+
51+
52+
def _sdpa_decode_key_fn(
53+
q: torch.Tensor,
54+
k_cache: torch.Tensor,
55+
v_cache: torch.Tensor,
56+
scale: float,
57+
*,
58+
max_sequence_kv: int,
59+
block_size: Optional[int] = 1,
60+
actual_seq_lens_q: Optional[torch.Tensor] = None,
61+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
62+
block_tables: Optional[torch.Tensor] = None,
63+
batch_offsets_q: Optional[torch.Tensor] = None,
64+
batch_offsets_o: Optional[torch.Tensor] = None,
65+
):
66+
return (
67+
"decode",
68+
max_sequence_kv,
69+
tuple(q.shape),
70+
tuple(k_cache.shape),
71+
)
72+
73+
74+
@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
75+
@cudnn.graph_cache(key_fn=_sdpa_decode_key_fn)
76+
def _build_decode_graph(
77+
q: torch.Tensor,
78+
k_cache: torch.Tensor,
79+
v_cache: torch.Tensor,
80+
scale: float,
81+
*,
82+
max_sequence_kv: int,
83+
block_size: Optional[int] = 1,
84+
actual_seq_lens_q: Optional[torch.Tensor] = None,
85+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
86+
block_tables: Optional[torch.Tensor] = None,
87+
batch_offsets_q: Optional[torch.Tensor] = None,
88+
batch_offsets_o: Optional[torch.Tensor] = None,
89+
):
90+
handle = _create_cudnn_handle(torch.cuda.current_stream())
91+
92+
with cudnn.graph(handle) as (g, _):
93+
94+
if q.dim() == 3:
95+
s_qo = 1
96+
b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2]
97+
elif q.dim() == 4:
98+
b, h_qo, s_qo, d_qk = (
99+
q.shape[0],
100+
q.shape[1],
101+
q.shape[2],
102+
q.shape[3],
103+
)
104+
else:
105+
raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}")
106+
107+
assert s_qo == 1, "q must have a sequence length of 1"
108+
assert k_cache.dim() == 4, "k_cache must have 4 dimensions"
109+
110+
h_kv = k_cache.shape[1]
111+
s_kv = max_sequence_kv
112+
d_vo = v_cache.shape[3]
113+
114+
cudnn_q = g.tensor(
115+
name="q",
116+
dim=(b, h_qo, s_qo, d_qk),
117+
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
118+
data_type=cudnn.data_type.BFLOAT16,
119+
)
120+
if batch_offsets_q is not None:
121+
ragged_q = g.tensor_like(batch_offsets_q)
122+
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
123+
cudnn_q.set_ragged_offset(ragged_q)
124+
125+
cudnn_k_cache = g.tensor_like(k_cache)
126+
cudnn_v_cache = g.tensor_like(v_cache)
127+
128+
cudnn_q.set_uid(UIDs.Q_UID.value)
129+
cudnn_k_cache.set_uid(UIDs.K_UID.value)
130+
cudnn_v_cache.set_uid(UIDs.V_UID.value)
131+
132+
if block_tables is not None:
133+
nd_block_tables = block_tables.reshape(
134+
block_tables.shape[0], 1, block_tables.shape[1], 1
135+
)
136+
cudnn_k_block_tables = g.tensor_like(nd_block_tables)
137+
cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value)
138+
139+
cudnn_v_block_tables = g.tensor_like(nd_block_tables)
140+
cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value)
141+
142+
if actual_seq_lens_q is not None:
143+
cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q)
144+
cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value)
145+
146+
if actual_seq_lens_kv is not None:
147+
cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv)
148+
cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value)
149+
cudnn_actual_seq_lens_kv.set_is_pass_by_value(False)
150+
151+
padding_mask = actual_seq_lens_kv is not None
152+
153+
O, _ = g.sdpa(
154+
name="sdpa",
155+
q=cudnn_q,
156+
k=cudnn_k_cache,
157+
v=cudnn_v_cache,
158+
seq_len_q=(
159+
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
160+
),
161+
seq_len_kv=(
162+
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
163+
),
164+
use_padding_mask=padding_mask,
165+
is_inference=True,
166+
attn_scale=scale,
167+
paged_attention_k_table=cudnn_k_block_tables,
168+
paged_attention_v_table=cudnn_v_block_tables,
169+
paged_attention_max_seq_len_kv=max_sequence_kv,
170+
compute_data_type=cudnn.data_type.FLOAT,
171+
)
172+
173+
if batch_offsets_o is not None:
174+
ragged_o = g.tensor_like(batch_offsets_o)
175+
ragged_o.set_uid(UIDs.RAGGED_O_UID.value)
176+
O.set_ragged_offset(ragged_o)
177+
178+
O.set_uid(UIDs.O_UID.value).set_output(True).set_dim(
179+
[b, h_qo, s_qo, d_vo]
180+
).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type(
181+
cudnn.data_type.BFLOAT16
182+
)
183+
184+
tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O]
185+
186+
if actual_seq_lens_q is not None:
187+
tensors_to_return.append(cudnn_actual_seq_lens_q)
188+
if actual_seq_lens_kv is not None:
189+
tensors_to_return.append(cudnn_actual_seq_lens_kv)
190+
191+
return g, tensors_to_return
192+
193+
194+
def _batch_decode_with_kv_cache(
195+
q: torch.Tensor,
196+
k_cache: torch.Tensor,
197+
v_cache: torch.Tensor,
198+
scale: float,
199+
workspace_buffer: torch.Tensor,
200+
*,
201+
max_sequence_kv: int,
202+
actual_seq_lens_q: Optional[torch.Tensor] = None,
203+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
204+
block_tables: Optional[torch.Tensor] = None,
205+
block_size: Optional[int] = 1,
206+
batch_offsets_q: Optional[torch.Tensor] = None,
207+
batch_offsets_o: Optional[torch.Tensor] = None,
208+
batch_offsets_k: Optional[torch.Tensor] = None,
209+
batch_offsets_v: Optional[torch.Tensor] = None,
210+
out: torch.Tensor,
211+
) -> torch.Tensor:
212+
213+
graph, tensors = _build_decode_graph(
214+
q=q,
215+
k_cache=k_cache,
216+
v_cache=v_cache,
217+
scale=scale,
218+
max_sequence_kv=max_sequence_kv,
219+
actual_seq_lens_q=actual_seq_lens_q,
220+
actual_seq_lens_kv=actual_seq_lens_kv,
221+
block_tables=block_tables,
222+
block_size=block_size,
223+
batch_offsets_q=batch_offsets_q if batch_offsets_q is not None else None,
224+
batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None,
225+
)
226+
227+
var_map = {
228+
UIDs.Q_UID.value: q,
229+
UIDs.K_UID.value: k_cache,
230+
UIDs.V_UID.value: v_cache,
231+
UIDs.O_UID.value: out,
232+
}
233+
if actual_seq_lens_q is not None:
234+
var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q
235+
if actual_seq_lens_kv is not None:
236+
var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv
237+
238+
if batch_offsets_q is not None:
239+
var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q
240+
if batch_offsets_o is not None:
241+
var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o
242+
243+
if block_tables is not None:
244+
var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables
245+
var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables
246+
247+
graph.execute(var_map, workspace=workspace_buffer)
248+
249+
return out
250+
8251

9252
def cudnn_batch_decode_with_kv_cache(
10253
q: torch.Tensor,
@@ -37,7 +280,6 @@ def cudnn_batch_decode_with_kv_cache(
37280
is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph
38281
batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU
39282
out: Optional pre-allocated output tensor
40-
lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
41283
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
42284
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
43285
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -53,30 +295,51 @@ def cudnn_batch_decode_with_kv_cache(
53295
"""
54296

55297
bs = q.shape[0]
56-
s_q = 1
57298
h_qo = q.shape[1]
58299
d_vo = v_cache.shape[3]
59300

60301
if out is None:
61302
out = torch.empty(bs, h_qo, d_vo, device=q.device, dtype=q.dtype)
62303

63-
actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)
304+
if not CUDNN_AVAILABLE:
305+
actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)
64306

65-
run_func = get_cudnn_fmha_gen_module().decode
66-
run_func(
67-
max_sequence_kv,
68-
q,
69-
k_cache,
70-
v_cache,
71-
scale,
72-
workspace_buffer,
73-
actual_seq_lens_kv,
74-
actual_seq_lens_kv_gpu,
75-
block_tables,
76-
out,
77-
batch_offsets_q,
78-
batch_offsets_o,
79-
is_cuda_graph_compatible,
80-
)
307+
run_func = get_cudnn_fmha_gen_module().decode
308+
run_func(
309+
max_sequence_kv,
310+
q,
311+
k_cache,
312+
v_cache,
313+
scale,
314+
workspace_buffer,
315+
actual_seq_lens_kv,
316+
actual_seq_lens_kv_gpu,
317+
block_tables,
318+
out,
319+
batch_offsets_q,
320+
batch_offsets_o,
321+
is_cuda_graph_compatible,
322+
)
323+
else:
324+
actual_seq_lens_q = torch.ones(
325+
(bs, 1, 1, 1), device=q.device, dtype=torch.int32
326+
)
327+
block_size = k_cache.shape[2]
328+
329+
_batch_decode_with_kv_cache(
330+
q=q,
331+
k_cache=k_cache,
332+
v_cache=v_cache,
333+
scale=scale,
334+
workspace_buffer=workspace_buffer,
335+
max_sequence_kv=max_sequence_kv,
336+
actual_seq_lens_q=actual_seq_lens_q,
337+
actual_seq_lens_kv=actual_seq_lens_kv,
338+
block_tables=block_tables,
339+
batch_offsets_q=batch_offsets_q,
340+
batch_offsets_o=batch_offsets_o,
341+
block_size=block_size,
342+
out=out,
343+
)
81344

82345
return out

tests/test_cudnn_decode.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import flashinfer
88

99

10-
@pytest.mark.parametrize("batch_size", [4, 8, 17, 64])
11-
@pytest.mark.parametrize("s_kv", [8, 40, 1024])
12-
@pytest.mark.parametrize("page_size", [1, 8])
13-
@pytest.mark.parametrize("num_kv_heads", [4])
14-
@pytest.mark.parametrize("num_qo_heads", [4, 32])
15-
@pytest.mark.parametrize("is_cuda_graph_compatible", [False, True])
10+
@pytest.mark.parametrize("batch_size", [8, 16, 32])
11+
@pytest.mark.parametrize("s_kv", [1024, 8192])
12+
@pytest.mark.parametrize("page_size", [16])
13+
@pytest.mark.parametrize("num_kv_heads", [8])
14+
@pytest.mark.parametrize("num_qo_heads", [32])
15+
@pytest.mark.parametrize("is_cuda_graph_compatible", [True, False])
1616
def test_cudnn_decode(
1717
batch_size,
1818
s_kv,
@@ -79,7 +79,11 @@ def test_cudnn_decode(
7979

8080
# Actual sequence lengths (should be randomized across batches. )
8181
actual_seq_lens_kv = torch.randint(
82-
0, s_kv, (batch_size, 1, 1, 1), dtype=torch.int32
82+
0, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
83+
)
84+
85+
ragged_q = torch.arange(0, batch_size + 1, device=device) * (
86+
num_qo_heads * head_dim
8387
)
8488

8589
workspace_buffer_size = math.ceil(
@@ -106,6 +110,8 @@ def test_cudnn_decode(
106110
actual_seq_lens_kv=actual_seq_lens_kv,
107111
block_tables=block_tables,
108112
is_cuda_graph_compatible=is_cuda_graph_compatible,
113+
batch_offsets_q=ragged_q,
114+
batch_offsets_o=ragged_q,
109115
)
110116

111117
actual_seq_lens_kv_device = actual_seq_lens_kv.to(device)

0 commit comments

Comments
 (0)