23
23
import torch .nn .functional as F
24
24
25
25
from ...activations import ACT2FN
26
- from ...cache_utils import Cache , DynamicCache , SlidingWindowLayer
26
+ from ...cache_utils import Cache , DynamicCache
27
27
from ...configuration_utils import PretrainedConfig , layer_type_validation
28
28
from ...masking_utils import create_causal_mask , create_sliding_window_causal_mask
29
29
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1748,13 +1748,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
1748
1748
1749
1749
first_kv_shared_layer_idx = self .config .num_hidden_layers - self .config .num_kv_shared_layers
1750
1750
self .is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
1751
- # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
1752
- layer_type = config .layer_types [layer_idx ]
1753
- self .kv_shared_layer_index = (
1754
- first_kv_shared_layer_idx - 1 - config .layer_types [first_kv_shared_layer_idx - 1 :: - 1 ].index (layer_type )
1755
- if self .is_kv_shared_layer
1756
- else None
1757
- )
1751
+ prev_layers = config .layer_types [:first_kv_shared_layer_idx ]
1752
+ if self .is_kv_shared_layer :
1753
+ # For shared layers, find the last non-shared layer of the same type before sharing starts
1754
+ self .kv_shared_layer_index = len (prev_layers ) - 1 - prev_layers [::- 1 ].index (config .layer_types [layer_idx ])
1755
+ self .store_full_length_kv = False
1756
+ else :
1757
+ self .kv_shared_layer_index = None
1758
+ # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
1759
+ self .store_full_length_kv = layer_idx == len (prev_layers ) - 1 - prev_layers [::- 1 ].index (
1760
+ config .layer_types [layer_idx ]
1761
+ )
1758
1762
1759
1763
@deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
1760
1764
def forward (
@@ -1780,18 +1784,10 @@ def forward(
1780
1784
# During prefill, cache_position is a full range [0, 1, ..., max_cache_len-1], but in autoregressive mode it's a single position [last_token_idx].
1781
1785
# For sliding window layers, we must clamp or slice indices to the cache's max length to avoid out-of-bounds access.
1782
1786
if self .is_kv_shared_layer and self .kv_shared_layer_index is not None and past_key_values is not None :
1783
- layer = past_key_values .layers [self .kv_shared_layer_index ]
1784
- # Device of past layer may be different from current one
1785
- indices = cache_position .to (layer .keys .device )
1786
- if isinstance (layer , SlidingWindowLayer ):
1787
- if cache_position .shape [0 ] > layer .get_max_cache_shape ():
1788
- indices = slice (0 , layer .get_max_cache_shape ())
1789
- else :
1790
- indices = indices .clamp (min = 0 , max = layer .get_max_cache_shape () - 1 )
1791
-
1787
+ key_states , value_states = past_key_values .shared_layers [self .kv_shared_layer_index ]
1792
1788
# Device of past layer may be different from current one
1793
- key_states = layer . keys [:, :, indices ] .to (query_states .device )
1794
- value_states = layer . values [:, :, indices ] .to (query_states .device )
1789
+ key_states = key_states .to (query_states .device )
1790
+ value_states = value_states .to (query_states .device )
1795
1791
else :
1796
1792
key_states = self .k_proj (hidden_states ).view (hidden_shape )
1797
1793
key_states = self .k_norm (key_states )
@@ -1810,6 +1806,10 @@ def forward(
1810
1806
"cache_position" : cache_position ,
1811
1807
"sliding_window" : self .sliding_window ,
1812
1808
}
1809
+ if self .store_full_length_kv :
1810
+ if not hasattr (past_key_values , "shared_layers" ):
1811
+ past_key_values .shared_layers = {}
1812
+ past_key_values .shared_layers [self .layer_idx ] = key_states , value_states
1813
1813
key_states , value_states = past_key_values .update (key_states , value_states , self .layer_idx , cache_kwargs )
1814
1814
1815
1815
attention_interface : Callable = eager_attention_forward
0 commit comments