Skip to content

Commit 6350109

Browse files
jsisometameta-codesync[bot]
authored andcommitted
Fix cutlass_blackwell_fmha_custom_op and add comprehensive FMHA tests (#5108)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2113 Pull Request resolved: #5108 This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention). ## Changes Made: ### 1. Fixed `cutlass_blackwell_fmha_custom_op.py` - Added missing parameters to `fmha_fwd`: `page_table`, `seqlen_k`, `window_size_left`, `window_size_right`, `bottom_right` - Added missing parameters to `fmha_bwd`: `softmax_scale`, `window_size_left`, `window_size_right`, `bottom_right`, `deterministic` - Fixed parameter type issues: `torch.ops.fbgemm.fmha_fwd/bwd` expect `int` and `bool` types, not `Optional[int]` or `Optional[bool]` - Added proper default value handling: - `window_size_left = -1` (default for no left window) - `window_size_right = -1` (default for no right window) - `bottom_right = True` (default) - `deterministic = False` (default) - Updated `_backward`, `_setup_context`, and wrapper functions to properly pass all parameters - The custom op now correctly wraps `torch.ops.fbgemm.fmha_fwd` and `torch.ops.fbgemm.fmha_bwd` ### 2. Created `blackwell_fmha.py` Test File - Structured following `blackwell_gdpa.py` as reference - Uses `cutlass_blackwell_fmha_custom_op` (Cutlass implementation) for forward and backward passes - Compares against `jagged_flash_attention_v2` (Triton JFA v2 implementation) - Tests BF16 dtype only (as specified) - Tests both forward outputs and backward gradients (dq, dk, dv) - Runs 10 random test configurations with varying batch sizes, sequence lengths, and number of heads - Uses `generate_jagged_data` utility for proper test data generation ### 3. Updated BUCK Dependencies - Changed from `//ads_mkl/ops:jfa` to `//ads_mkl/ops/triton:triton_jfa_v2` - Added `//ads_mkl/ops/utils:jfa_utils` for data generation utilities - Changed from `blackwell_attention_ops_gpu` to `blackwell_attention` to include Python bindings --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Trace) Reviewed By: devashishshankar Differential Revision: D86583157 fbshipit-source-id: 8771f26c80b587694e2568e6b3232d4ae367c915
1 parent d8dcd23 commit 6350109

File tree

1 file changed

+75
-3
lines changed

1 file changed

+75
-3
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212

1313
torch.library.define(
1414
"blackwell_fmha::fmha_fwd",
15-
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)",
15+
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
1616
tags=torch.Tag.pt2_compliant_tag,
1717
)
1818

1919
torch.library.define(
2020
"blackwell_fmha::fmha_bwd",
21-
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)",
21+
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
2222
tags=torch.Tag.pt2_compliant_tag,
2323
)
2424

