Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fe5f374
watch the world burn
manueldeprada Jul 29, 2025
6567a14
fix models, pipelines
manueldeprada Aug 5, 2025
605ba14
fix models that are both enc-dec and StandaloneDecoder
manueldeprada Aug 5, 2025
0f4a749
Merge branch 'main' into no-from_legacy-init
manueldeprada Aug 5, 2025
85812b9
Merge branch 'main' into no-from_legacy-init
manueldeprada Aug 5, 2025
7cdd948
replace routing_weights.device.index with a
manueldeprada Aug 5, 2025
375146f
Merge branch 'fix-cpu-tests' into no-from_legacy-init
manueldeprada Aug 5, 2025
c6fc969
Merge branch 'main' into no-from_legacy-init
manueldeprada Aug 6, 2025
0d8d68b
cyril review
manueldeprada Aug 12, 2025
053cdf1
Merge branch 'main' of github.com:huggingface/transformers into no-fr…
manueldeprada Aug 12, 2025
f725e21
raise error if None is passed
manueldeprada Aug 12, 2025
495ef60
fix xlm model
manueldeprada Aug 12, 2025
00a1cd4
fixes
manueldeprada Aug 12, 2025
18dde86
ops
manueldeprada Aug 12, 2025
3bfcfae
more models
manueldeprada Aug 13, 2025
d6fb95e
Merge branch 'main' of github.com:huggingface/transformers into no-fr…
manueldeprada Aug 14, 2025
94e4712
fix flaubert
manueldeprada Aug 14, 2025
c7f336a
fix bert test, dont use legacy
manueldeprada Aug 14, 2025
e051637
Merge branch 'main' into no-from_legacy-init
manueldeprada Aug 14, 2025
959271f
fix pipelines definetely, sanitize cache's select_index, create tests
manueldeprada Aug 14, 2025
0ad4c8f
fix tests/trainer/test_trainer.py::TrainerIntegrationTest::test_evalu…
manueldeprada Aug 14, 2025
24949fa
Revert "fix tests/trainer/test_trainer.py::TrainerIntegrationTest::te…
manueldeprada Aug 14, 2025
89dbbe0
skip test
manueldeprada Aug 15, 2025
04021bc
revert cache changes
manueldeprada Aug 15, 2025
56a1b09
ops
manueldeprada Aug 15, 2025
35dd4db
fix gpt2 test
manueldeprada Aug 15, 2025
259de5a
make the error a warning
manueldeprada Aug 15, 2025
4693a10
remove kwargs and return_legacy_cache
manueldeprada Aug 18, 2025
260fab9
cont
manueldeprada Aug 18, 2025
9baf489
cyril review
manueldeprada Aug 18, 2025
a0eaa4a
Merge branch 'main' of github.com:huggingface/transformers into no-fr…
manueldeprada Aug 18, 2025
28cd663
fix reformer
manueldeprada Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/modular-transformers/modeling_dummy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
Expand Down Expand Up @@ -540,8 +540,13 @@ def forward(
)
use_cache = False

if use_cache and self.config.is_decoder and past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)

return_legacy_cache = False
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand Down
9 changes: 7 additions & 2 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from packaging import version

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
Expand Down Expand Up @@ -543,8 +543,13 @@ def forward(
)
use_cache = False

if use_cache and self.config.is_decoder and past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)

return_legacy_cache = False
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,8 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
backward compatibility.
"""
cache = cls()
if past_key_values is None:
raise ValueError("past_key_values cannot be None in from_legacy_cache()")
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
Expand Down Expand Up @@ -1522,6 +1524,8 @@ def from_legacy_cache(
cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
if past_key_values is None:
raise ValueError("past_key_values cannot be None in from_legacy_cache()")
cache = cls(
self_attention_cache=DynamicCache(),
cross_attention_cache=DynamicCache(),
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1154,8 +1154,12 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch import nn
from torch.nn import functional as F

from ...cache_utils import Cache, DynamicCache
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
Expand Down Expand Up @@ -498,8 +498,10 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None:
past_key_values = DynamicCache()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `DynamicCache` instead, e.g. "
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -1042,8 +1042,14 @@ def forward(
inputs_embeds = self.embed_tokens(input)

# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if encoder_hidden_states is not None
else DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -639,8 +639,13 @@ def forward(
)
use_cache = False

if use_cache and self.config.is_decoder and past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)

return_legacy_cache = False
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
Expand Down Expand Up @@ -380,8 +380,12 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,8 +1568,10 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None:
past_key_values = DynamicCache()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `DynamicCache` instead, e.g. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -2200,8 +2200,14 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if encoder_hidden_states is not None
else DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand Down Expand Up @@ -556,15 +556,17 @@ def forward(
use_cache = False

# initialize past_key_values
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
"You should pass an instance of `DynamicCache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

batch_size, seq_length = inputs_embeds.size()[:-1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand All @@ -578,11 +580,7 @@ def forward(
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

self_attn_cache = (
past_key_values.self_attention_cache
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
self_attn_cache = past_key_values

causal_mask = self._update_causal_mask(
attention_mask,
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/biogpt/modular_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -380,15 +380,17 @@ def forward(
use_cache = False

# initialize past_key_values
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
"You should pass an instance of `DynamicCache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

batch_size, seq_length = inputs_embeds.size()[:-1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand All @@ -402,11 +404,7 @@ def forward(
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

self_attn_cache = (
past_key_values.self_attention_cache
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
self_attn_cache = past_key_values

causal_mask = self._update_causal_mask(
attention_mask,
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -998,8 +998,14 @@ def forward(
use_cache = False

# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if encoder_hidden_states is not None
else DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
Expand Down Expand Up @@ -984,8 +984,14 @@ def forward(
use_cache = False

# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if encoder_hidden_states is not None
else DynamicCache()
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and isinstance(past_key_values, tuple):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def forward(

return_legacy_cache = False
if use_cache:
if not isinstance(past_key_values, Cache):
if isinstance(past_key_values, tuple):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
Expand All @@ -447,6 +447,10 @@ def forward(
# `EncoderDecoderCache` type assuming that the incoming cache is from `self_attention`
elif isinstance(past_key_values, DynamicCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)

all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down
Loading