Skip to content

Mistral: Add support for interleaved attention #39799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from

Conversation

manueldeprada
Copy link
Contributor

Adds support for interleaved attention masks to the Mistral model.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@manueldeprada manueldeprada changed the title Adds support for interleaved attention on Mistral models Mistral: Add support for interleaved attention Jul 31, 2025
@manueldeprada manueldeprada marked this pull request as ready for review July 31, 2025 13:16
@manueldeprada
Copy link
Contributor Author

run-slow: mistral

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/mistral']
quantizations: [] ...

@huggingface huggingface deleted a comment from github-actions bot Jul 31, 2025
@huggingface huggingface deleted a comment from github-actions bot Jul 31, 2025
@huggingface huggingface deleted a comment from github-actions bot Jul 31, 2025
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments to make it fit our standard way of doing 🤗

Copy link
Contributor

github-actions bot commented Aug 6, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: minimax, mistral, mixtral, phi3, phi4_multimodal, starcoder2

"past_key_values": past_key_values,
"position_ids": position_ids,
}
full_mask_already_prepared = isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
Copy link
Contributor Author

@manueldeprada manueldeprada Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is necessary due to test_modeling_mistral.py::Mask4DTestHard::test_stacked_causal_mask passing a raw 4d mask.

Maybe should I change the test to pass a dict instead? It would break BC...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not make sense, the model is sliding only

@manueldeprada
Copy link
Contributor Author

run-slow: mistral

Copy link
Contributor

github-actions bot commented Aug 6, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/mistral']
quantizations: [] ...

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistral model was never interleaved, so my first review is to go towards no.
We can maybe define this in the config but hardcode it to be sliding. Our philosophy is to not add abstraction when it is not relevant. Same for mixtral: modeling code should be unchanged

"past_key_values": past_key_values,
"position_ids": position_ids,
}
full_mask_already_prepared = isinstance(attention_mask, torch.Tensor) and len(attention_mask.shape) > 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not make sense, the model is sliding only

@manueldeprada
Copy link
Contributor Author

manueldeprada commented Aug 8, 2025

Mistral model was never interleaved, so my first review is to go towards no.
We can maybe define this in the config but hardcode it to be sliding. Our philosophy is to not add abstraction when it is not relevant.

mistralai/Ministral-8B-Instruct-2410 is interleaved, as reported by @hmellor, quoting model's readme "Trained with a 128k context window with interleaved sliding-window attention"

I had understood that we wanted to support this 😅 But if Ministral isn’t relevant enough to justify the modeling change, sure just close the PR 🤗

We can maybe define this in the config but hardcode it to be sliding.

The PR already does this: defaults are set to sliding or full attention unless the model's config says otherwise:

if self.layer_types is None:
self.layer_types = [
"sliding_attention" if self.sliding_window is not None else "full_attention"
] * num_hidden_layers

(several Mistral models use full attention, example here)

Same for mixtral: modeling code should be unchanged.

This PR doesn’t change Mixtral’s modeling, it only moves Attention’s sliding_window to __init__ so inheriting models can modify it.

@ArthurZucker
Copy link
Collaborator

I don't mind the config, but the modeling code is not longer mistral, but qwen2. We need to check because that architecture already exists -> does not make sense for me to do an exception here ! So as much as possible let's check which model is thus the closest

@manueldeprada
Copy link
Contributor Author

manueldeprada commented Aug 8, 2025

I don't mind the config, but the modeling code is not longer mistral, but qwen2. We need to check because that architecture already exists -> does not make sense for me to do an exception here ! So as much as possible let's check which model is thus the closest

🤯 wow, you have every model on your head! I made a quick Ministral modular from Qwen2 and it matches exactly this PR's outputs with just the bias removal:

class MinistralAttention(Qwen2Attention):
    def __init__(self, config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)

So I understand now! Whats the preferred approach?:

  1. Give up on Ministral and only do the config change for external libraries (and maybe warn that Ministral may be incorrect on long context windows?).
  2. Add a new Ministral model with the slim modular from Qwen2.

@hmellor
Copy link
Member

hmellor commented Aug 8, 2025

https://huggingface.co/mistralai/Ministral-8B-Instruct-2410/blob/main/params.json (Mistral format config) shows that Ministral was supposed to be interleaved.

From what I understand about the history, when this model was originally contributed to Transformers, interleaved sliding attention was not supported, so the model was capped to use all sliding attention.

@ArthurZucker
Copy link
Collaborator

  1. Add a new Ministral model with the slim modular from Qwen2. would be the best IMO! We can open a PR to support or just do a hardcode case for mistral + interleaved -> load Ministral

@manueldeprada
Copy link
Contributor Author

  1. Add a new Ministral model with the slim modular from Qwen2. would be the best IMO! We can open a PR to support or just do a hardcode case for mistral + interleaved -> load Ministral

I am interested in learning about model bring-up, so I’ll open the PR! Closing this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants