diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 1065afe80d36..49146cae31b2 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -13,7 +13,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions @@ -540,14 +540,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -576,9 +577,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index 68ce50d4529f..f0b976a5be71 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -13,7 +13,7 @@ from packaging import version from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions @@ -543,14 +543,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -579,9 +580,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c88a7ab0dcf3..56a9e7a4b5a9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1063,6 +1063,8 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens backward compatibility. """ cache = cls() + if past_key_values is None: + logger.warning_once("past_key_values should not be None in from_legacy_cache()") if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] @@ -1528,6 +1530,8 @@ def from_legacy_cache( cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + if past_key_values is None: + logger.warning_once("past_key_values should not be None in from_legacy_cache()") cache = cls( self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache(), diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index bc0e5faf5965..b2006ad72ffa 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,7 +26,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1154,14 +1154,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) # expand encoder attention mask @@ -1229,9 +1229,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index bba22e65e610..3703af30b28e 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, @@ -498,14 +498,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -569,9 +569,6 @@ def forward( logits = self.lm_head(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index de20331e829e..82931e2eb785 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1042,9 +1042,13 @@ def forward( inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1138,9 +1142,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 04323c7ec4ae..edc511315805 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -639,14 +639,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -675,9 +676,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index d582a914395a..b3e6fa053e9d 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -22,7 +22,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions @@ -380,14 +380,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -416,9 +416,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 97f6187a1f98..1861bb525c84 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1568,14 +1568,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -1608,9 +1608,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9af25afd797f..6e4a9e100b6f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -2200,9 +2200,13 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -2297,9 +2301,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index bba29c892fce..bdf0729f48a5 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -556,15 +556,15 @@ def forward( use_cache = False # initialize past_key_values - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) batch_size, seq_length = inputs_embeds.size()[:-1] past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -578,11 +578,7 @@ def forward( mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - self_attn_cache = ( - past_key_values.self_attention_cache - if isinstance(past_key_values, EncoderDecoderCache) - else past_key_values - ) + self_attn_cache = past_key_values causal_mask = self._update_causal_mask( attention_mask, @@ -646,9 +642,6 @@ def forward( hidden_states = self.layer_norm(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 7b29640cd8f2..d500fe9ec7ec 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -380,15 +380,15 @@ def forward( use_cache = False # initialize past_key_values - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) batch_size, seq_length = inputs_embeds.size()[:-1] past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -402,11 +402,7 @@ def forward( mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - self_attn_cache = ( - past_key_values.self_attention_cache - if isinstance(past_key_values, EncoderDecoderCache) - else past_key_values - ) + self_attn_cache = past_key_values causal_mask = self._update_causal_mask( attention_mask, @@ -470,9 +466,6 @@ def forward( hidden_states = self.layer_norm(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 78dd0223bc71..22a4faccd71a 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -998,9 +998,13 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1096,9 +1100,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 212c7cb135a3..c705be8ab556 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -23,7 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -984,9 +984,13 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1081,9 +1085,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 55b274596ab5..4aa44a9afb4f 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -433,20 +433,22 @@ def forward( ) use_cache = False - return_legacy_cache = False if use_cache: - if not isinstance(past_key_values, Cache): + if isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) # The model acts as encoder decoder but is not an encoder decoder. So we cast all cache objects to # `EncoderDecoderCache` type assuming that the incoming cache is from `self_attention` elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif past_key_values is None: + past_key_values = EncoderDecoderCache( + self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache() + ) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -479,9 +481,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index b547d160ab0f..50318d34dfd8 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -767,14 +767,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -803,9 +804,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 746112afef26..be5904a6293a 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -596,14 +596,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -632,9 +633,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index c8077eb8719e..40490b23d2ef 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1069,14 +1069,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1155,9 +1155,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 91dba7da6c8e..c8a61d6bd0ee 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -636,14 +636,14 @@ def forward( position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -687,9 +687,6 @@ def forward( new_hidden_states += (hidden_state[:, self.prompt_length :, :],) all_hidden_states = new_hidden_states - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index f942cb1531c1..1dd5d8ad2aed 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -338,14 +338,14 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -422,9 +422,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index e6754770d038..cbe35c7b1070 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -484,14 +484,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -520,9 +521,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 6b3419beb638..3b2e3b70726b 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -567,19 +567,18 @@ def forward( token_type_ids = token_type_ids.view(-1, input_shape[-1]) # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder and similar addition in GPT2Model - return_legacy_cache = False if use_cache: if past_key_values is None: - return_legacy_cache = True past_key_values = DynamicCache() - elif not isinstance(past_key_values, Cache): - return_legacy_cache = True + elif isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " "You should pass an instance of `Cache` instead, e.g. " "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." ) past_key_values = DynamicCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = DynamicCache() if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) @@ -697,12 +696,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) past_key_values = past_key_values if use_cache else None - if return_legacy_cache: - past_key_values = ( - past_key_values.self_attention_cache.to_legacy_cache() - if self.config.add_cross_attention - else past_key_values.to_legacy_cache() - ) + # no return to legacy cache if not return_dict: return tuple( v diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index f402614e7530..ee921b681326 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -539,14 +539,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -575,9 +576,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index bea39d1d8be5..b21267322c16 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -468,14 +468,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -504,9 +505,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 3643119988ba..6c0ca07d2682 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -858,7 +858,10 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - if not isinstance(cache, Cache): + if cache is None: + cache = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if isinstance(cache, tuple): cache = EncoderDecoderCache.from_legacy_cache(cache) if lengths is None: @@ -1144,12 +1147,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If @@ -1255,12 +1254,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ @@ -1352,12 +1347,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1495,12 +1486,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels whether a question has an answer or no answer (SQuAD 2.0) cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1636,12 +1623,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 8a7b3e898840..36fcb173ece5 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,7 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( @@ -649,9 +649,9 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -718,9 +718,6 @@ def forward( x = self.output_projection(x) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [x, past_key_values, all_hidden_states, all_self_attns, all_cross_attns] if v is not None diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b08c2f718af3..3b2159a4ede0 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -839,13 +839,10 @@ def forward( use_cache = False # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder - return_legacy_cache = False if use_cache: if past_key_values is None: - return_legacy_cache = True past_key_values = DynamicCache() - elif not isinstance(past_key_values, Cache): - return_legacy_cache = True + elif isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " "You should pass an instance of `Cache` instead, e.g. " @@ -961,12 +958,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) past_key_values = past_key_values if use_cache else None - if return_legacy_cache: - past_key_values = ( - past_key_values.self_attention_cache.to_legacy_cache() - if self.config.add_cross_attention - else past_key_values.to_legacy_cache() - ) if not return_dict: return tuple( v diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0d21f30f490c..5d7434137928 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -22,7 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import is_flash_attn_available @@ -487,14 +487,14 @@ def forward( if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) if inputs_embeds is None: @@ -589,9 +589,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 2e13db7f29d9..7472f669b226 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -646,14 +646,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values @@ -760,9 +760,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 8ac7e905a20b..ae0e3019360e 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -26,7 +26,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1216,9 +1216,9 @@ def forward( input_shape = inputs_embeds.size()[:-1] # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1304,9 +1304,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 92aa3ace5a8e..1e4110fb196a 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -23,7 +23,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1050,14 +1050,18 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1134,9 +1138,6 @@ def forward( # add final layer norm hidden_states = self.layer_norm(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 1d3beefd7d59..4b162cfdd957 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -1760,14 +1760,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1844,9 +1844,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index fadba94f9efd..4a0a804a8449 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1055,9 +1055,9 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1158,9 +1158,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 20dc02213dc7..5ed87c86f31e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -997,9 +997,13 @@ def forward( inputs_embeds = inputs_embeds * self.embed_scale # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1090,9 +1094,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 11d8ca2d269c..499a626f4440 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1038,9 +1038,13 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1134,9 +1138,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index b12e97e68cf3..69239a0b468f 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -509,14 +509,14 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) all_hidden_states = () if output_hidden_states else None @@ -555,9 +555,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index dff967861284..45db56a9ee11 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -636,9 +636,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index cd81408283c9..c095abdefed5 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -658,9 +658,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 891602c33897..8dbd8235b86b 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -354,9 +354,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `DynamicCache` instead, e.g. " @@ -405,9 +405,6 @@ def forward( # Add last hidden state hidden_states = self.norm_f(hidden_states) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 74f61e79029e..0df49dae8a35 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -559,14 +559,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -638,9 +638,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index b0afcf6da7ef..559f7375977b 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -521,14 +521,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -601,9 +601,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index c15975fd6fd4..e2d1072f6a71 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -855,14 +855,18 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -939,9 +943,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 819ebef200ad..f5b786b9767c 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1233,9 +1233,9 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1326,9 +1326,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 174e9308cb92..78e2dc0cb4eb 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1047,9 +1047,13 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1141,9 +1145,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index faf71a29f8a5..578cc9958e07 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1305,9 +1305,9 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1391,9 +1391,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index de92fb89aa4c..b348b7a88f7f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -958,9 +958,13 @@ def forward( inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1054,9 +1058,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index dbe813648cab..4085cecaea37 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,7 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -1238,14 +1238,18 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1357,9 +1361,6 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - # split last_hidden_state for return last_hidden_state = hidden_states[:, :sequence_length] last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 760edd48453a..367af6692357 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -30,7 +30,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -62,13 +61,12 @@ ) -class ReformerDynamicCache(DynamicCache): +class ReformerDynamicCache: """ A dynamic cache that stores past buckets instead of key/values. """ def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: - super().__init__() self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self.buckets_cache: list[torch.Tensor] = [] self.states_cache: list[torch.Tensor] = [] @@ -1827,14 +1825,14 @@ def forward( all_attentions = [] # init cached hidden states if necessary - return_legacy_cache = False - if use_cache or not isinstance(past_buckets_states, ReformerDynamicCache): + if use_cache and past_buckets_states is None: + past_buckets_states = ReformerDynamicCache() + elif use_cache and isinstance(past_buckets_states, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `ReformerDynamicCache` instead, e.g. " "`past_key_values=ReformerDynamicCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_buckets_states = ReformerDynamicCache.from_legacy_cache(past_buckets_states) # concat same tensor for reversible ResNet @@ -1861,8 +1859,6 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) next_cache = past_buckets_states if use_cache else None - if return_legacy_cache: - next_cache = past_buckets_states.to_legacy_cache() return ReformerEncoderOutput( hidden_states=hidden_states, @@ -2360,15 +2356,15 @@ def prepare_inputs_for_generation( def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] - for layer_past in past_key_values: + for buckets, hidden_states in past_key_values: # buckets - if layer_past[0] is not None: - reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)) + if buckets is not None and buckets.numel() > 0: + reord_buckets = buckets.index_select(0, beam_idx.to(buckets.device)) else: reord_buckets = None # hidden states - reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)) + reord_hidden_states = hidden_states.index_select(0, beam_idx.to(hidden_states.device)) reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) if isinstance(past_key_values, ReformerDynamicCache): diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index fc3f1862e67d..10530f4949ea 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -505,14 +505,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) hidden_states = self.embedding_hidden_mapping_in(hidden_states) @@ -545,9 +545,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index e1770bb4db3f..de8fd818cbd6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -595,14 +595,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -631,9 +632,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 30cc18801d40..6f6adeeb0937 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -476,14 +476,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -512,9 +513,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 69f9787d2232..990331044bed 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -595,14 +595,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -631,9 +632,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 2a4ddc69417d..2a3490e34380 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -540,14 +540,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) all_hidden_states = () if output_hidden_states else None @@ -586,9 +586,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 2a9e23cb92fb..6a84b884fba5 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1800,9 +1800,9 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1866,9 +1866,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 775d8c1a68ed..e475b5a9f81f 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1843,9 +1843,9 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1910,9 +1910,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 56b9a582a042..7469e684613a 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -873,9 +873,9 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -949,9 +949,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 60ddec72c37e..fefbb8b6604f 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1582,14 +1582,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1655,9 +1655,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 90b4672c9da0..7ad3fd845af6 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -812,10 +812,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if not self.training and use_cache and past_key_values is None: - past_key_values = EncoderDecoderCache( - self_attention_cache=DynamicCache(config=self.config), - cross_attention_cache=DynamicCache(config=self.config), - ) + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index ed407be82a57..988e24e65365 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -675,10 +675,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if not self.training and use_cache and past_key_values is None: - past_key_values = EncoderDecoderCache( - self_attention_cache=DynamicCache(config=self.config), - cross_attention_cache=DynamicCache(config=self.config), - ) + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 6704c77b3b77..6b3ae41e67d3 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -577,7 +577,9 @@ def forward( return_dict=True, cache_position=None, ): - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 20e6b11b6864..53f1ffed9bfb 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -22,7 +22,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -993,9 +993,9 @@ def forward( input_shape = inputs_embeds.size()[:-1] # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -1081,9 +1081,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index e9a12069a0cf..d357643af0ee 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -582,14 +582,18 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -668,9 +672,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index d9df105204e8..b0de0327c657 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -22,7 +22,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -494,9 +494,13 @@ def forward( use_cache = False # initialize `past_key_values` - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(), DynamicCache()) + if encoder_hidden_states is not None + else DynamicCache() + ) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " @@ -578,9 +582,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 57b88811e6ce..5655195102bf 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -811,12 +811,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -831,7 +827,10 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - if not isinstance(cache, Cache): + if cache is None: + cache = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if isinstance(cache, tuple): cache = EncoderDecoderCache.from_legacy_cache(cache) if lengths is None: @@ -1033,12 +1032,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` @@ -1126,12 +1121,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If @@ -1240,12 +1231,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1347,12 +1334,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels whether a question has an answer or no answer (SQuAD 2.0) cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1468,12 +1451,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ @@ -1579,12 +1558,8 @@ def forward( also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in `[0, ..., input_ids.size(-1)]`. cache (`dict[str, torch.FloatTensor]`, *optional*): - Dictionary string to `torch.FloatTensor` that contains precomputed hidden states (key and values in the - attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential + Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential decoding. - - The dictionary object will be modified in-place during the forward pass to add newly computed - hidden-states. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index fdff73ff77fc..a0e7075cd8bc 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -596,14 +596,15 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): @@ -632,9 +633,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index c666aef841cb..054aed9ee074 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -585,14 +585,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) all_hidden_states = () if output_hidden_states else None @@ -627,9 +627,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 80e16dc966ae..e2bb9e2e74b9 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu -from ...cache_utils import Cache, EncoderDecoderCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -534,14 +534,14 @@ def forward( ) use_cache = False - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if use_cache and isinstance(past_key_values, tuple): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) all_hidden_states = () if output_hidden_states else None @@ -578,9 +578,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if return_legacy_cache: - past_key_values = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v diff --git a/src/transformers/pipelines/pt_utils.py b/src/transformers/pipelines/pt_utils.py index 19663437cd69..e48950aa4914 100644 --- a/src/transformers/pipelines/pt_utils.py +++ b/src/transformers/pipelines/pt_utils.py @@ -86,13 +86,15 @@ def loader_batch_item(self): elif isinstance(element[0], np.ndarray): loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) continue - if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): + if k in {"hidden_states", "attentions"} and isinstance(element, tuple): # Those are stored as lists of tensors so need specific unbatching. if isinstance(element[0], torch.Tensor): loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) elif isinstance(element[0], np.ndarray): loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element) continue + if k == "past_key_values": + continue if element is None: # This can happen for optional data that get passed around loader_batched[k] = None diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 9cc0a8be2437..8d33d9dc1b22 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -17,6 +17,7 @@ from packaging import version from transformers import AutoTokenizer, BertConfig, is_torch_available +from transformers.cache_utils import EncoderDecoderCache from transformers.models.auto import get_values from transformers.testing_utils import ( CaptureLogger, @@ -716,8 +717,8 @@ def test_sdpa_ignored_mask(self): # Case where query length != kv_length. Note that model needs to be a decoder so we can use cache model.config.is_decoder = True model_sdpa.config.is_decoder = True - res_eager = model(**inp, past_key_values=pkv, use_cache=True) - res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True) + res_eager = model(**inp, past_key_values=EncoderDecoderCache.from_legacy_cache(pkv), use_cache=True) + res_sdpa = model_sdpa(**inp, past_key_values=EncoderDecoderCache.from_legacy_cache(pkv), use_cache=True) self.assertTrue( torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 5155f6d9a0ec..2b3c887f5c3a 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math import unittest import pytest -from transformers import GPT2Config, is_torch_available +from transformers import DynamicCache, GPT2Config, is_torch_available from transformers.testing_utils import ( Expectations, cleanup, @@ -443,9 +442,15 @@ def create_and_check_cached_forward_with_and_without_attention_mask(self, config # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) cache_outputs = model(**cache_inputs) - full_outputs_with_attention_mask = model( - **non_cache_inputs, past_key_values=cache_outputs.past_key_values - ).last_hidden_state + # Caches are mutable (unlike legacy tuples), so we need to copy them before using multiple times + pkv_copy = DynamicCache() + pkv_copy.update( + cache_outputs.past_key_values.layers[0].keys, cache_outputs.past_key_values.layers[0].values, 0 + ) + pkv_copy.update( + cache_outputs.past_key_values.layers[1].keys, cache_outputs.past_key_values.layers[1].values, 1 + ) + full_outputs_with_attention_mask = model(**non_cache_inputs, past_key_values=pkv_copy).last_hidden_state full_outputs_without_attention_mask = model( non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values ).last_hidden_state diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4c2b64f0d07f..a609b75ea0c7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1316,21 +1316,6 @@ def test_trainer_works_with_dict(self): _ = trainer.evaluate() _ = trainer.predict(eval_dataset) - def test_evaluation_with_keys_to_drop(self): - config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) - tiny_gpt2 = GPT2LMHeadModel(config) - x = torch.randint(0, 100, (128,)) - eval_dataset = RepeatDataset(x) - args = TrainingArguments(self.get_auto_remove_tmp_dir(), report_to="none") - trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset) - # By default the past_key_values are removed - result = trainer.predict(eval_dataset) - self.assertTrue(isinstance(result.predictions, np.ndarray)) - # We can still get them by setting ignore_keys to [] - result = trainer.predict(eval_dataset, ignore_keys=[]) - self.assertTrue(isinstance(result.predictions, tuple)) - self.assertEqual(len(result.predictions), 2) - def test_training_arguments_are_left_untouched(self): tmp_dir = self.get_auto_remove_tmp_dir() trainer = get_regression_trainer(output_dir=tmp_dir) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 54a7dc24cf63..043a43865212 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -1197,6 +1197,28 @@ def test_dynamic_cache(self): "DynamicCache Scenario 2 layer 1 failed", ) + def test_dynamic_cache_batch_select_indices(self): + """Select a subset of batches in-place using batch_select_indices.""" + cache = DynamicCache() + # Shape: (batch=3, heads=1, seq_len=2, head_dim=1) + prefill = torch.tensor( + [ + [[[1.0], [2.0]]], + [[[10.0], [20.0]]], + [[[100.0], [200.0]]], + ] + ) + cache.update(prefill, prefill, 0) + self.assertEqual(cache.layers[0].keys.shape[0], 3) + + # Keep batches 0 and 2 + cache.batch_select_indices((0, 2)) + self.assertEqual(cache.layers[0].keys.shape[0], 2) + self.assertEqual( + cache.layers[0].keys[:, 0, :, 0].tolist(), + [[1.0, 2.0], [100.0, 200.0]], + ) + def test_hybrid_cache(self): """ Test HybridCache with a mix of static and sliding layers,