Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
40c604b
fix gemma
manueldeprada Jul 28, 2025
d143de4
fix min
Cyrilvallez Jul 29, 2025
404208a
fix quant init issue
manueldeprada Jul 29, 2025
8aff749
Merge branch 'main' of github.com:huggingface/transformers into max-c…
manueldeprada Jul 29, 2025
ee1fe17
fix gemma 3n
manueldeprada Jul 29, 2025
31b1bbe
Merge branch 'max-cache-len-fix' of https://github.com/manueldeprada/…
manueldeprada Jul 29, 2025
82a2c5f
skip quant cache test
manueldeprada Jul 29, 2025
e4e6cc7
fix modular
manueldeprada Jul 29, 2025
ffb2c61
new test for Gemma
manueldeprada Jul 29, 2025
e3ca2a3
include cyril change
manueldeprada Jul 29, 2025
4243098
gemma3n tests and code improvements
manueldeprada Jul 29, 2025
83f2599
Merge branch main
manueldeprada Jul 29, 2025
501f651
modular fix
manueldeprada Jul 11, 2025
49d52d7
opsie
manueldeprada Jul 29, 2025
3f630b8
modular
manueldeprada Jul 29, 2025
b23e3ca
modular
manueldeprada Jul 29, 2025
7d79307
fix audio
manueldeprada Jul 29, 2025
afeca3b
fix test, remove sliding_window_pattern mention
manueldeprada Jul 30, 2025
b8f7f09
add flash_attn pytest marks
manueldeprada Jul 30, 2025
9d4ecb6
ops docstring
manueldeprada Jul 30, 2025
8131164
ops
manueldeprada Jul 30, 2025
dd77392
add cleanup
manueldeprada Jul 30, 2025
c9ca022
try to fix OOMs
manueldeprada Jul 30, 2025
fbfa424
fix modular
manueldeprada Jul 30, 2025
6247789
Merge branch 'main' into gemma3n-fixes
manueldeprada Jul 30, 2025
1a0bd59
raushan review, yih-dar review
manueldeprada Aug 4, 2025
1c5b653
cyril review
manueldeprada Aug 6, 2025
c12d304
cyril review
manueldeprada Aug 18, 2025
5d4785a
Merge commit 'cf243a1bf85e2197dac2cfc1f9b23c0e99493fa2' into gemma3n-…
manueldeprada Aug 18, 2025
eb88eb0
Merge commit '95510ab0182b6581822c55472cd53e85daa3379b' into gemma3n-…
manueldeprada Aug 18, 2025
e9bcc4c
Merge commit 'dc11a3cbb2c6cd96986519a144d4a22610fd8487' into gemma3n-…
manueldeprada Aug 18, 2025
abd6c5d
Merge commit 'a1a4fcd03e3455772415e6400fee91f3159e7ac5' into gemma3n-…
manueldeprada Aug 18, 2025
df1cf46
Merge commit 'e4223fa9150580beca9a3ae5fc72e0e1ef20fe37' into gemma3n-…
manueldeprada Aug 18, 2025
3eda455
Merge commit '5337f3052db90e8f5f8f64afcbf257da603d56fb' into gemma3n-…
manueldeprada Aug 18, 2025
84b80af
Merge branch 'main' of github.com:huggingface/transformers into gemma…
manueldeprada Aug 18, 2025
46cd717
fix gemma3n cache layer sharing
manueldeprada Aug 20, 2025
c7047e3
Merge branch 'main' into gemma3n-fixes
manueldeprada Aug 20, 2025
32e252d
fix tests
manueldeprada Aug 21, 2025
d083330
Merge branch 'gemma3n-fixes' of https://github.com/manueldeprada/tran…
manueldeprada Aug 21, 2025
4f69ef3
no cache update on top layers
manueldeprada Aug 21, 2025
0436f99
style
manueldeprada Aug 21, 2025
db221f0
fix
manueldeprada Aug 21, 2025
8fcd111
Merge branch 'main' of github.com:huggingface/transformers into gemma…
manueldeprada Aug 28, 2025
29c8c4d
fix tests and generation with cache having less layers than model
manueldeprada Aug 28, 2025
7fabbac
ops
manueldeprada Aug 28, 2025
f6ee7bd
review
manueldeprada Aug 28, 2025
fff923a
cyril review 2
manueldeprada Aug 28, 2025
b572f5d
ops
manueldeprada Aug 28, 2025
6f9a6a1
add comment explaining
manueldeprada Aug 28, 2025
03b3f9f
Merge branch 'main' into gemma3n-fixes
manueldeprada Aug 28, 2025
f2155d7
fix slow tests
manueldeprada Aug 28, 2025
5f8a3d8
Merge branch 'main' into gemma3n-fixes
manueldeprada Aug 28, 2025
c53c18a
slow tests 2
manueldeprada Aug 28, 2025
c2f72e5
revert test_flash_attn_2_fp32_ln fix
manueldeprada Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,9 @@ def __init__(
"sliding_attention" if sliding_window is not None else "full_attention"
for _ in range(config.num_hidden_layers)
]
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
if hasattr(config, "num_kv_shared_layers"):
layer_types = layer_types[: -config.num_kv_shared_layers]

