Skip to content

Commit 2fff58c

Browse files
committed
Adding cudnn-frontend for all arch since it has been published to pypi
1 parent acfad91 commit 2fff58c

File tree

3 files changed

+117
-116
lines changed

3 files changed

+117
-116
lines changed

flashinfer/cudnn/decode.py

Lines changed: 114 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -71,128 +71,129 @@ def _sdpa_decode_key_fn(
7171
)
7272

7373

74-
@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
75-
@cudnn.graph_cache(key_fn=_sdpa_decode_key_fn)
76-
def _build_decode_graph(
77-
q: torch.Tensor,
78-
k_cache: torch.Tensor,
79-
v_cache: torch.Tensor,
80-
scale: float,
81-
*,
82-
max_sequence_kv: int,
83-
block_size: Optional[int] = 1,
84-
actual_seq_lens_q: Optional[torch.Tensor] = None,
85-
actual_seq_lens_kv: Optional[torch.Tensor] = None,
86-
block_tables: Optional[torch.Tensor] = None,
87-
batch_offsets_q: Optional[torch.Tensor] = None,
88-
batch_offsets_o: Optional[torch.Tensor] = None,
89-
):
90-
handle = _create_cudnn_handle(torch.cuda.current_stream())
91-
92-
# WAR: override batch offsets for now, as it leads to a poor performance
93-
batch_offsets_q = None
94-
batch_offsets_o = None
95-
96-
with cudnn.graph(handle) as (g, _):
97-
98-
if q.dim() == 3:
99-
s_qo = 1
100-
b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2]
101-
elif q.dim() == 4:
102-
b, h_qo, s_qo, d_qk = (
103-
q.shape[0],
104-
q.shape[1],
105-
q.shape[2],
106-
q.shape[3],
74+
if CUDNN_AVAILABLE:
75+
76+
@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
77+
@cudnn.graph_cache(key_fn=_sdpa_decode_key_fn)
78+
def _build_decode_graph(
79+
q: torch.Tensor,
80+
k_cache: torch.Tensor,
81+
v_cache: torch.Tensor,
82+
scale: float,
83+
*,
84+
max_sequence_kv: int,
85+
block_size: Optional[int] = 1,
86+
actual_seq_lens_q: Optional[torch.Tensor] = None,
87+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
88+
block_tables: Optional[torch.Tensor] = None,
89+
batch_offsets_q: Optional[torch.Tensor] = None,
90+
batch_offsets_o: Optional[torch.Tensor] = None,
91+
):
92+
handle = _create_cudnn_handle(torch.cuda.current_stream())
93+
94+
# WAR: override batch offsets for now, as it leads to a poor performance
95+
batch_offsets_q = None
96+
batch_offsets_o = None
97+
98+
with cudnn.graph(handle) as (g, _):
99+
if q.dim() == 3:
100+
s_qo = 1
101+
b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2]
102+
elif q.dim() == 4:
103+
b, h_qo, s_qo, d_qk = (
104+
q.shape[0],
105+
q.shape[1],
106+
q.shape[2],
107+
q.shape[3],
108+
)
109+
else:
110+
raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}")
111+
112+
assert s_qo == 1, "q must have a sequence length of 1"
113+
assert k_cache.dim() == 4, "k_cache must have 4 dimensions"
114+
115+
h_kv = k_cache.shape[1]
116+
s_kv = max_sequence_kv
117+
d_vo = v_cache.shape[3]
118+
119+
cudnn_q = g.tensor(
120+
name="q",
121+
dim=(b, h_qo, s_qo, d_qk),
122+
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
123+
data_type=cudnn.data_type.BFLOAT16,
124+
)
125+
if batch_offsets_q is not None:
126+
ragged_q = g.tensor_like(batch_offsets_q)
127+
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
128+
cudnn_q.set_ragged_offset(ragged_q)
129+
130+
cudnn_k_cache = g.tensor_like(k_cache)
131+
cudnn_v_cache = g.tensor_like(v_cache)
132+
133+
cudnn_q.set_uid(UIDs.Q_UID.value)
134+
cudnn_k_cache.set_uid(UIDs.K_UID.value)
135+
cudnn_v_cache.set_uid(UIDs.V_UID.value)
136+
137+
if block_tables is not None:
138+
nd_block_tables = block_tables.reshape(
139+
block_tables.shape[0], 1, block_tables.shape[1], 1
140+
)
141+
cudnn_k_block_tables = g.tensor_like(nd_block_tables)
142+
cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value)
143+
144+
cudnn_v_block_tables = g.tensor_like(nd_block_tables)
145+
cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value)
146+
147+
if actual_seq_lens_q is not None:
148+
cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q)
149+
cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value)
150+
151+
if actual_seq_lens_kv is not None:
152+
cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv)
153+
cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value)
154+
cudnn_actual_seq_lens_kv.set_is_pass_by_value(False)
155+
156+
padding_mask = actual_seq_lens_kv is not None
157+
158+
O, _ = g.sdpa(
159+
name="sdpa",
160+
q=cudnn_q,
161+
k=cudnn_k_cache,
162+
v=cudnn_v_cache,
163+
seq_len_q=(
164+
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
165+
),
166+
seq_len_kv=(
167+
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
168+
),
169+
use_padding_mask=padding_mask,
170+
is_inference=True,
171+
attn_scale=scale,
172+
paged_attention_k_table=cudnn_k_block_tables,
173+
paged_attention_v_table=cudnn_v_block_tables,
174+
paged_attention_max_seq_len_kv=max_sequence_kv,
175+
compute_data_type=cudnn.data_type.FLOAT,
107176
)
108-
else:
109-
raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}")
110-
111-
assert s_qo == 1, "q must have a sequence length of 1"
112-
assert k_cache.dim() == 4, "k_cache must have 4 dimensions"
113-
114-
h_kv = k_cache.shape[1]
115-
s_kv = max_sequence_kv
116-
d_vo = v_cache.shape[3]
117-
118-
cudnn_q = g.tensor(
119-
name="q",
120-
dim=(b, h_qo, s_qo, d_qk),
121-
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
122-
data_type=cudnn.data_type.BFLOAT16,
123-
)
124-
if batch_offsets_q is not None:
125-
ragged_q = g.tensor_like(batch_offsets_q)
126-
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
127-
cudnn_q.set_ragged_offset(ragged_q)
128-
129-
cudnn_k_cache = g.tensor_like(k_cache)
130-
cudnn_v_cache = g.tensor_like(v_cache)
131177

