Skip to content

Commit 46cd717

Browse files
committed
fix gemma3n cache layer sharing
1 parent 84b80af commit 46cd717

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import torch.nn.functional as F
3131

3232
from ...activations import ACT2FN
33-
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
33+
from ...cache_utils import Cache, DynamicCache
3434
from ...generation import GenerationMixin
3535
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
3636
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1297,13 +1297,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
12971297

12981298
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
12991299
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
1300-
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
1301-
layer_type = config.layer_types[layer_idx]
1302-
self.kv_shared_layer_index = (
1303-
first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
1304-
if self.is_kv_shared_layer
1305-
else None
1306-
)
1300+
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
1301+
if self.is_kv_shared_layer:
1302+
# For shared layers, find the last non-shared layer of the same type before sharing starts
1303+
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
1304+
self.store_full_length_kv = False
1305+
else:
1306+
self.kv_shared_layer_index = None
1307+
# For non-shared layers, store full-length kv if this is the last non-shared layer of its type
1308+
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
1309+
config.layer_types[layer_idx]
1310+
)
13071311

13081312
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
13091313
def forward(
@@ -1329,18 +1333,10 @@ def forward(
13291333
# During prefill, cache_position is a full range [0, 1, ..., max_cache_len-1], but in autoregressive mode it's a single position [last_token_idx].
13301334
# For sliding window layers, we must clamp or slice indices to the cache's max length to avoid out-of-bounds access.
13311335
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None:
1332-
layer = past_key_values.layers[self.kv_shared_layer_index]
1333-
# Device of past layer may be different from current one
1334-
indices = cache_position.to(layer.keys.device)
1335-
if isinstance(layer, SlidingWindowLayer):
1336-
if cache_position.shape[0] > layer.get_max_cache_shape():
1337-
indices = slice(0, layer.get_max_cache_shape())
1338-
else:
1339-
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
1340-
1336+
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
13411337
# Device of past layer may be different from current one
1342-
key_states = layer.keys[:, :, indices].to(query_states.device)
1343-
value_states = layer.values[:, :, indices].to(query_states.device)
1338+
key_states = key_states.to(query_states.device)
1339+
value_states = value_states.to(query_states.device)
13441340
else:
13451341
key_states = self.k_proj(hidden_states).view(hidden_shape)
13461342
key_states = self.k_norm(key_states)
@@ -1359,6 +1355,10 @@ def forward(
13591355
"cache_position": cache_position,
13601356
"sliding_window": self.sliding_window,
13611357
}
1358+
if self.store_full_length_kv:
1359+
if not hasattr(past_key_values, "shared_layers"):
1360+
past_key_values.shared_layers = {}
1361+
past_key_values.shared_layers[self.layer_idx] = key_states, value_states
13621362
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
13631363

13641364
attention_interface: Callable = eager_attention_forward

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn.functional as F
2424

2525
from ...activations import ACT2FN
26-
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
26+
from ...cache_utils import Cache, DynamicCache
2727
from ...configuration_utils import PretrainedConfig, layer_type_validation
2828
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
2929
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -1748,13 +1748,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
17481748

17491749
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
17501750
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
1751-
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
1752-
layer_type = config.layer_types[layer_idx]
1753-
self.kv_shared_layer_index = (
1754-
first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
1755-
if self.is_kv_shared_layer
1756-
else None
1757-
)
1751+
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
1752+
if self.is_kv_shared_layer:
1753+
# For shared layers, find the last non-shared layer of the same type before sharing starts
1754+
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
1755+
self.store_full_length_kv = False
1756+
else:
1757+
self.kv_shared_layer_index = None
1758+
# For non-shared layers, store full-length kv if this is the last non-shared layer of its type
1759+
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
1760+
config.layer_types[layer_idx]
1761+
)
17581762

17591763
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
17601764
def forward(
@@ -1780,18 +1784,10 @@ def forward(
17801784
# During prefill, cache_position is a full range [0, 1, ..., max_cache_len-1], but in autoregressive mode it's a single position [last_token_idx].
17811785
# For sliding window layers, we must clamp or slice indices to the cache's max length to avoid out-of-bounds access.
17821786
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None:
1783-
layer = past_key_values.layers[self.kv_shared_layer_index]
1784-
# Device of past layer may be different from current one
1785-
indices = cache_position.to(layer.keys.device)
1786-
if isinstance(layer, SlidingWindowLayer):
1787-
if cache_position.shape[0] > layer.get_max_cache_shape():
1788-
indices = slice(0, layer.get_max_cache_shape())
1789-
else:
1790-
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)
1791-
1787+
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
17921788
# Device of past layer may be different from current one
1793-
key_states = layer.keys[:, :, indices].to(query_states.device)
1794-
value_states = layer.values[:, :, indices].to(query_states.device)
1789+
key_states = key_states.to(query_states.device)
1790+
value_states = value_states.to(query_states.device)
17951791
else:
17961792
key_states = self.k_proj(hidden_states).view(hidden_shape)
17971793
key_states = self.k_norm(key_states)
@@ -1810,6 +1806,10 @@ def forward(
18101806
"cache_position": cache_position,
18111807
"sliding_window": self.sliding_window,
18121808
}
1809+
if self.store_full_length_kv:
1810+
if not hasattr(past_key_values, "shared_layers"):
1811+
past_key_values.shared_layers = {}
1812+
past_key_values.shared_layers[self.layer_idx] = key_states, value_states
18131813
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
18141814

18151815
attention_interface: Callable = eager_attention_forward

0 commit comments

Comments
 (0)