Skip to content

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Oct 7, 2025

We currently ignore anything passed through the interface and forced correct loading at attn setting time. This changes it to be more flexible where power users may customly switch it up by just setting it in the config.

if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
Copy link
Contributor Author

@vasqu vasqu Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't have any effect currently and we only used the forced loading to get the correct implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is super important for paged_attention wrappers 😓

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the logic so that CB is now able to use all fa versions with lazy loading, i.e. fa2, fa3, kernels fas

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel, force_import=True)
Copy link
Contributor Author

@vasqu vasqu Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy import only happens at init or explicit setting time

# preload flash attention here to allow compile with fullgraph
if "flash" in applicable_attn_implementation:
lazy_import_flash_attention(applicable_attn_implementation)

# lazy loading flash attention including kernel variations
if "flash" in attn_implementation:
from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention
lazy_import_paged_flash_attention(attn_implementation)

Otherwise, it can happen when users only set the attention within the config, i.e. the interface detects a different implementation name than the currently loaded one

if implementation is not None and _loaded_implementation != implementation:
_loaded_implementation = implementation

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure you test the continious_batching_simple script with fa2 !

if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is super important for paged_attention wrappers 😓

@vasqu
Copy link
Contributor Author

vasqu commented Oct 8, 2025

@ArthurZucker updated the logic a bit to include CB now as well. The loaded attention is now properly dependent on the config like in the base fa version + we can now use fa3 as well for CB.

@huggingface huggingface deleted a comment from github-actions bot Oct 9, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do:

- partial(attention_wrapper, implementation=kernel)

in that case we should remove this arg from all xxx_paged

@vasqu
Copy link
Contributor Author

vasqu commented Oct 14, 2025

Seems like the implementation kwarg was only used in flash_paged, the other implementations don't have it:

  • eager
    def eager_paged_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor], # shape [seqlen_q, seqlen_k]
    scaling: float,
    **kwargs,
    ):
  • sdpa
    def sdpa_attention_paged_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    **kwargs,
    ) -> tuple[torch.Tensor, None]:
  • flash removed in this PR
    def paged_attention_forward(
    module: torch.nn.Module,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    cache: PagedAttentionCache = None,
    cu_seq_lens_q=None,
    cu_seq_lens_k=None,
    max_seqlen_q=None,
    max_seqlen_k=None,
    **kwargs,
    ) -> torch.Tensor:

    So we should be good @ArthurZucker?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants