You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix cutlass_blackwell_fmha_custom_op and add comprehensive FMHA tests (#5108)
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2113
Pull Request resolved: #5108
This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention).
## Changes Made:
### 1. Fixed `cutlass_blackwell_fmha_custom_op.py`
- Added missing parameters to `fmha_fwd`: `page_table`, `seqlen_k`, `window_size_left`, `window_size_right`, `bottom_right`
- Added missing parameters to `fmha_bwd`: `softmax_scale`, `window_size_left`, `window_size_right`, `bottom_right`, `deterministic`
- Fixed parameter type issues: `torch.ops.fbgemm.fmha_fwd/bwd` expect `int` and `bool` types, not `Optional[int]` or `Optional[bool]`
- Added proper default value handling:
- `window_size_left = -1` (default for no left window)
- `window_size_right = -1` (default for no right window)
- `bottom_right = True` (default)
- `deterministic = False` (default)
- Updated `_backward`, `_setup_context`, and wrapper functions to properly pass all parameters
- The custom op now correctly wraps `torch.ops.fbgemm.fmha_fwd` and `torch.ops.fbgemm.fmha_bwd`
### 2. Created `blackwell_fmha.py` Test File
- Structured following `blackwell_gdpa.py` as reference
- Uses `cutlass_blackwell_fmha_custom_op` (Cutlass implementation) for forward and backward passes
- Compares against `jagged_flash_attention_v2` (Triton JFA v2 implementation)
- Tests BF16 dtype only (as specified)
- Tests both forward outputs and backward gradients (dq, dk, dv)
- Runs 10 random test configurations with varying batch sizes, sequence lengths, and number of heads
- Uses `generate_jagged_data` utility for proper test data generation
### 3. Updated BUCK Dependencies
- Changed from `//ads_mkl/ops:jfa` to `//ads_mkl/ops/triton:triton_jfa_v2`
- Added `//ads_mkl/ops/utils:jfa_utils` for data generation utilities
- Changed from `blackwell_attention_ops_gpu` to `blackwell_attention` to include Python bindings
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Trace)
Reviewed By: devashishshankar
Differential Revision: D86583157
fbshipit-source-id: 8771f26c80b587694e2568e6b3232d4ae367c915
Copy file name to clipboardExpand all lines: fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py
0 commit comments