@@ -35,13 +35,19 @@ def custom_op_fmha(
3535
softmax_scale: Optional[float] = None,
3636
causal: bool = False,
3737
seqlen_kv: Optional[torch.Tensor] = None,
38+
page_table: Optional[torch.Tensor] = None,
39+
seqlen_k: Optional[int] = None,
40+
window_size_left: int = -1,
41+
window_size_right: int = -1,
42+
bottom_right: bool = True,
3843
) -> tuple[torch.Tensor, torch.Tensor]:
3944
assert q.is_contiguous(), "q is not contiguous"
4045
assert k.is_contiguous(), "k is not contiguous"
4146
assert v.is_contiguous(), "v is not contiguous"
4247
assert q.is_cuda, "q must be on GPU"
4348
assert k.is_cuda, "k must be on GPU"
4449
assert v.is_cuda, "v must be on GPU"
50+
4551
return torch.ops.fbgemm.fmha_fwd(
4652
q,
4753
k,
@@ -53,6 +59,11 @@ def custom_op_fmha(
5359
softmax_scale=softmax_scale,
5460
causal=causal,
5561
seqlen_kv=seqlen_kv,
62+
page_table=page_table,
63+
seqlen_k=seqlen_k,
64+
window_size_left=window_size_left,
65+
window_size_right=window_size_right,
66+
bottom_right=bottom_right,
5667
)
5768

5869

@@ -68,6 +79,11 @@ def fmha_fwd_meta(
6879
softmax_scale: Optional[float] = None,
6980
causal: bool = False,
7081
seqlen_kv: Optional[torch.Tensor] = None,
82+
page_table: Optional[torch.Tensor] = None,
83+
seqlen_k: Optional[int] = None,
84+
window_size_left: int = -1,
85+
window_size_right: int = -1,
86+
bottom_right: bool = True,
7187
):
7288
if q.dtype == torch.float16:
7389
out_dtype = torch.float16
@@ -122,8 +138,14 @@ def custom_op_fmha_bwd(
122138
cu_seqlens_k: Optional[torch.Tensor] = None,
123139
max_seq_len_q: Optional[int] = None,
124140
max_seq_len_k: Optional[int] = None,
141+
softmax_scale: Optional[float] = None,
125142
causal: bool = False,
143+
window_size_left: int = -1,
144+
window_size_right: int = -1,
145+
bottom_right: bool = True,
146+
deterministic: bool = False,
126147
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148+
127149
return torch.ops.fbgemm.fmha_bwd(
128150
dOutput,
129151
query,
@@ -135,7 +157,12 @@ def custom_op_fmha_bwd(
135157
cu_seqlens_k=cu_seqlens_k,
136158
max_seq_len_q=max_seq_len_q,
137159
max_seq_len_k=max_seq_len_k,
160+
softmax_scale=softmax_scale,
138161
causal=causal,
162+
window_size_left=window_size_left,
163+
window_size_right=window_size_right,
164+
bottom_right=bottom_right,
165+
deterministic=deterministic,
139166
)
140167

141168

@@ -151,7 +178,12 @@ def fmha_bwd_meta(
151178
cu_seqlens_k: Optional[torch.Tensor] = None,
152179
max_seq_len_q: Optional[int] = None,
153180
max_seq_len_k: Optional[int] = None,
181+
softmax_scale: Optional[float] = None,
154182
causal: bool = False,
183+
window_size_left: int = -1,
184+
window_size_right: int = -1,
185+
bottom_right: bool = True,
186+
deterministic: bool = False,
155187
):
156188
return (
157189
torch.empty_like(query),
@@ -198,9 +230,30 @@ def _backward(ctx, *grad):
198230
ctx.cu_seqlens_k,
199231
ctx.max_seq_len_q,
200232
ctx.max_seq_len_k,
233+
ctx.softmax_scale,
201234
ctx.causal,
235+
ctx.window_size_left,
236+
ctx.window_size_right,
237+
ctx.bottom_right,
238+
ctx.deterministic,
239+
)
240+
return (
241+
dq,
242+
dk,
243+
dv,
244+
None,
245+
None,
246+
None,
247+
None,
248+
None,
249+
None,
250+
None,
251+
None,
252+
None,
253+
None,
254+
None,
255+
None,
202256
)
203-
return dq, dk, dv, None, None, None, None, None, None, None
204257

205258

206259
def _setup_context(ctx, inputs, output):
@@ -215,6 +268,11 @@ def _setup_context(ctx, inputs, output):
215268
softmax_scale,
216269
causal,
217270
seqlen_kv,
271+
page_table,
272+
seqlen_k,
273+
window_size_left,
274+
window_size_right,
275+
bottom_right,
218276
) = inputs
219277
(out, softmax_lse) = output
220278
ctx.save_for_backward(q, k, v, out, softmax_lse)
@@ -224,6 +282,10 @@ def _setup_context(ctx, inputs, output):
224282
ctx.max_seq_len_k = max_seq_len_k
225283
ctx.cu_seqlens_q = cu_seqlens_q
226284
ctx.cu_seqlens_k = cu_seqlens_k
285+
ctx.window_size_left = window_size_left
286+
ctx.window_size_right = window_size_right
287+
ctx.bottom_right = bottom_right
288+
ctx.deterministic = False # Set default value
227289
ctx.is_gen = False
228290

229291

@@ -246,6 +308,11 @@ def cutlass_blackwell_fmha_custom_op(
246308
max_seq_len_q: int | None = None,
247309
max_seq_len_k: int | None = None,
248310
seqlen_kv: torch.Tensor | None = None,
311+
page_table: torch.Tensor | None = None,
312+
seqlen_k: int | None = -1,
313+
window_size_left: int | None = -1,
314+
window_size_right: int | None = -1,
315+
bottom_right: bool | None = True,
249316
):
250317
return torch.ops.blackwell_fmha.fmha_fwd(
251318
q=q,
@@ -258,4 +325,9 @@ def cutlass_blackwell_fmha_custom_op(
258325
softmax_scale=softmax_scale,
259326
causal=causal,
260327
seqlen_kv=seqlen_kv,
328+
page_table=page_table,
329+
seqlen_k=seqlen_k,
330+
window_size_left=window_size_left,
331+
window_size_right=window_size_right,
332+
bottom_right=bottom_right,
261333
)[0]

0 commit comments

Comments
 (0)