-
Notifications
You must be signed in to change notification settings - Fork 33
[Transform] Attention/Cache transforms #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
# isort: off | ||
from .kvcache import * | ||
from .attention import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from typing import Callable, Optional | ||
from weakref import ref | ||
|
||
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache | ||
from compressed_tensors.quantization.lifecycle.forward import forward_quantize | ||
from compressed_tensors.utils import getattr_chain | ||
from compressed_tensors.utils.internal import InternalModule | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.utils.hooks import RemovableHandle | ||
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel | ||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
|
||
|
||
__all__ = [ | ||
"QuantizedAttentionImpl", | ||
"initialize_hooked_attention", | ||
"register_query_hook", | ||
"IMPL_ATTR", | ||
] | ||
|
||
|
||
IMPL_ATTR = "impl" | ||
HOOKED_ATTENTION_NAME = "ct_hooked_attention" | ||
|
||
|
||
class QuantizedAttentionImpl(InternalModule): | ||
""" | ||
QuantizedAttentionImpl module which wraps the functionality of the original | ||
attention implementation. Unlike the original attention function, this | ||
implementation is a `torch.nn.Module` which can be hooked to trigger | ||
transforms and calibration hooks. | ||
|
||
This module works by being registered as a submodule to attention modules via | ||
`initialize_hooked_attention`, registering a new attention implementation function | ||
which calls this module, then setting the model attention implementation to the new | ||
function. After triggering hooks and quantization, this module calls the original | ||
attention implementation function. | ||
|
||
:param attn_module: parent attention module | ||
""" | ||
|
||
_original_impl = "eager" | ||
|
||
def __init__(self, config: PretrainedConfig, attn_module: Module): | ||
super().__init__() | ||
self.config = config | ||
self.attn_module = ref(attn_module) # avoid circular references | ||
|
||
def forward( | ||
self, | ||
module: Module, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
*args, | ||
**kwargs, | ||
): | ||
# quantization | ||
quant_args_attr = "quantization_scheme.input_activations" | ||
quant_args = getattr_chain(module, quant_args_attr, None) | ||
quant_enabled = getattr(module, "quantization_enabled", True) | ||
if quant_args is not None and quant_enabled: | ||
query = forward_quantize(module, query, "q", quant_args) | ||
brian-dellabetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# original attention | ||
return ALL_ATTENTION_FUNCTIONS[_original_impl]( | ||
module, | ||
query, | ||
key, | ||
value, | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
# ----- initialize ----- # | ||
|
||
|
||
def _ct_hooked_attention(module: Module, *args, **kwargs): | ||
if hasattr(module, IMPL_ATTR): | ||
return module.impl(module, *args, **kwargs) | ||
else: | ||
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) | ||
|
||
|
||
def initialize_hooked_attention(model: PreTrainedModel, module: Module): | ||
""" | ||
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances | ||
attached to attention | ||
|
||
:param model: parent model of attention module | ||
:param module: attention module to initialize with | ||
""" | ||
if not hasattr(module, IMPL_ATTR): | ||
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module)) | ||
if model.config._attn_implementation != HOOKED_ATTENTION_NAME: | ||
# assumes only one model at a time | ||
global _original_impl | ||
Comment on lines
+113
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😬 i don't want to delay things, but we should briefly consider if there are alternative solutions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I spent 20 minutes exploring this, it requires creating specialized There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but in order to register the The first time, you "get" it from You could register it to the model module itself or something like that, but I think that that's less reliable than just a a global store. If it's functionality you're after, we can turn it into a hash table or something, keyed by model hash. |
||
_original_impl = model.config._attn_implementation | ||
|
||
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) | ||
model.config._attn_implementation = HOOKED_ATTENTION_NAME | ||
|
||
initialize_hooked_kv_cache(model, module) | ||
|
||
|
||
# ----- hooks ----- # | ||
|
||
|
||
def register_query_hook( | ||
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] | ||
) -> RemovableHandle: | ||
""" | ||
Register a hook which takes post-rope query states as an argument and | ||
returns the modified query states or `None` | ||
|
||
:param module: attention module to add hook to | ||
:param hook: query hook function | ||
""" | ||
impl = getattr(module, IMPL_ATTR) | ||
|
||
def _hook(impl: QuantizedAttentionImpl, args, kwargs): | ||
bound = inspect.signature(impl.forward).bind(*args, **kwargs) | ||
value = hook(module, bound.arguments["query"]) | ||
if value is not None: | ||
bound.arguments["query"] = value | ||
|
||
return bound.args, bound.kwargs | ||
|
||
return impl.register_forward_pre_hook(_hook, with_kwargs=True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from typing import Callable, Optional, Tuple | ||
from weakref import ref | ||
|
||
from compressed_tensors.quantization.lifecycle.forward import forward_quantize | ||
from compressed_tensors.utils import getattr_chain | ||
from compressed_tensors.utils.internal import InternalModule | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.utils.hooks import RemovableHandle | ||
from transformers import Cache, PretrainedConfig, PreTrainedModel | ||
|
||
|
||
__all__ = [ | ||
"QuantizedKVCache", | ||
"initialize_hooked_kv_cache", | ||
"register_key_hook", | ||
"register_value_hook", | ||
"KV_CACHE_ATTR", | ||
] | ||
|
||
|
||
KV_CACHE_ATTR = "kv_cache" | ||
|
||
|
||
class QuantizedKVCache(InternalModule): | ||
""" | ||
QuantizedKVCache module which wraps the functionality of any existing kvcache args. | ||
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be | ||
hooked to trigger transforms and calibration hooks. | ||
|
||
This module works by being registered as a submodule to attention modules via | ||
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values` | ||
kwargs with this module. This module adopts the functionality of the replaced cache, | ||
preserving caching functionality such as sliding window attention, ect. | ||
|
||
:param attn_module: parent attention module | ||
""" | ||
|
||
def __init__(self, config: PretrainedConfig, attn_module: Module): | ||
super().__init__() | ||
self.config = config | ||
self.attn_module = ref(attn_module) # avoid circular reference | ||
self.past_key_values: Optional[Cache] = None | ||
|
||
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: | ||
return self(*args, **kwargs) | ||
|
||
def forward( | ||
self, | ||
key_states: Tensor, | ||
value_states: Tensor, | ||
*args, | ||
**kwargs, | ||
) -> Tuple[Tensor, Tensor]: | ||
# quantization | ||
module = self.attn_module() | ||
quant_args_attr = "quantization_scheme.input_activations" | ||
quant_args = getattr_chain(module, quant_args_attr, None) | ||
quant_enabled = getattr(module, "quantization_enabled", True) | ||
if quant_args is not None and quant_enabled: | ||
key_states = forward_quantize(module, key_states, "k", quant_args) | ||
value_states = forward_quantize(module, value_states, "v", quant_args) | ||
|
||
# original cache | ||
if self.past_key_values is not None: | ||
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) | ||
else: | ||
ret = (key_states, value_states) | ||
|
||
self.past_key_values = None | ||
return ret | ||
|
||
|
||
# ----- initialize ----- # | ||
|
||
|
||
def _kv_cache_attention_hook(module: Module, args, kwargs): | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) | ||
_past_kv_name = ( | ||
"past_key_values" # transformers#39956 | ||
if "past_key_values" in inspect.signature(module.forward).parameters | ||
else "past_key_value" | ||
) | ||
kv_cache.past_key_values = kwargs.get(_past_kv_name, None) | ||
kwargs[_past_kv_name] = kv_cache | ||
|
||
return args, kwargs | ||
|
||
|
||
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module): | ||
""" | ||
Initialize a `QuantizedKVCache` instance attached to attention | ||
|
||
:param model: parent model of attention module | ||
:param module: attention module to initialize with | ||
""" | ||
if not hasattr(module, KV_CACHE_ATTR): | ||
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) | ||
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) | ||
|
||
|
||
# ----- hooks ----- # | ||
|
||
|
||
def register_key_hook( | ||
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] | ||
) -> RemovableHandle: | ||
""" | ||
Register a hook which takes post-rope key states as an argument and | ||
returns the modified key states or `None` | ||
|
||
:param module: attention module to add hook to | ||
:param hook: key hook function | ||
""" | ||
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) | ||
|
||
def _hook(cache: QuantizedKVCache, args, kwargs): | ||
bound = inspect.signature(cache.forward).bind(*args, **kwargs) | ||
value = hook(module, bound.arguments["key_states"]) | ||
if value is not None: | ||
bound.arguments["key_states"] = value | ||
|
||
return bound.args, bound.kwargs | ||
|
||
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) | ||
|
||
|
||
def register_value_hook( | ||
brian-dellabetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] | ||
) -> RemovableHandle: | ||
""" | ||
Register a hook which takes value states as an argument and | ||
returns the modified value states or `None` | ||
|
||
:param module: attention module to add hook to | ||
:param hook: value hook function | ||
""" | ||
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) | ||
|
||
def _hook(cache: QuantizedKVCache, args, kwargs): | ||
bound = inspect.signature(cache.forward).bind(*args, **kwargs) | ||
value = hook(module, bound.arguments["value_states"]) | ||
if value is not None: | ||
bound.arguments["value_states"] = value | ||
|
||
return bound.args, bound.kwargs | ||
|
||
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this wrap every single attention block? If so,
global _original_impl
will be re-set multiple times, though if the same attention function is used throughout the entire model that's probably ok?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We guard against multiple sets using
if model.config._attn_implementation != HOOKED_ATTENTION_NAME: