Skip to content

Commit a36d51e

Browse files
🚨 Always return Cache objects in modelings (to align with generate) (#39765)
* watch the world burn * fix models, pipelines * make the error a warning * remove kwargs and return_legacy_cache * fix reformer
1 parent 57e230c commit a36d51e

File tree

72 files changed

+378
-538
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+378
-538
lines changed

‎examples/modular-transformers/modeling_dummy_bert.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import nn
1414

1515
from ...activations import ACT2FN
16-
from ...cache_utils import Cache, EncoderDecoderCache
16+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
1717
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
1818
from ...modeling_layers import GradientCheckpointingLayer
1919
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
@@ -540,14 +540,15 @@ def forward(
540540
)
541541
use_cache = False
542542

543-
return_legacy_cache = False
544-
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
543+
if use_cache and self.config.is_decoder and past_key_values is None:
544+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
545+
546+
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
545547
logger.warning_once(
546548
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
547549
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
548550
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
549551
)
550-
return_legacy_cache = True
551552
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
552553

553554
for i, layer_module in enumerate(self.layer):
@@ -576,9 +577,6 @@ def forward(
576577
if output_hidden_states:
577578
all_hidden_states = all_hidden_states + (hidden_states,)
578579

579-
if return_legacy_cache:
580-
past_key_values = past_key_values.to_legacy_cache()
581-
582580
if not return_dict:
583581
return tuple(
584582
v

‎examples/modular-transformers/modeling_roberta.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from packaging import version
1414

1515
from ...activations import ACT2FN
16-
from ...cache_utils import Cache, EncoderDecoderCache
16+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
1717
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
1818
from ...modeling_layers import GradientCheckpointingLayer
1919
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
@@ -543,14 +543,15 @@ def forward(
543543
)
544544
use_cache = False
545545

546-
return_legacy_cache = False
547-
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
546+
if use_cache and self.config.is_decoder and past_key_values is None:
547+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
548+
549+
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
548550
logger.warning_once(
549551
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
550552
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
551553
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
552554
)
553-
return_legacy_cache = True
554555
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
555556

556557
for i, layer_module in enumerate(self.layer):
@@ -579,9 +580,6 @@ def forward(
579580
if output_hidden_states:
580581
all_hidden_states = all_hidden_states + (hidden_states,)
581582

582-
if return_legacy_cache:
583-
past_key_values = past_key_values.to_legacy_cache()
584-
585583
if not return_dict:
586584
return tuple(
587585
v

‎src/transformers/cache_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,8 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
10631063
backward compatibility.
10641064
"""
10651065
cache = cls()
1066+
if past_key_values is None:
1067+
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
10661068
if past_key_values is not None:
10671069
for layer_idx in range(len(past_key_values)):
10681070
key_states, value_states = past_key_values[layer_idx]
@@ -1528,6 +1530,8 @@ def from_legacy_cache(
15281530
cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]
15291531
) -> "EncoderDecoderCache":
15301532
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1533+
if past_key_values is None:
1534+
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
15311535
cache = cls(
15321536
self_attention_cache=DynamicCache(),
15331537
cross_attention_cache=DynamicCache(),

‎src/transformers/models/autoformer/modeling_autoformer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch import nn
2727

2828
from ...activations import ACT2FN
29-
from ...cache_utils import Cache, EncoderDecoderCache
29+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3030
from ...modeling_attn_mask_utils import (
3131
_prepare_4d_attention_mask,
3232
_prepare_4d_attention_mask_for_sdpa,
@@ -1154,14 +1154,14 @@ def forward(
11541154
)
11551155
use_cache = False
11561156

1157-
return_legacy_cache = False
1158-
if use_cache and not isinstance(past_key_values, Cache):
1157+
if use_cache and past_key_values is None:
1158+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
1159+
if use_cache and isinstance(past_key_values, tuple):
11591160
logger.warning_once(
11601161
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
11611162
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
11621163
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
11631164
)
1164-
return_legacy_cache = True
11651165
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
11661166

11671167
# expand encoder attention mask
@@ -1229,9 +1229,6 @@ def forward(
12291229
if output_hidden_states:
12301230
all_hidden_states += (hidden_states,)
12311231

1232-
if return_legacy_cache:
1233-
past_key_values = past_key_values.to_legacy_cache()
1234-
12351232
if not return_dict:
12361233
return tuple(
12371234
v

‎src/transformers/models/bark/modeling_bark.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch import nn
2424
from torch.nn import functional as F
2525

26-
from ...cache_utils import Cache, DynamicCache
26+
from ...cache_utils import DynamicCache
2727
from ...generation import GenerationMixin
2828
from ...generation.logits_process import (
2929
AlternatingCodebooksLogitsProcessor,
@@ -498,14 +498,14 @@ def forward(
498498
)
499499
use_cache = False
500500

501-
return_legacy_cache = False
502-
if use_cache and not isinstance(past_key_values, Cache):
501+
if use_cache and past_key_values is None:
502+
past_key_values = DynamicCache()
503+
if use_cache and isinstance(past_key_values, tuple):
503504
logger.warning_once(
504505
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
505506
"You should pass an instance of `DynamicCache` instead, e.g. "
506507
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
507508
)
508-
return_legacy_cache = True
509509
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
510510

511511
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -569,9 +569,6 @@ def forward(
569569

570570
logits = self.lm_head(hidden_states)
571571

572-
if return_legacy_cache:
573-
past_key_values = past_key_values.to_legacy_cache()
574-
575572
if not return_dict:
576573
return tuple(
577574
v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None

‎src/transformers/models/bart/modeling_bart.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2525

2626
from ...activations import ACT2FN
27-
from ...cache_utils import Cache, EncoderDecoderCache
27+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2828
from ...generation import GenerationMixin
2929
from ...modeling_attn_mask_utils import (
3030
AttentionMaskConverter,
@@ -1042,9 +1042,13 @@ def forward(
10421042
inputs_embeds = self.embed_tokens(input)
10431043

10441044
# initialize `past_key_values`
1045-
return_legacy_cache = False
1046-
if use_cache and not isinstance(past_key_values, Cache):
1047-
return_legacy_cache = True
1045+
if use_cache and past_key_values is None:
1046+
past_key_values = (
1047+
EncoderDecoderCache(DynamicCache(), DynamicCache())
1048+
if encoder_hidden_states is not None
1049+
else DynamicCache()
1050+
)
1051+
if use_cache and isinstance(past_key_values, tuple):
10481052
logger.warning_once(
10491053
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
10501054
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
@@ -1138,9 +1142,6 @@ def forward(
11381142
if output_hidden_states:
11391143
all_hidden_states += (hidden_states,)
11401144

1141-
if return_legacy_cache:
1142-
past_key_values = past_key_values.to_legacy_cache()
1143-
11441145
if not return_dict:
11451146
return tuple(
11461147
v

‎src/transformers/models/bert/modeling_bert.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2929

3030
from ...activations import ACT2FN
31-
from ...cache_utils import Cache, EncoderDecoderCache
31+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3232
from ...generation import GenerationMixin
3333
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
3434
from ...modeling_layers import GradientCheckpointingLayer
@@ -639,14 +639,15 @@ def forward(
639639
)
640640
use_cache = False
641641

642-
return_legacy_cache = False
643-
if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache):
642+
if use_cache and self.config.is_decoder and past_key_values is None:
643+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
644+
645+
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
644646
logger.warning_once(
645647
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
646648
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
647649
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
648650
)
649-
return_legacy_cache = True
650651
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
651652

652653
for i, layer_module in enumerate(self.layer):
@@ -675,9 +676,6 @@ def forward(
675676
if output_hidden_states:
676677
all_hidden_states = all_hidden_states + (hidden_states,)
677678

678-
if return_legacy_cache:
679-
past_key_values = past_key_values.to_legacy_cache()
680-
681679
if not return_dict:
682680
return tuple(
683681
v

‎src/transformers/models/bert_generation/modeling_bert_generation.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323

2424
from ...activations import ACT2FN
25-
from ...cache_utils import Cache, EncoderDecoderCache
25+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2626
from ...generation import GenerationMixin
2727
from ...modeling_layers import GradientCheckpointingLayer
2828
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
@@ -380,14 +380,14 @@ def forward(
380380
)
381381
use_cache = False
382382

383-
return_legacy_cache = False
384-
if use_cache and not isinstance(past_key_values, Cache):
383+
if use_cache and past_key_values is None:
384+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
385+
if use_cache and isinstance(past_key_values, tuple):
385386
logger.warning_once(
386387
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
387388
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
388389
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
389390
)
390-
return_legacy_cache = True
391391
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
392392

393393
for i, layer_module in enumerate(self.layer):
@@ -416,9 +416,6 @@ def forward(
416416
if output_hidden_states:
417417
all_hidden_states = all_hidden_states + (hidden_states,)
418418

419-
if return_legacy_cache:
420-
past_key_values = past_key_values.to_legacy_cache()
421-
422419
if not return_dict:
423420
return tuple(
424421
v

‎src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,14 +1568,14 @@ def forward(
15681568
)
15691569
use_cache = False
15701570

1571-
return_legacy_cache = False
1572-
if use_cache and not isinstance(past_key_values, Cache):
1571+
if use_cache and past_key_values is None:
1572+
past_key_values = DynamicCache()
1573+
if use_cache and isinstance(past_key_values, tuple):
15731574
logger.warning_once(
15741575
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
15751576
"You should pass an instance of `DynamicCache` instead, e.g. "
15761577
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
15771578
)
1578-
return_legacy_cache = True
15791579
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
15801580

15811581
for i, layer_module in enumerate(self.layer):
@@ -1608,9 +1608,6 @@ def forward(
16081608
if output_hidden_states:
16091609
all_hidden_states = all_hidden_states + (hidden_states,)
16101610

1611-
if return_legacy_cache:
1612-
past_key_values = past_key_values.to_legacy_cache()
1613-
16141611
if not return_dict:
16151612
return tuple(
16161613
v

‎src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2424

2525
from ...activations import ACT2FN
26-
from ...cache_utils import Cache, EncoderDecoderCache
26+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2727
from ...generation import GenerationMixin
2828
from ...modeling_attn_mask_utils import (
2929
AttentionMaskConverter,
@@ -2200,9 +2200,13 @@ def forward(
22002200
inputs_embeds = self.embed_tokens(input_ids)
22012201

22022202
# initialize `past_key_values`
2203-
return_legacy_cache = False
2204-
if use_cache and not isinstance(past_key_values, Cache):
2205-
return_legacy_cache = True
2203+
if use_cache and past_key_values is None:
2204+
past_key_values = (
2205+
EncoderDecoderCache(DynamicCache(), DynamicCache())
2206+
if encoder_hidden_states is not None
2207+
else DynamicCache()
2208+
)
2209+
if use_cache and isinstance(past_key_values, tuple):
22062210
logger.warning_once(
22072211
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
22082212
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
@@ -2297,9 +2301,6 @@ def forward(
22972301
if output_hidden_states:
22982302
all_hidden_states += (hidden_states,)
22992303

2300-
if return_legacy_cache:
2301-
past_key_values = past_key_values.to_legacy_cache()
2302-
23032304
if not return_dict:
23042305
return tuple(
23052306
v

0 commit comments

Comments
 (0)