132-
cudnn_q.set_uid(UIDs.Q_UID.value)
133-
cudnn_k_cache.set_uid(UIDs.K_UID.value)
134-
cudnn_v_cache.set_uid(UIDs.V_UID.value)
178+
if batch_offsets_o is not None:
179+
ragged_o = g.tensor_like(batch_offsets_o)
180+
ragged_o.set_uid(UIDs.RAGGED_O_UID.value)
181+
O.set_ragged_offset(ragged_o)
135182

136-
if block_tables is not None:
137-
nd_block_tables = block_tables.reshape(
138-
block_tables.shape[0], 1, block_tables.shape[1], 1
183+
O.set_uid(UIDs.O_UID.value).set_output(True).set_dim(
184+
[b, h_qo, s_qo, d_vo]
185+
).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type(
186+
cudnn.data_type.BFLOAT16
139187
)
140-
cudnn_k_block_tables = g.tensor_like(nd_block_tables)
141-
cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value)
142188

143-
cudnn_v_block_tables = g.tensor_like(nd_block_tables)
144-
cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value)
189+
tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O]
145190

146191
if actual_seq_lens_q is not None:
147-
cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q)
148-
cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value)
149-
192+
tensors_to_return.append(cudnn_actual_seq_lens_q)
150193
if actual_seq_lens_kv is not None:
151-
cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv)
152-
cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value)
153-
cudnn_actual_seq_lens_kv.set_is_pass_by_value(False)
154-
155-
padding_mask = actual_seq_lens_kv is not None
156-
157-
O, _ = g.sdpa(
158-
name="sdpa",
159-
q=cudnn_q,
160-
k=cudnn_k_cache,
161-
v=cudnn_v_cache,
162-
seq_len_q=(
163-
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
164-
),
165-
seq_len_kv=(
166-
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
167-
),
168-
use_padding_mask=padding_mask,
169-
is_inference=True,
170-
attn_scale=scale,
171-
paged_attention_k_table=cudnn_k_block_tables,
172-
paged_attention_v_table=cudnn_v_block_tables,
173-
paged_attention_max_seq_len_kv=max_sequence_kv,
174-
compute_data_type=cudnn.data_type.FLOAT,
175-
)
176-
177-
if batch_offsets_o is not None:
178-
ragged_o = g.tensor_like(batch_offsets_o)
179-
ragged_o.set_uid(UIDs.RAGGED_O_UID.value)
180-
O.set_ragged_offset(ragged_o)
181-
182-
O.set_uid(UIDs.O_UID.value).set_output(True).set_dim(
183-
[b, h_qo, s_qo, d_vo]
184-
).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type(
185-
cudnn.data_type.BFLOAT16
186-
)
187-
188-
tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O]
189-
190-
if actual_seq_lens_q is not None:
191-
tensors_to_return.append(cudnn_actual_seq_lens_q)
192-
if actual_seq_lens_kv is not None:
193-
tensors_to_return.append(cudnn_actual_seq_lens_kv)
194+
tensors_to_return.append(cudnn_actual_seq_lens_kv)
194195

195-
return g, tensors_to_return
196+
return g, tensors_to_return
196197

197198

198199
def _batch_decode_with_kv_cache(

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def generate_build_meta(aot_build_meta: dict) -> None:
6161
"pynvml",
6262
"einops",
6363
"nvidia-nvshmem-cu12",
64-
"nvidia-cudnn-cu12>=9.11.0",
65-
'nvidia-cudnn-frontend>=1.13.0; platform_machine == "x86_64" or platform_machine == "AMD64"',
64+
"nvidia-cudnn-cu12",
65+
"nvidia-cudnn-frontend>=1.13.0",
6666
]
6767
generate_build_meta({})
6868

tests/test_cudnn_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@pytest.mark.parametrize("batch_size", [8, 16, 32])
11-
@pytest.mark.parametrize("s_kv", [1024, 8192])
11+
@pytest.mark.parametrize("s_kv", [512, 8192])
1212
@pytest.mark.parametrize("page_size", [16])
1313
@pytest.mark.parametrize("num_kv_heads", [8])
1414
@pytest.mark.parametrize("num_qo_heads", [32])

0 commit comments

Comments
 (0)