Skip to content

Conversation

@jsisometa
Copy link
Contributor

Summary:
This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention).

Changes Made:

1. Fixed cutlass_blackwell_fmha_custom_op.py

  • Added missing parameters to fmha_fwd: page_table, seqlen_k, window_size_left, window_size_right, bottom_right
  • Added missing parameters to fmha_bwd: softmax_scale, window_size_left, window_size_right, bottom_right, deterministic
  • Fixed parameter type issues: torch.ops.fbgemm.fmha_fwd/bwd expect int and bool types, not Optional[int] or Optional[bool]
  • Added proper default value handling:
    • window_size_left = -1 (default for no left window)
    • window_size_right = -1 (default for no right window)
    • bottom_right = True (default)
    • deterministic = False (default)
  • Updated _backward, _setup_context, and wrapper functions to properly pass all parameters
  • The custom op now correctly wraps torch.ops.fbgemm.fmha_fwd and torch.ops.fbgemm.fmha_bwd

2. Created blackwell_fmha.py Test File

  • Structured following blackwell_gdpa.py as reference
  • Uses cutlass_blackwell_fmha_custom_op (Cutlass implementation) for forward and backward passes
  • Compares against jagged_flash_attention_v2 (Triton JFA v2 implementation)
  • Tests BF16 dtype only (as specified)
  • Tests both forward outputs and backward gradients (dq, dk, dv)
  • Runs 10 random test configurations with varying batch sizes, sequence lengths, and number of heads
  • Uses generate_jagged_data utility for proper test data generation

3. Updated BUCK Dependencies

  • Changed from //ads_mkl/ops:jfa to //ads_mkl/ops/triton:triton_jfa_v2
  • Added //ads_mkl/ops/utils:jfa_utils for data generation utilities
  • Changed from blackwell_attention_ops_gpu to blackwell_attention to include Python bindings

Generated by Confucius Code Assist (CCA)
Session, Trace

Differential Revision: D86583157

@netlify
Copy link

netlify bot commented Nov 10, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit bed170c
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/691364b0918a1a0008f92021
😎 Deploy Preview https://deploy-preview-5108--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@meta-cla meta-cla bot added the cla signed label Nov 10, 2025
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Nov 10, 2025

@jsisometa has exported this pull request. If you are a Meta employee, you can view the originating Diff in D86583157.

…pytorch#5108)

Summary:
X-link: facebookresearch/FBGEMM#2113


This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention).

## Changes Made:

### 1. Fixed `cutlass_blackwell_fmha_custom_op.py`
- Added missing parameters to `fmha_fwd`: `page_table`, `seqlen_k`, `window_size_left`, `window_size_right`, `bottom_right`
- Added missing parameters to `fmha_bwd`: `softmax_scale`, `window_size_left`, `window_size_right`, `bottom_right`, `deterministic`
- Fixed parameter type issues: `torch.ops.fbgemm.fmha_fwd/bwd` expect `int` and `bool` types, not `Optional[int]` or `Optional[bool]`
- Added proper default value handling:
  - `window_size_left = -1` (default for no left window)
  - `window_size_right = -1` (default for no right window)
  - `bottom_right = True` (default)
  - `deterministic = False` (default)
- Updated `_backward`, `_setup_context`, and wrapper functions to properly pass all parameters
- The custom op now correctly wraps `torch.ops.fbgemm.fmha_fwd` and `torch.ops.fbgemm.fmha_bwd`

### 2. Created `blackwell_fmha.py` Test File
- Structured following `blackwell_gdpa.py` as reference
- Uses `cutlass_blackwell_fmha_custom_op` (Cutlass implementation) for forward and backward passes
- Compares against `jagged_flash_attention_v2` (Triton JFA v2 implementation)
- Tests BF16 dtype only (as specified)
- Tests both forward outputs and backward gradients (dq, dk, dv)
- Runs 10 random test configurations with varying batch sizes, sequence lengths, and number of heads
- Uses `generate_jagged_data` utility for proper test data generation

### 3. Updated BUCK Dependencies
- Changed from `//ads_mkl/ops:jfa` to `//ads_mkl/ops/triton:triton_jfa_v2`
- Added `//ads_mkl/ops/utils:jfa_utils` for data generation utilities
- Changed from `blackwell_attention_ops_gpu` to `blackwell_attention` to include Python bindings

---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=96622022-bc27-11f0-bdba-7c8c09f29af2&tab=Trace)

Differential Revision: D86583157
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Nov 12, 2025

This pull request has been merged in 6350109.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants