diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4d5f6071a67d..e993693c93e4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1017,6 +1017,9 @@ def __init__( "sliding_attention" if sliding_window is not None else "full_attention" for _ in range(config.num_hidden_layers) ] + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(config, "num_kv_shared_layers"): + layer_types = layer_types[: -config.num_kv_shared_layers] for layer_type in layer_types: if layer_type in ("sliding_attention", "chunked_attention"): @@ -1128,6 +1131,9 @@ def __init__( layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)] else: layer_types = ["full_attention" for _ in range(config.num_hidden_layers)] + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(config, "num_kv_shared_layers"): + layer_types = layer_types[: -config.num_kv_shared_layers] layers = [] for layer_type in layer_types: diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index b5f144bcc0ee..1cb2e1f7c7b2 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -156,12 +156,13 @@ class Gemma3nTextConfig(PretrainedConfig): The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers` layers in the model "share" the KV values in that each local and global layer in this range uses the KV cache values computed for the last local or global layer, respectively, before entering this range. The - value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`. + value should be a multiple of the attention pattern size (see `layer_types` parameter). laurel_rank (int, *optional*, defaults to 64): The intermediate size for the linear projections in the Learned Augmented Residual Layer. - activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`): + activation_sparsity_pattern (Sequence[float], *optional*): The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must - explicitly provide a sparsity value for each layer in the model. + explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are + sparse with a sparsity factor of 0.95 and the rest are dense. ```python >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig @@ -227,7 +228,7 @@ def __init__( altup_num_inputs: int = 4, num_kv_shared_layers: int = 15, laurel_rank: int = 64, - activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, + activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None, **kwargs, ): super().__init__( @@ -289,7 +290,10 @@ def __init__( self.laurel_rank = laurel_rank if activation_sparsity_pattern is None: - activation_sparsity_pattern = [0.0] * num_hidden_layers + num_sparse_layers = 10 if num_hidden_layers > 10 else 0 + activation_sparsity_pattern = (0.95,) * num_sparse_layers + (0.0,) * ( + num_hidden_layers - num_sparse_layers + ) if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers: raise ValueError( diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3c21143a3205..c2f05d5d6895 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -30,7 +30,7 @@ import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -1299,13 +1299,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 - # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) - layer_type = config.layer_types[layer_idx] - self.kv_shared_layer_index = ( - first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) - if self.is_kv_shared_layer - else None - ) + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, find the last non-shared layer of the same type before sharing starts + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # For non-shared layers, store full-length kv if this is the last non-shared layer of its type + self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( + config.layer_types[layer_idx] + ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -1327,21 +1331,12 @@ def forward( query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) - if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: - # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - layer = past_key_values.layers[self.kv_shared_layer_index] + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] # Device of past layer may be different from current one - indices = cache_position.to(layer.keys.device) - # Sliding window cache layers might have smaller size (for full layers, we never go beyond) - if isinstance(layer, SlidingWindowLayer): - if cache_position.shape[0] > layer.get_max_cache_shape(): - indices = slice(0, layer.get_max_cache_shape()) - else: - indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1) - - # Device of past layer may be different from current one - key_states = layer.keys[:, :, indices].to(query_states.device) - value_states = layer.values[:, :, indices].to(query_states.device) + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1360,7 +1355,14 @@ def forward( "cache_position": cache_position, "sliding_window": self.sliding_window, } - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index b6629053f118..6118752144b8 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowLayer +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -184,12 +184,13 @@ class Gemma3nTextConfig(Gemma2Config, PretrainedConfig): The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers` layers in the model "share" the KV values in that each local and global layer in this range uses the KV cache values computed for the last local or global layer, respectively, before entering this range. The - value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`. + value should be a multiple of the attention pattern size (see `layer_types` parameter). laurel_rank (int, *optional*, defaults to 64): The intermediate size for the linear projections in the Learned Augmented Residual Layer. - activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`): + activation_sparsity_pattern (Sequence[float], *optional*): The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must - explicitly provide a sparsity value for each layer in the model. + explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are + sparse with a sparsity factor of 0.95 and the rest are dense. ```python >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig @@ -240,7 +241,7 @@ def __init__( altup_num_inputs: int = 4, num_kv_shared_layers: int = 15, laurel_rank: int = 64, - activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, + activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None, **kwargs, ): PretrainedConfig.__init__( @@ -302,7 +303,10 @@ def __init__( self.laurel_rank = laurel_rank if activation_sparsity_pattern is None: - activation_sparsity_pattern = [0.0] * num_hidden_layers + num_sparse_layers = 10 if num_hidden_layers > 10 else 0 + activation_sparsity_pattern = (0.95,) * num_sparse_layers + (0.0,) * ( + num_hidden_layers - num_sparse_layers + ) if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers: raise ValueError( @@ -1746,13 +1750,17 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 - # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) - layer_type = config.layer_types[layer_idx] - self.kv_shared_layer_index = ( - first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) - if self.is_kv_shared_layer - else None - ) + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, find the last non-shared layer of the same type before sharing starts + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # For non-shared layers, store full-length kv if this is the last non-shared layer of its type + self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( + config.layer_types[layer_idx] + ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -1774,21 +1782,12 @@ def forward( query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) - if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: - # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - layer = past_key_values.layers[self.kv_shared_layer_index] + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] # Device of past layer may be different from current one - indices = cache_position.to(layer.keys.device) - # Sliding window cache layers might have smaller size (for full layers, we never go beyond) - if isinstance(layer, SlidingWindowLayer): - if cache_position.shape[0] > layer.get_max_cache_shape(): - indices = slice(0, layer.get_max_cache_shape()) - else: - indices = indices.clamp(min=0, max=layer.get_max_cache_shape() - 1) - - # Device of past layer may be different from current one - key_states = layer.keys[:, :, indices].to(query_states.device) - value_states = layer.values[:, :, indices].to(query_states.device) + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1807,7 +1806,14 @@ def forward( "cache_position": cache_position, "sliding_window": self.sliding_window, } - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/gemma3n/processing_gemma3n.py b/src/transformers/models/gemma3n/processing_gemma3n.py index 19274fece4c1..89d2880cd5c3 100644 --- a/src/transformers/models/gemma3n/processing_gemma3n.py +++ b/src/transformers/models/gemma3n/processing_gemma3n.py @@ -24,10 +24,6 @@ class Gemma3nImagesKwargs(ImagesKwargs): - do_pan_and_scan: Optional[bool] - pan_and_scan_min_crop_size: Optional[int] - pan_and_scan_max_num_crops: Optional[int] - pan_and_scan_min_ratio_to_activate: Optional[float] do_convert_rgb: Optional[bool] diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index af0918596fed..f5115efc59ca 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -14,6 +14,8 @@ # limitations under the License. """Testing suite for the PyTorch Gemma3n model.""" +import copy +import inspect import tempfile import unittest @@ -26,11 +28,13 @@ AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + Cache, Gemma3nAudioConfig, Gemma3nAudioFeatureExtractor, Gemma3nConfig, Gemma3nTextConfig, GenerationConfig, + StaticCache, is_torch_available, ) from transformers.testing_utils import ( @@ -39,11 +43,14 @@ require_read_token, require_torch, require_torch_gpu, + set_config_for_less_flaky_test, + set_model_for_less_flaky_test, slow, torch_device, ) +from transformers.utils import is_flash_attn_2_available -from ...generation.test_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, @@ -370,6 +377,13 @@ def _check_hidden_states_for_generate( [expected_shape] * len(iter_hidden_states), ) + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="Gemma3n flash attention does not support right padding") + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) def test_eager_matches_sdpa_inference( self, @@ -433,6 +447,274 @@ def test_eager_padding_matches_padding_free_with_position_ids(self): def test_sdpa_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("Gemma3n only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @pytest.mark.generate + def test_generate_from_inputs_embeds_with_static_cache(self): + """ + Test that StaticCache can generate from inputs_embeds and calculates max_cache_length + correctly in `generate()`. We force the model to not stop generation until max-length is reached + to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. + """ + for model_class in self.all_generative_model_classes: + # Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly + # use a static cache because they don't create the causal masks correctly. + # TODO: cyril -> relax this by adding a `_support_static_cache` attribute + if not model_class._can_compile_fullgraph: + self.skipTest(reason="This model does not support the static cache format") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if config.get_text_config(decoder=True).is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") + + model = model_class(config).to(torch_device).eval() + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters: + self.skipTest(reason="This model does not support `inputs_embeds` in generation") + + input_ids = inputs_dict.pop("input_ids") + + model.config.use_cache = True + model.config.is_decoder = True + batch_size = input_ids.shape[0] + max_new_tokens = 10 + + # here we force to not stop at eos and go until max-length + model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + text_config = model.config.get_text_config() + head_dim = ( + getattr(text_config, "head_dim", None) or text_config.hidden_size // text_config.num_attention_heads + ) + num_key_value_heads = ( + text_config.num_attention_heads + if getattr(text_config, "num_key_value_heads", None) is None + else text_config.num_key_value_heads + ) + num_hidden_layers = text_config.num_hidden_layers + + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) + + # we should get `max_length - 1` in shape, not `max_length - embeds_length`. + # -1 because the last generated token isn't yet in the cache. + max_length = max_new_tokens + inputs_embeds.shape[1] - 1 + cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] + self.assertIsInstance(outputs.past_key_values, StaticCache) + self.assertEqual(len(outputs.past_key_values), num_hidden_layers - text_config.num_kv_shared_layers) + self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) + + @pytest.mark.generate + def test_generate_with_static_cache(self): + """ + Tests that generating with static cache give almost same results as with dynamic cache, and the output cache + has the expected shapes + """ + for model_class in self.all_generative_model_classes: + # Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly + # use a static cache because they don't create the causal masks correctly. + # TODO: cyril -> relax this by adding a `_support_static_cache` attribute + if not model_class._can_compile_fullgraph: + self.skipTest(reason="This model does not support the static cache format") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + set_config_for_less_flaky_test(config) + main_input = inputs_dict[model_class.main_input_name] + + if config.get_text_config(decoder=True).is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") + + config.is_decoder = True + batch_size = main_input.shape[0] + seq_length = self.model_tester.seq_length + max_new_tokens = 20 + + for dtype in (torch.float32, torch.float16): + model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval() + inputs_dict = { + k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v + for k, v in inputs_dict.items() + } + set_model_for_less_flaky_test(model) + + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "return_dict_in_generate": True, # Required to return `past_key_values` + "output_scores": True, + "use_cache": True, + } + + static_cache_generation = model.generate( + **generation_kwargs, **inputs_dict, cache_implementation="static" + ) + + # Check 1: The cache shapes must match the expected shapes + max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache + text_config = config.text_config if hasattr(config, "text_config") else config + head_dim = ( + getattr(text_config, "head_dim", None) + or text_config.hidden_size // text_config.num_attention_heads + ) + num_key_value_heads = ( + text_config.num_attention_heads + if getattr(text_config, "num_key_value_heads", None) is None + else text_config.num_key_value_heads + ) + num_hidden_layers = text_config.num_hidden_layers + cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) + self.assertTrue( + len(static_cache_generation.past_key_values) + == num_hidden_layers - text_config.num_kv_shared_layers + ) + self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) + + # Check 2: The outputs must be similar to the case with dynamic cache + dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) + self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)) + + @pytest.mark.generate + def test_past_key_values_format(self, custom_all_cache_shapes=None): + """ + Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the + expected cache shapes. + Having a standard KV cache format is important for a consistent API (and for advanced generation methods). + """ + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # 1. If it doesn't support cache, skip the test + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + model = model_class(config).to(torch_device) + model = model.eval() + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided) + past_kv = outputs["past_key_values"] + is_legacy_cache = not isinstance(past_kv, Cache) + + text_config = config.get_text_config() + num_decoder_layers = ( + getattr(text_config, "decoder_layers", None) + or getattr(text_config, "num_decoder_layers", None) + or text_config.num_hidden_layers + ) + + if custom_all_cache_shapes is None: + num_query_attention_heads = getattr( + text_config, "decoder_attention_heads", text_config.num_attention_heads + ) + embed_dim = getattr(text_config, "d_model", text_config.hidden_size) + per_head_embed_dim = embed_dim // num_query_attention_heads + num_key_value_heads = ( + text_config.num_key_value_heads + if getattr(text_config, "num_key_value_heads", None) is not None + else num_query_attention_heads + ) + if config.is_encoder_decoder: + encoder_num_attention_heads = ( + text_config.encoder_attention_heads + if hasattr(text_config, "encoder_attention_heads") + else text_config.num_attention_heads + ) + encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads + batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] + # The sequence length for the encoder K V depends on the model. Since it is not manipulated in + # autoregressive generation, we're keeping the test general and not checking the 3rd dim + default_cross_attention_shape = ( + batch_size, + encoder_num_attention_heads, + encoder_per_head_embed_dim, + ) + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [ + default_self_attention_shape, + default_self_attention_shape, + default_cross_attention_shape, + default_cross_attention_shape, + ] + for _ in range(num_decoder_layers) + ] + else: + batch_size, seq_length = inputs["input_ids"].shape[:2] + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) + ] + + else: + all_cache_shapes = custom_all_cache_shapes + + # 3. Check cache shapes + # 3.1. Encoder-Decoder checks + if config.is_encoder_decoder: + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + + for i in range(num_decoder_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple + + # Self attention + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys + ) + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values + ) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + + # Cross attention (ignore 3rd dim, see default shape preparation) + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys + ) + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values + ) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) + + # 3.2. Decoder-only checks + else: + num_cache_decoder_layers = len(past_kv) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers - text_config.num_kv_shared_layers) + + for i in range(num_decoder_layers - text_config.num_kv_shared_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple + + # Self attention + if is_legacy_cache: + self_attention_layer_keys = past_kv[i][0] + self_attention_layer_values = past_kv[i][1] + elif getattr(past_kv, "layers", None) is None: + # Cache is lot layered (i.e, Mamba derivatives) + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + else: + self_attention_layer_keys = past_kv.layers[i].keys + self_attention_layer_values = past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) + class Gemma3nVision2TextModelTester: text_config = {"activation_sparsity_pattern": None} @@ -606,7 +888,6 @@ def test_automodelforcausallm(self): self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM) -@unittest.skip("Skipped for now!") @slow @require_torch_gpu @require_read_token @@ -629,7 +910,8 @@ def setUp(self): audio_ds = load_dataset( "etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav" ) - self.audio_file_path = audio_ds["train"][0]["audio"]["path"] + self.audio_file_path = audio_ds["train"][0]["audio"].metadata.path + cleanup(torch_device, gc_collect=True) def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -637,9 +919,7 @@ def tearDown(self): def test_model_4b_bf16(self): model_id = "Google/gemma-3n-E4B-it" - model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, low_cpu_mem_usage=True, dtype=torch.bfloat16 - ).to(torch_device) + model = Gemma3nForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) inputs = self.processor.apply_chat_template( self.messages, @@ -651,8 +931,7 @@ def test_model_4b_bf16(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) - - EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_with_audio(self): @@ -664,8 +943,8 @@ def test_model_with_audio(self): model_id = "Google/gemma-3n-E4B-it" model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, low_cpu_mem_usage=True, dtype=torch.bfloat16 - ).to(torch_device) + model_id, dtype=torch.bfloat16, device_map=torch_device + ) messages = [ [ @@ -701,8 +980,8 @@ def test_model_4b_batch(self): model_id = "Google/gemma-3n-E4B-it" model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, low_cpu_mem_usage=False, dtype=torch.bfloat16 - ).to(torch_device) + model_id, dtype=torch.bfloat16, device_map=torch_device + ) messages_2 = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, @@ -731,27 +1010,15 @@ def test_model_4b_batch(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) - EXPECTED_TEXTS = [ - 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', - "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" - ] # fmt: skip + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) - def test_model_4b_crops(self): + def test_model_4b_image(self): model_id = "Google/gemma-3n-E4B-it" model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, low_cpu_mem_usage=True, dtype=torch.bfloat16 - ).to(torch_device) - - crop_config = { - "images_kwargs": { - "do_pan_and_scan": True, - "pan_and_scan_max_num_crops": 448, - "pan_and_scan_min_crop_size": 32, - "pan_and_scan_min_ratio_to_activate": 0.3, - } - } + model_id, dtype=torch.bfloat16, device_map=torch_device + ) inputs = self.processor.apply_chat_template( self.messages, @@ -759,14 +1026,13 @@ def test_model_4b_crops(self): return_dict=True, return_tensors="pt", add_generation_prompt=True, - **crop_config, ).to(torch_device) output = model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) - EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images - EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip + EXPECTED_NUM_IMAGES = 1 # Gemma3n does not support crops + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'] # fmt: skip self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) self.assertEqual(output_text, EXPECTED_TEXTS) @@ -774,8 +1040,8 @@ def test_model_4b_multiimage(self): model_id = "Google/gemma-3n-E4B-it" model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, low_cpu_mem_usage=True, dtype=torch.bfloat16 - ).to(torch_device) + model_id, dtype=torch.bfloat16, device_map=torch_device + ) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, @@ -800,15 +1066,14 @@ def test_model_4b_multiimage(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) - EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are some key elements:\n\n* **A prominent red'] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) + @unittest.skip("For now, using a gemma model with the 3n class is not supported") def test_model_1b_text_only(self): model_id = "google/gemma-3-1b-it" - model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, dtype=torch.bfloat16).to( - torch_device - ) + model = Gemma3nForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device) @@ -818,38 +1083,17 @@ def test_model_1b_text_only(self): EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) - # TODO: raushan FA2 generates gibberish for no reason, check later - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - def test_model_4b_flash_attn(self): - model_id = "Google/gemma-3n-E4B-it" - - model = Gemma3nForConditionalGeneration.from_pretrained( - model_id, dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ).to(torch_device) - - inputs = self.processor.apply_chat_template( - self.messages, - tokenize=True, - return_dict=True, - return_tensors="pt", - add_generation_prompt=True, - ).to(torch_device) - - output = model.generate(**inputs, max_new_tokens=30, do_sample=False) - output_text = self.processor.batch_decode(output, skip_special_tokens=True) - - EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip - self.assertEqual(output_text, EXPECTED_TEXTS) - - @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)]) + @parameterized.expand([("sdpa",), ("eager",), ("flash_attention_2",)]) def test_generation_beyond_sliding_window(self, attn_implementation: str): """Test that we can correctly generate beyond the sliding window. This is non trivial as we need to correctly slice the attention mask in all cases (because we use a hybrid cache). Outputs for every attention functions should be coherent and identical. """ - model_id = "google/gemma-3-1b-it" + + if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available(): + self.skipTest("Test requires Flash Attention") + + model_id = "google/gemma-3n-E2B-it" input_text = [ "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens @@ -859,26 +1103,25 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str): inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) model = AutoModelForCausalLM.from_pretrained( - model_id, attn_implementation=attn_implementation, dtype=torch.float16 - ).to(torch_device) + model_id, attn_implementation=attn_implementation, dtype=torch.float16, device_map=torch_device + ) # Make sure prefill is larger than sliding window input_size = inputs.input_ids.shape[-1] - self.assertTrue(input_size > model.config.sliding_window) + self.assertTrue(input_size > model.config.get_text_config().sliding_window) - out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:] output_text = tokenizer.batch_decode(out) - EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + EXPECTED_COMPLETIONS = [" and I think it's a nice place to visit. This is a nice place. This is", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"] # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS) def test_generation_beyond_sliding_window_with_generation_config(self): - """ - Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- + """Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. """ - model_id = "google/gemma-3-1b-it" - attn_implementation = "sdpa" + + model_id = "google/gemma-3n-E2B-it" input_text = [ "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens @@ -887,18 +1130,16 @@ def test_generation_beyond_sliding_window_with_generation_config(self): tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) - model = AutoModelForCausalLM.from_pretrained( - model_id, attn_implementation=attn_implementation, dtype=torch.float16 - ).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16, device_map=torch_device) # Make sure prefill is larger than sliding window input_size = inputs.input_ids.shape[-1] - self.assertTrue(input_size > model.config.sliding_window) + self.assertTrue(input_size > model.config.get_text_config().sliding_window) - generation_config = GenerationConfig(max_new_tokens=20) - - out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] + out = model.generate(**inputs, generation_config=GenerationConfig(max_new_tokens=20, do_sample=False))[ + :, input_size: + ] output_text = tokenizer.batch_decode(out) - EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + EXPECTED_COMPLETIONS = [" and I am glad to be here. This is a nice place. This is a nice place.", ", green, yellow, purple, orange, pink, brown, black, white.\n\nHere are"] # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS)