Fix cutlass_blackwell_fmha_custom_op and add comprehensive FMHA tests #5108
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This diff fixes the cutlass_blackwell_fmha_custom_op.py to be fully functional and adds comprehensive testing for Blackwell FMHA (Fused Multi-Head Attention).
Changes Made:
1. Fixed
cutlass_blackwell_fmha_custom_op.pyfmha_fwd:page_table,seqlen_k,window_size_left,window_size_right,bottom_rightfmha_bwd:softmax_scale,window_size_left,window_size_right,bottom_right,deterministictorch.ops.fbgemm.fmha_fwd/bwdexpectintandbooltypes, notOptional[int]orOptional[bool]window_size_left = -1(default for no left window)window_size_right = -1(default for no right window)bottom_right = True(default)deterministic = False(default)_backward,_setup_context, and wrapper functions to properly pass all parameterstorch.ops.fbgemm.fmha_fwdandtorch.ops.fbgemm.fmha_bwd2. Created
blackwell_fmha.pyTest Fileblackwell_gdpa.pyas referencecutlass_blackwell_fmha_custom_op(Cutlass implementation) for forward and backward passesjagged_flash_attention_v2(Triton JFA v2 implementation)generate_jagged_datautility for proper test data generation3. Updated BUCK Dependencies
//ads_mkl/ops:jfato//ads_mkl/ops/triton:triton_jfa_v2//ads_mkl/ops/utils:jfa_utilsfor data generation utilitiesblackwell_attention_ops_gputoblackwell_attentionto include Python bindingsDifferential Revision: D86583157