@@ -346,8 +346,7 @@ def _update_causal_mask(
346
346
token_type_ids ,
347
347
past_key_values ,
348
348
cache_position ,
349
- input_ids = None ,
350
- inputs_embeds = None ,
349
+ input_tensor ,
351
350
is_training : bool = False ,
352
351
):
353
352
if self .config .text_config ._attn_implementation == "flash_attention_2" :
@@ -357,8 +356,7 @@ def _update_causal_mask(
357
356
358
357
using_static_cache = isinstance (past_key_values , StaticCache )
359
358
min_dtype = torch .finfo (self .dtype ).min
360
- inputs_lead_dim = input_ids .shape [0 ] if input_ids is not None else inputs_embeds .shape [0 ]
361
- sequence_length = input_ids .shape [1 ] if input_ids is not None else inputs_embeds .shape [1 ]
359
+ inputs_lead_dim , sequence_length = input_tensor .shape [:2 ]
362
360
if using_static_cache :
363
361
target_length = past_key_values .get_max_cache_shape ()
364
362
elif isinstance (past_key_values , HybridCache ):
@@ -435,6 +433,7 @@ def forward(
435
433
output_hidden_states : Optional [bool ] = None ,
436
434
return_dict : Optional [bool ] = None ,
437
435
num_logits_to_keep : int = 0 ,
436
+ ** lm_kwargs ,
438
437
) -> Union [Tuple , PaliGemmaCausalLMOutputWithPast ]:
439
438
r"""
440
439
Args:
@@ -525,7 +524,7 @@ def forward(
525
524
labels = torch .where (input_ids == self .pad_token_id , self .config .ignore_index , labels )
526
525
527
526
causal_mask = self ._update_causal_mask (
528
- attention_mask , token_type_ids , past_key_values , cache_position , input_ids , inputs_embeds , is_training
527
+ attention_mask , token_type_ids , past_key_values , cache_position , inputs_embeds , is_training
529
528
)
530
529
outputs = self .language_model (
531
530
attention_mask = causal_mask ,
@@ -538,6 +537,7 @@ def forward(
538
537
return_dict = return_dict ,
539
538
cache_position = cache_position ,
540
539
num_logits_to_keep = num_logits_to_keep ,
540
+ ** lm_kwargs ,
541
541
)
542
542
543
543
logits = outputs .logits
@@ -613,10 +613,12 @@ def prepare_inputs_for_generation(
613
613
model_inputs ["pixel_values" ] = pixel_values
614
614
is_training = token_type_ids is not None and labels is not None
615
615
if cache_position [0 ] == 0 and isinstance (past_key_values , HybridCache ):
616
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
616
617
causal_mask = self ._update_causal_mask (
617
- attention_mask , token_type_ids , past_key_values , cache_position , input_ids , inputs_embeds , is_training
618
+ attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training
618
619
)
619
620
model_inputs ["attention_mask" ] = causal_mask
621
+
620
622
return model_inputs
621
623
622
624
0 commit comments