-
Notifications
You must be signed in to change notification settings - Fork 33
[Transform] Attention/Cache transforms #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, though i have a number of questions and minor suggestions
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😬 i don't want to delay things, but we should briefly consider if there are alternative solutions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I spent 20 minutes exploring this, it requires creating specialized _ct_hooked_attention
functions and specialized QuantizedAttentionImpl
, which is more complexity than value added imho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can _original_impl
be registered on the module level (i.e. each self_attn block) instead of setting a global var?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but in order to register the _original_impl
, it needs to be gotten from somewhere.
The first time, you "get" it from model.config
. However on subsequent calls, model.config
is overridden. This means that in order to "get" the original implementation, you'd have to go find the last Attention module you registered it to, or else store it in some global store.
You could register it to the model module itself or something like that, but I think that that's less reliable than just a a global store. If it's functionality you're after, we can turn it into a hash table or something, keyed by model hash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention
and initialize_hooked_kv_cache
to initialize.py
?
I understand we haven't hooked them in yet for those workflows but I think these belong there.
7bf4b57
to
75056bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do a pass through on any missing docstring, otherwise lgtm.
nice work
The base branch was changed.
e224a5d
to
05ec17e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following for the most part. A few clarifications, but this makes sense to me
QuantizedAttentionImpl module which wraps the functionality of the original | ||
attention implementation. Unlike the original attention function, this | ||
implementation is a `torch.nn.Module` which can be hooked to trigger | ||
transforms and calibration hooks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this wrap every single attention block? If so, global _original_impl
will be re-set multiple times, though if the same attention function is used throughout the entire model that's probably ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We guard against multiple sets using if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
# assumes only one model at a time | ||
global _original_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can _original_impl
be registered on the module level (i.e. each self_attn block) instead of setting a global var?
d084c5e
to
e3f24d4
Compare
The base branch was changed.
145c9aa
to
2efe3db
Compare
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
7c19358
to
04f716a
Compare
Purpose
Prerequisites
Changes
New Classes
QuantizedAttentionImpl
injects itself into the model by registering a new attention implementation calledct_hooked_attention
overridingmodel.config._attn_implementation
to be the new implementation nameQuantizedKVCache
injects itself into the model by overriding thepast_key_values
input kwarg to attention, and wrapping the functionality of the original cacheregister_query_hook
,register_key_hook
register_value_hook
Quantization Lifecycle Changes
initialize_hooked_kv_cache
initialize_hooked_attention
if attention modules are explicitly targted (seeis_narrow_match
)initialize_module_for_quantization
QuantizationConfig. from_pretrained
was cleaned up with additional commentskv_cache_scheme
field is added if there are any attention modules with aquantization_scheme
attachedHelpers
is_narrow_match
is used to check that attention modules are being specifically targeted (rather than targeting all modules in a layer)get_head_dim
is used to get the attention head_dim from a configTesting
is_narrow_match