Skip to content

Commit 2166534

Browse files
committed
cyril review
1 parent 6d245ee commit 2166534

File tree

6 files changed

+75
-161
lines changed

6 files changed

+75
-161
lines changed

src/transformers/models/minimax/modular_minimax.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,7 @@ def forward(
378378

379379

380380
class MiniMaxAttention(MixtralAttention):
381-
def __init__(self, config: MiniMaxConfig, layer_idx: int):
382-
super().__init__(config, layer_idx)
383-
del is_sliding # noqa: F821
384-
self.sliding_window = getattr(config, "sliding_window", None)
381+
pass
385382

386383

387384
class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock):

src/transformers/models/mistral/modeling_mistral.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ...cache_utils import Cache, DynamicCache
1616
from ...generation import GenerationMixin
1717
from ...integrations import use_kernel_forward_from_hub
18-
from ...masking_utils import create_masks_for_generate
18+
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
1919
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2020
from ...modeling_layers import (
2121
GenericForQuestionAnswering,
@@ -135,13 +135,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
135135
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
136136
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
137137
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
138-
# This check is necessary to support models that inherit via modular (e.g. Mixtral) and do not use layer_types
139-
is_sliding = (
140-
config.layer_types[layer_idx] == "sliding_attention"
141-
if getattr(config, "layer_types", None) is not None
142-
else getattr(config, "sliding_window", None) is not None
143-
)
144-
self.sliding_window = config.sliding_window if is_sliding else None
138+
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
145139

146140
def forward(
147141
self,
@@ -217,9 +211,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
217211
self.mlp = MistralMLP(config)
218212
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
219213
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220-
self.attention_type = (
221-
config.layer_types[layer_idx] if getattr(config, "layer_types", None) is not None else None
222-
)
214+
self.attention_type = config.layer_types[layer_idx]
223215

224216
def forward(
225217
self,
@@ -358,31 +350,35 @@ def forward(
358350
position_ids = cache_position.unsqueeze(0)
359351

360352
# It may already have been prepared by e.g. `generate`
361-
mask_already_prepared = isinstance(attention_mask, dict) or (
362-
isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
363-
)
364-
if not mask_already_prepared:
365-
attention_mask = create_masks_for_generate(
366-
config=self.config,
367-
input_embeds=inputs_embeds,
368-
attention_mask=attention_mask,
369-
cache_position=cache_position,
370-
past_key_values=past_key_values,
371-
position_ids=position_ids,
372-
)
373-
353+
if not isinstance(causal_mask_mapping := attention_mask, dict):
354+
# Prepare mask arguments
355+
mask_kwargs = {
356+
"config": self.config,
357+
"input_embeds": inputs_embeds,
358+
"attention_mask": attention_mask,
359+
"cache_position": cache_position,
360+
"past_key_values": past_key_values,
361+
"position_ids": position_ids,
362+
}
363+
full_mask_already_prepared = isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
364+
causal_mask_mapping = {}
365+
if "sliding_attention" in self.config.layer_types:
366+
sliding_attention_mask = (
367+
create_sliding_window_causal_mask(**mask_kwargs)
368+
if not full_mask_already_prepared
369+
else attention_mask
370+
)
371+
causal_mask_mapping["sliding_attention"] = sliding_attention_mask
372+
if "full_attention" in self.config.layer_types:
373+
causal_mask = create_causal_mask(**mask_kwargs) if not full_mask_already_prepared else attention_mask
374+
causal_mask_mapping["full_attention"] = causal_mask
374375
hidden_states = inputs_embeds
375376
position_embeddings = self.rotary_emb(hidden_states, position_ids)
376377

377378
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
378-
causal_mask = (
379-
attention_mask[decoder_layer.attention_type]
380-
if decoder_layer.attention_type is not None and isinstance(attention_mask, dict)
381-
else attention_mask
382-
)
383379
hidden_states = decoder_layer(
384380
hidden_states,
385-
attention_mask=causal_mask,
381+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
386382
position_ids=position_ids,
387383
past_key_value=past_key_values,
388384
use_cache=use_cache,

src/transformers/models/mistral/modular_mistral.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers.utils.generic import check_model_inputs
77

88
from ...cache_utils import Cache, DynamicCache
9-
from ...masking_utils import create_masks_for_generate
9+
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
1010
from ...modeling_flash_attention_utils import FlashAttentionKwargs
1111
from ...modeling_layers import (
1212
GenericForQuestionAnswering,
@@ -49,13 +49,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
4949
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
5050
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
5151
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
52-
# This check is necessary to support models that inherit via modular (e.g. Mixtral) and do not use layer_types
53-
is_sliding = (
54-
config.layer_types[layer_idx] == "sliding_attention"
55-
if getattr(config, "layer_types", None) is not None
56-
else getattr(config, "sliding_window", None) is not None
57-
)
58-
self.sliding_window = config.sliding_window if is_sliding else None
52+
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
5953

6054
def forward(
6155
self,
@@ -107,9 +101,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
107101
super().__init__(config, layer_idx)
108102
self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
109103
self.mlp = MistralMLP(config)
110-
self.attention_type = (
111-
config.layer_types[layer_idx] if getattr(config, "layer_types", None) is not None else None
112-
)
104+
self.attention_type = config.layer_types[layer_idx]
113105

114106

115107
class MistralPreTrainedModel(LlamaPreTrainedModel):
@@ -152,31 +144,35 @@ def forward(
152144
position_ids = cache_position.unsqueeze(0)
153145

154146
# It may already have been prepared by e.g. `generate`
155-
mask_already_prepared = isinstance(attention_mask, dict) or (
156-
isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
157-
)
158-
if not mask_already_prepared:
159-
attention_mask = create_masks_for_generate(
160-
config=self.config,
161-
input_embeds=inputs_embeds,
162-
attention_mask=attention_mask,
163-
cache_position=cache_position,
164-
past_key_values=past_key_values,
165-
position_ids=position_ids,
166-
)
167-
147+
if not isinstance(causal_mask_mapping := attention_mask, dict):
148+
# Prepare mask arguments
149+
mask_kwargs = {
150+
"config": self.config,
151+
"input_embeds": inputs_embeds,
152+
"attention_mask": attention_mask,
153+
"cache_position": cache_position,
154+
"past_key_values": past_key_values,
155+
"position_ids": position_ids,
156+
}
157+
full_mask_already_prepared = isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
158+
causal_mask_mapping = {}
159+
if "sliding_attention" in self.config.layer_types:
160+
sliding_attention_mask = (
161+
create_sliding_window_causal_mask(**mask_kwargs)
162+
if not full_mask_already_prepared
163+
else attention_mask
164+
)
165+
causal_mask_mapping["sliding_attention"] = sliding_attention_mask
166+
if "full_attention" in self.config.layer_types:
167+
causal_mask = create_causal_mask(**mask_kwargs) if not full_mask_already_prepared else attention_mask
168+
causal_mask_mapping["full_attention"] = causal_mask
168169
hidden_states = inputs_embeds
169170
position_embeddings = self.rotary_emb(hidden_states, position_ids)
170171

171172
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
172-
causal_mask = (
173-
attention_mask[decoder_layer.attention_type]
174-
if decoder_layer.attention_type is not None and isinstance(attention_mask, dict)
175-
else attention_mask
176-
)
177173
hidden_states = decoder_layer(
178174
hidden_states,
179-
attention_mask=causal_mask,
175+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
180176
position_ids=position_ids,
181177
past_key_value=past_key_values,
182178
use_cache=use_cache,

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ class MixtralRMSNorm(MistralRMSNorm):
225225
class MixtralAttention(MistralAttention):
226226
def __init__(self, config: MixtralConfig, layer_idx: int):
227227
super().__init__(config, layer_idx)
228-
del is_sliding # noqa: F821
229228
self.sliding_window = getattr(config, "sliding_window", None)
230229

231230

src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py

Lines changed: 22 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ...processing_utils import Unpack
3636
from ...utils import auto_docstring, logging
3737
from ...utils.generic import TransformersKwargs, check_model_inputs
38+
from ..phi3.configuration_phi3 import Phi3Config
3839
from ..phi3.modeling_phi3 import (
3940
Phi3DecoderLayer,
4041
Phi3ForCausalLM,
@@ -277,7 +278,7 @@ def __init__(
277278
self.nemo_final_size = length
278279

279280

280-
class Phi4MultimodalConfig(PretrainedConfig):
281+
class Phi4MultimodalConfig(Phi3Config):
281282
r"""
282283
This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a
283284
Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration
@@ -370,20 +371,6 @@ class Phi4MultimodalConfig(PretrainedConfig):
370371
>>> configuration = model.config
371372
```"""
372373

373-
model_type = "phi4_multimodal"
374-
keys_to_ignore_at_inference = ["past_key_values"]
375-
base_model_tp_plan = {
376-
"layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv
377-
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv
378-
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
379-
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
380-
}
381-
base_model_pp_plan = {
382-
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
383-
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
384-
"norm": (["hidden_states"], ["hidden_states"]),
385-
}
386-
387374
sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig}
388375

389376
def __init__(
@@ -416,37 +403,31 @@ def __init__(
416403
**kwargs,
417404
):
418405
super().__init__(
406+
vocab_size=vocab_size,
407+
hidden_size=hidden_size,
408+
intermediate_size=intermediate_size,
409+
num_hidden_layers=num_hidden_layers,
410+
num_attention_heads=num_attention_heads,
411+
num_key_value_heads=num_key_value_heads,
412+
resid_pdrop=resid_pdrop,
413+
embd_pdrop=embd_pdrop,
414+
attention_dropout=attention_dropout,
415+
hidden_act=hidden_act,
416+
max_position_embeddings=max_position_embeddings,
417+
initializer_range=initializer_range,
418+
rms_norm_eps=rms_norm_eps,
419+
use_cache=use_cache,
420+
tie_word_embeddings=tie_word_embeddings,
421+
rope_theta=rope_theta,
422+
rope_scaling=rope_scaling,
423+
partial_rotary_factor=partial_rotary_factor,
419424
bos_token_id=bos_token_id,
420425
eos_token_id=eos_token_id,
421426
pad_token_id=pad_token_id,
422-
tie_word_embeddings=tie_word_embeddings,
427+
original_max_position_embeddings=original_max_position_embeddings,
428+
sliding_window=sliding_window,
423429
**kwargs,
424430
)
425-
self.vocab_size = vocab_size
426-
self.hidden_size = hidden_size
427-
self.intermediate_size = intermediate_size
428-
self.num_hidden_layers = num_hidden_layers
429-
self.num_attention_heads = num_attention_heads
430-
431-
if num_key_value_heads is None:
432-
num_key_value_heads = num_attention_heads
433-
434-
self.num_key_value_heads = num_key_value_heads
435-
self.resid_pdrop = resid_pdrop
436-
self.embd_pdrop = embd_pdrop
437-
self.attention_dropout = attention_dropout
438-
self.hidden_act = hidden_act
439-
self.max_position_embeddings = max_position_embeddings
440-
self.original_max_position_embeddings = original_max_position_embeddings
441-
self.initializer_range = initializer_range
442-
self.rms_norm_eps = rms_norm_eps
443-
self.use_cache = use_cache
444-
self.rope_theta = rope_theta
445-
self.rope_scaling = rope_scaling
446-
self.partial_rotary_factor = partial_rotary_factor
447-
self._rope_scaling_adjustment()
448-
self._rope_scaling_validation()
449-
self.sliding_window = sliding_window
450431

451432
if isinstance(vision_config, dict):
452433
vision_config = Phi4MultimodalVisionConfig(**vision_config)
@@ -460,60 +441,6 @@ def __init__(
460441
audio_config = Phi4MultimodalAudioConfig()
461442
self.audio_config = audio_config
462443

463-
def _rope_scaling_adjustment(self):
464-
"""
465-
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
466-
"""
467-
if self.rope_scaling is None:
468-
return
469-
470-
rope_scaling_type = self.rope_scaling.get("type", None)
471-
472-
# For backward compatibility if previous version used "su" or "yarn"
473-
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
474-
self.rope_scaling["type"] = "longrope"
475-
476-
def _rope_scaling_validation(self):
477-
"""
478-
Validate the `rope_scaling` configuration.
479-
"""
480-
if self.rope_scaling is None:
481-
return
482-
483-
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
484-
raise ValueError(
485-
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
486-
f"got {self.rope_scaling}"
487-
)
488-
rope_scaling_type = self.rope_scaling.get("type", None)
489-
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
490-
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
491-
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
492-
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
493-
if not (
494-
isinstance(rope_scaling_short_factor, list)
495-
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
496-
):
497-
raise ValueError(
498-
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
499-
)
500-
rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
501-
if not len(rope_scaling_short_factor) == rotary_ndims // 2:
502-
raise ValueError(
503-
f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
504-
)
505-
if not (
506-
isinstance(rope_scaling_long_factor, list)
507-
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
508-
):
509-
raise ValueError(
510-
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
511-
)
512-
if not len(rope_scaling_long_factor) == rotary_ndims // 2:
513-
raise ValueError(
514-
f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
515-
)
516-
517444

518445
class Phi4MultimodalVisionMLP(SiglipMLP):
519446
pass

src/transformers/models/starcoder2/modular_starcoder2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
7777
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
7878
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
7979
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
80-
del is_sliding # noqa: F821
8180
del self.sliding_window
8281

8382
def forward(

0 commit comments

Comments
 (0)