-
Notifications
You must be signed in to change notification settings - Fork 30.8k
[FA
] Cleanup loading logic
#41427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[FA
] Cleanup loading logic
#41427
Conversation
if hasattr(kernel, "flash_attn_varlen_func"): | ||
if attention_wrapper is None: | ||
attention_wrapper = flash_attention_forward | ||
kernel_function = partial(attention_wrapper, implementation=kernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This didn't have any effect currently and we only used the forced loading to get the correct implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no this is super important for paged_attention
wrappers 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the logic so that CB is now able to use all fa versions with lazy loading, i.e. fa2, fa3, kernels fas
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
if attention_wrapper is None: | ||
attention_wrapper = flash_attention_forward | ||
kernel_function = partial(attention_wrapper, implementation=kernel) | ||
lazy_import_flash_attention(kernel, force_import=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy import only happens at init or explicit setting time
transformers/src/transformers/modeling_utils.py
Lines 2580 to 2582 in 766ac18
# preload flash attention here to allow compile with fullgraph | |
if "flash" in applicable_attn_implementation: | |
lazy_import_flash_attention(applicable_attn_implementation) |
transformers/src/transformers/generation/continuous_batching/continuous_api.py
Lines 614 to 618 in ced85ca
# lazy loading flash attention including kernel variations | |
if "flash" in attn_implementation: | |
from ...modeling_flash_attention_utils import lazy_import_paged_flash_attention | |
lazy_import_paged_flash_attention(attn_implementation) |
Otherwise, it can happen when users only set the attention within the config, i.e. the interface detects a different implementation name than the currently loaded one
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 143 to 144 in 766ac18
if implementation is not None and _loaded_implementation != implementation: | |
_loaded_implementation = implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure you test the continious_batching_simple
script with fa2 !
if hasattr(kernel, "flash_attn_varlen_func"): | ||
if attention_wrapper is None: | ||
attention_wrapper = flash_attention_forward | ||
kernel_function = partial(attention_wrapper, implementation=kernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no this is super important for paged_attention
wrappers 😓
@ArthurZucker updated the logic a bit to include CB now as well. The loaded attention is now properly dependent on the config like in the base fa version + we can now use fa3 as well for CB. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do:
- partial(attention_wrapper, implementation=kernel)
in that case we should remove this arg from all xxx_paged
Seems like the
|
We currently ignore anything passed through the interface and forced correct loading at attn setting time. This changes it to be more flexible where power users may customly switch it up by just setting it in the config.