35
35
from ...processing_utils import Unpack
36
36
from ...utils import auto_docstring , logging
37
37
from ...utils .generic import TransformersKwargs , check_model_inputs
38
+ from ..phi3 .configuration_phi3 import Phi3Config
38
39
from ..phi3 .modeling_phi3 import (
39
40
Phi3DecoderLayer ,
40
41
Phi3ForCausalLM ,
@@ -277,7 +278,7 @@ def __init__(
277
278
self .nemo_final_size = length
278
279
279
280
280
- class Phi4MultimodalConfig (PretrainedConfig ):
281
+ class Phi4MultimodalConfig (Phi3Config ):
281
282
r"""
282
283
This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a
283
284
Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration
@@ -370,20 +371,6 @@ class Phi4MultimodalConfig(PretrainedConfig):
370
371
>>> configuration = model.config
371
372
```"""
372
373
373
- model_type = "phi4_multimodal"
374
- keys_to_ignore_at_inference = ["past_key_values" ]
375
- base_model_tp_plan = {
376
- "layers.*.self_attn.qkv_proj" : "colwise_rep" , # we need to replicate here due to the slicing of qkv
377
- "layers.*.self_attn.o_proj" : "rowwise_rep" , # we need to replicate here due to the slicing of qkv
378
- "layers.*.mlp.gate_up_proj" : "colwise_rep" , # we need to replicate here due to the `chunk` operation
379
- "layers.*.mlp.down_proj" : "rowwise_rep" , # we need to replicate here due to the `chunk` operation
380
- }
381
- base_model_pp_plan = {
382
- "embed_tokens" : (["input_ids" ], ["inputs_embeds" ]),
383
- "layers" : (["hidden_states" , "attention_mask" ], ["hidden_states" ]),
384
- "norm" : (["hidden_states" ], ["hidden_states" ]),
385
- }
386
-
387
374
sub_configs = {"audio_config" : Phi4MultimodalAudioConfig , "vision_config" : Phi4MultimodalVisionConfig }
388
375
389
376
def __init__ (
@@ -416,37 +403,31 @@ def __init__(
416
403
** kwargs ,
417
404
):
418
405
super ().__init__ (
406
+ vocab_size = vocab_size ,
407
+ hidden_size = hidden_size ,
408
+ intermediate_size = intermediate_size ,
409
+ num_hidden_layers = num_hidden_layers ,
410
+ num_attention_heads = num_attention_heads ,
411
+ num_key_value_heads = num_key_value_heads ,
412
+ resid_pdrop = resid_pdrop ,
413
+ embd_pdrop = embd_pdrop ,
414
+ attention_dropout = attention_dropout ,
415
+ hidden_act = hidden_act ,
416
+ max_position_embeddings = max_position_embeddings ,
417
+ initializer_range = initializer_range ,
418
+ rms_norm_eps = rms_norm_eps ,
419
+ use_cache = use_cache ,
420
+ tie_word_embeddings = tie_word_embeddings ,
421
+ rope_theta = rope_theta ,
422
+ rope_scaling = rope_scaling ,
423
+ partial_rotary_factor = partial_rotary_factor ,
419
424
bos_token_id = bos_token_id ,
420
425
eos_token_id = eos_token_id ,
421
426
pad_token_id = pad_token_id ,
422
- tie_word_embeddings = tie_word_embeddings ,
427
+ original_max_position_embeddings = original_max_position_embeddings ,
428
+ sliding_window = sliding_window ,
423
429
** kwargs ,
424
430
)
425
- self .vocab_size = vocab_size
426
- self .hidden_size = hidden_size
427
- self .intermediate_size = intermediate_size
428
- self .num_hidden_layers = num_hidden_layers
429
- self .num_attention_heads = num_attention_heads
430
-
431
- if num_key_value_heads is None :
432
- num_key_value_heads = num_attention_heads
433
-
434
- self .num_key_value_heads = num_key_value_heads
435
- self .resid_pdrop = resid_pdrop
436
- self .embd_pdrop = embd_pdrop
437
- self .attention_dropout = attention_dropout
438
- self .hidden_act = hidden_act
439
- self .max_position_embeddings = max_position_embeddings
440
- self .original_max_position_embeddings = original_max_position_embeddings
441
- self .initializer_range = initializer_range
442
- self .rms_norm_eps = rms_norm_eps
443
- self .use_cache = use_cache
444
- self .rope_theta = rope_theta
445
- self .rope_scaling = rope_scaling
446
- self .partial_rotary_factor = partial_rotary_factor
447
- self ._rope_scaling_adjustment ()
448
- self ._rope_scaling_validation ()
449
- self .sliding_window = sliding_window
450
431
451
432
if isinstance (vision_config , dict ):
452
433
vision_config = Phi4MultimodalVisionConfig (** vision_config )
@@ -460,60 +441,6 @@ def __init__(
460
441
audio_config = Phi4MultimodalAudioConfig ()
461
442
self .audio_config = audio_config
462
443
463
- def _rope_scaling_adjustment (self ):
464
- """
465
- Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
466
- """
467
- if self .rope_scaling is None :
468
- return
469
-
470
- rope_scaling_type = self .rope_scaling .get ("type" , None )
471
-
472
- # For backward compatibility if previous version used "su" or "yarn"
473
- if rope_scaling_type is not None and rope_scaling_type in ["su" , "yarn" ]:
474
- self .rope_scaling ["type" ] = "longrope"
475
-
476
- def _rope_scaling_validation (self ):
477
- """
478
- Validate the `rope_scaling` configuration.
479
- """
480
- if self .rope_scaling is None :
481
- return
482
-
483
- if not isinstance (self .rope_scaling , dict ) or len (self .rope_scaling ) != 3 :
484
- raise ValueError (
485
- "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
486
- f"got { self .rope_scaling } "
487
- )
488
- rope_scaling_type = self .rope_scaling .get ("type" , None )
489
- rope_scaling_short_factor = self .rope_scaling .get ("short_factor" , None )
490
- rope_scaling_long_factor = self .rope_scaling .get ("long_factor" , None )
491
- if rope_scaling_type is None or rope_scaling_type not in ["longrope" ]:
492
- raise ValueError (f"`rope_scaling`'s type field must be one of ['longrope'], got { rope_scaling_type } " )
493
- if not (
494
- isinstance (rope_scaling_short_factor , list )
495
- and all (isinstance (x , (int , float )) for x in rope_scaling_short_factor )
496
- ):
497
- raise ValueError (
498
- f"`rope_scaling`'s short_factor field must be a list of numbers, got { rope_scaling_short_factor } "
499
- )
500
- rotary_ndims = int (self .hidden_size // self .num_attention_heads * self .partial_rotary_factor )
501
- if not len (rope_scaling_short_factor ) == rotary_ndims // 2 :
502
- raise ValueError (
503
- f"`rope_scaling`'s short_factor field must have length { rotary_ndims // 2 } , got { len (rope_scaling_short_factor )} "
504
- )
505
- if not (
506
- isinstance (rope_scaling_long_factor , list )
507
- and all (isinstance (x , (int , float )) for x in rope_scaling_long_factor )
508
- ):
509
- raise ValueError (
510
- f"`rope_scaling`'s long_factor field must be a list of numbers, got { rope_scaling_long_factor } "
511
- )
512
- if not len (rope_scaling_long_factor ) == rotary_ndims // 2 :
513
- raise ValueError (
514
- f"`rope_scaling`'s long_factor field must have length { rotary_ndims // 2 } , got { len (rope_scaling_long_factor )} "
515
- )
516
-
517
444
518
445
class Phi4MultimodalVisionMLP (SiglipMLP ):
519
446
pass
0 commit comments