for layer_type in layer_types:
if layer_type in ("sliding_attention", "chunked_attention"):
Expand Down Expand Up @@ -1128,6 +1131,9 @@ def __init__(
layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
else:
layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
if hasattr(config, "num_kv_shared_layers"):
layer_types = layer_types[: -config.num_kv_shared_layers]

layers = []
for layer_type in layer_types:
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/gemma3n/configuration_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,13 @@ class Gemma3nTextConfig(PretrainedConfig):
The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
layers in the model "share" the KV values in that each local and global layer in this range uses the KV
cache values computed for the last local or global layer, respectively, before entering this range. The
value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
value should be a multiple of the attention pattern size (see `layer_types` parameter).
laurel_rank (int, *optional*, defaults to 64):
The intermediate size for the linear projections in the Learned Augmented Residual Layer.
activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
activation_sparsity_pattern (Sequence[float], *optional*):
The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
explicitly provide a sparsity value for each layer in the model.
explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
sparse with a sparsity factor of 0.95 and the rest are dense.

```python
>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(
altup_num_inputs: int = 4,
num_kv_shared_layers: int = 15,
laurel_rank: int = 64,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -289,7 +290,10 @@ def __init__(
self.laurel_rank = laurel_rank

if activation_sparsity_pattern is None:
activation_sparsity_pattern = [0.0] * num_hidden_layers
num_sparse_layers = 10 if num_hidden_layers > 10 else 0
activation_sparsity_pattern = (0.95,) * num_sparse_layers + (0.0,) * (
num_hidden_layers - num_sparse_layers
)
Comment on lines 292 to +296
Copy link
Contributor Author

@manueldeprada manueldeprada Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the number of layers hardcoded is no good, the code crashes when instantiating a model with a different number of layers.

The None default is therefore used. There is no danger in deleting the previous default as no model on the Hub relied on it, see here for the discussion.


if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
raise ValueError(
Expand Down
48 changes: 25 additions & 23 deletions src/transformers/models/gemma3n/modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch.nn.functional as F

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -1299,13 +1299,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):

first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
if self.is_kv_shared_layer
else None
)
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
if self.is_kv_shared_layer:
# For shared layers, find the last non-shared layer of the same type before sharing starts
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
self.store_full_length_kv = False
else:
self.kv_shared_layer_index = None
# For non-shared layers, store full-length kv if this is the last non-shared layer of its type
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
config.layer_types[layer_idx]
)

@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
Expand All @@ -1327,21 +1331,12 @@ def forward(
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)

if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None:
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
layer = past_key_values.layers[self.kv_shared_layer_index]
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
# Device of past layer may be different from current one
indices = cache_position.to(layer.keys.device)
# Sliding window cache layers might have smaller size (for full layers, we never go beyond)
if isinstance(layer, SlidingWindowLayer):
if cache_position.shape[0] > layer.get_max_cache_shape():
indices = slice(0, layer.get_max_cache_shape())
else:
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)

# Device of past layer may be different from current one
key_states = layer.keys[:, :, indices].to(query_states.device)
value_states = layer.values[:, :, indices].to(query_states.device)
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
Expand All @@ -1360,7 +1355,14 @@ def forward(
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
if not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
if self.store_full_length_kv:
if not hasattr(past_key_values, "shared_layers"):
past_key_values.shared_layers = {}
past_key_values.shared_layers[self.layer_idx] = key_states, value_states

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
Expand Down
62 changes: 34 additions & 28 deletions src/transformers/models/gemma3n/modular_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -184,12 +184,13 @@ class Gemma3nTextConfig(Gemma2Config, PretrainedConfig):
The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
layers in the model "share" the KV values in that each local and global layer in this range uses the KV
cache values computed for the last local or global layer, respectively, before entering this range. The
value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
value should be a multiple of the attention pattern size (see `layer_types` parameter).
laurel_rank (int, *optional*, defaults to 64):
The intermediate size for the linear projections in the Learned Augmented Residual Layer.
activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
activation_sparsity_pattern (Sequence[float], *optional*):
The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
explicitly provide a sparsity value for each layer in the model.
explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
sparse with a sparsity factor of 0.95 and the rest are dense.

```python
>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
Expand Down Expand Up @@ -240,7 +241,7 @@ def __init__(
altup_num_inputs: int = 4,
num_kv_shared_layers: int = 15,
laurel_rank: int = 64,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None,
**kwargs,
):
PretrainedConfig.__init__(
Expand Down Expand Up @@ -302,7 +303,10 @@ def __init__(
self.laurel_rank = laurel_rank

if activation_sparsity_pattern is None:
activation_sparsity_pattern = [0.0] * num_hidden_layers
num_sparse_layers = 10 if num_hidden_layers > 10 else 0
activation_sparsity_pattern = (0.95,) * num_sparse_layers + (0.0,) * (
num_hidden_layers - num_sparse_layers
)

if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
raise ValueError(
Expand Down Expand Up @@ -1746,13 +1750,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):

first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
if self.is_kv_shared_layer
else None
)
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
if self.is_kv_shared_layer:
# For shared layers, find the last non-shared layer of the same type before sharing starts
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
self.store_full_length_kv = False
else:
self.kv_shared_layer_index = None
# For non-shared layers, store full-length kv if this is the last non-shared layer of its type
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
config.layer_types[layer_idx]
)

@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
Expand All @@ -1774,21 +1782,12 @@ def forward(
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)

if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None:
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
layer = past_key_values.layers[self.kv_shared_layer_index]
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
# Device of past layer may be different from current one
indices = cache_position.to(layer.keys.device)
# Sliding window cache layers might have smaller size (for full layers, we never go beyond)
if isinstance(layer, SlidingWindowLayer):
if cache_position.shape[0] > layer.get_max_cache_shape():
indices = slice(0, layer.get_max_cache_shape())
else:
indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1)

# Device of past layer may be different from current one
key_states = layer.keys[:, :, indices].to(query_states.device)
value_states = layer.values[:, :, indices].to(query_states.device)
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
Expand All @@ -1807,7 +1806,14 @@ def forward(
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
if not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
if self.store_full_length_kv:
if not hasattr(past_key_values, "shared_layers"):
past_key_values.shared_layers = {}
past_key_values.shared_layers[self.layer_idx] = key_states, value_states

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/gemma3n/processing_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@


class Gemma3nImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]


Expand Down
Loading