Skip to content

Commit 093bebc

Browse files
zucchini-nlpArthurZucker
authored andcommitted
Paligemma: fix generation with Gemma2 (#36044)
* fix paligemma * nit * use `kwargs` in models that can load any LM * update changes to only affect Paligenma
1 parent 97a6cf9 commit 093bebc

File tree

3 files changed

+358
-6
lines changed

3 files changed

+358
-6
lines changed

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,7 @@ def _update_causal_mask(
346346
token_type_ids,
347347
past_key_values,
348348
cache_position,
349-
input_ids=None,
350-
inputs_embeds=None,
349+
input_tensor,
351350
is_training: bool = False,
352351
):
353352
if self.config.text_config._attn_implementation == "flash_attention_2":
@@ -357,8 +356,7 @@ def _update_causal_mask(
357356

358357
using_static_cache = isinstance(past_key_values, StaticCache)
359358
min_dtype = torch.finfo(self.dtype).min
360-
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
361-
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
359+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
362360
if using_static_cache:
363361
target_length = past_key_values.get_max_cache_shape()
364362
elif isinstance(past_key_values, HybridCache):
@@ -435,6 +433,7 @@ def forward(
435433
output_hidden_states: Optional[bool] = None,
436434
return_dict: Optional[bool] = None,
437435
num_logits_to_keep: int = 0,
436+
**lm_kwargs,
438437
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
439438
r"""
440439
Args:
@@ -525,7 +524,7 @@ def forward(
525524
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
526525

527526
causal_mask = self._update_causal_mask(
528-
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
527+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
529528
)
530529
outputs = self.language_model(
531530
attention_mask=causal_mask,
@@ -538,6 +537,7 @@ def forward(
538537
return_dict=return_dict,
539538
cache_position=cache_position,
540539
num_logits_to_keep=num_logits_to_keep,
540+
**lm_kwargs,
541541
)
542542

543543
logits = outputs.logits
@@ -613,10 +613,12 @@ def prepare_inputs_for_generation(
613613
model_inputs["pixel_values"] = pixel_values
614614
is_training = token_type_ids is not None and labels is not None
615615
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
616+
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
616617
causal_mask = self._update_causal_mask(
617-
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
618+
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
618619
)
619620
model_inputs["attention_mask"] = causal_mask
621+
620622
return model_inputs
621623

622624

tests/models/paligemma2/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)