Skip to content

Commit 6567a14

Browse files
committed
fix models, pipelines
1 parent fe5f374 commit 6567a14

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+153
-57
lines changed

src/transformers/cache_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,15 +1732,15 @@ class EncoderDecoderCache(Cache):
17321732
# Override @property from Cache
17331733
is_compileable = None
17341734

1735-
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
1735+
def __init__(self, self_attention_cache: Cache = None, cross_attention_cache: Cache = None):
17361736
super().__init__(layer_classes=DynamicLayer)
1737-
self.self_attention_cache = self_attention_cache
1738-
self.cross_attention_cache = cross_attention_cache
1737+
self.self_attention_cache = self_attention_cache if self_attention_cache is not None else DynamicCache()
1738+
self.cross_attention_cache = cross_attention_cache if cross_attention_cache is not None else DynamicCache()
17391739
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
17401740

17411741
self.is_updated = {}
1742-
for layer_idx in range(len(cross_attention_cache)):
1743-
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
1742+
for layer_idx in range(len(self.cross_attention_cache)):
1743+
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
17441744

17451745
def __iter__(self):
17461746
"""

src/transformers/models/autoformer/modeling_autoformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,8 +1151,10 @@ def forward(
11511151
)
11521152
use_cache = False
11531153

1154+
if use_cache and past_key_values is None:
1155+
past_key_values = EncoderDecoderCache()
11541156
return_legacy_cache = False
1155-
if use_cache and not isinstance(past_key_values, Cache):
1157+
if use_cache and isinstance(past_key_values, tuple):
11561158
logger.warning_once(
11571159
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
11581160
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "

src/transformers/models/bark/modeling_bark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch import nn
2424
from torch.nn import functional as F
2525

26-
from ...cache_utils import Cache, DynamicCache
26+
from ...cache_utils import DynamicCache
2727
from ...generation import GenerationMixin
2828
from ...generation.logits_process import (
2929
AlternatingCodebooksLogitsProcessor,
@@ -498,8 +498,10 @@ def forward(
498498
)
499499
use_cache = False
500500

501+
if use_cache and past_key_values is None:
502+
past_key_values = DynamicCache()
501503
return_legacy_cache = False
502-
if use_cache and not isinstance(past_key_values, Cache):
504+
if use_cache and isinstance(past_key_values, tuple):
503505
logger.warning_once(
504506
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
505507
"You should pass an instance of `DynamicCache` instead, e.g. "

src/transformers/models/bart/modeling_bart.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,8 +1039,10 @@ def forward(
10391039
inputs_embeds = self.embed_tokens(input)
10401040

10411041
# initialize `past_key_values`
1042+
if use_cache and past_key_values is None:
1043+
past_key_values = EncoderDecoderCache()
10421044
return_legacy_cache = False
1043-
if use_cache and not isinstance(past_key_values, Cache):
1045+
if use_cache and isinstance(past_key_values, tuple):
10441046
return_legacy_cache = True
10451047
logger.warning_once(
10461048
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "

src/transformers/models/bert_generation/modeling_bert_generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,10 @@ def forward(
376376
)
377377
use_cache = False
378378

379+
if use_cache and past_key_values is None:
380+
past_key_values = EncoderDecoderCache()
379381
return_legacy_cache = False
380-
if use_cache and not isinstance(past_key_values, Cache):
382+
if use_cache and isinstance(past_key_values, tuple):
381383
logger.warning_once(
382384
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
383385
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1566,8 +1566,10 @@ def forward(
15661566
)
15671567
use_cache = False
15681568

1569+
if use_cache and past_key_values is None:
1570+
past_key_values = DynamicCache()
15691571
return_legacy_cache = False
1570-
if use_cache and not isinstance(past_key_values, Cache):
1572+
if use_cache and isinstance(past_key_values, tuple):
15711573
logger.warning_once(
15721574
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
15731575
"You should pass an instance of `DynamicCache` instead, e.g. "

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2197,8 +2197,10 @@ def forward(
21972197
inputs_embeds = self.embed_tokens(input_ids)
21982198

21992199
# initialize `past_key_values`
2200+
if use_cache and past_key_values is None:
2201+
past_key_values = EncoderDecoderCache()
22002202
return_legacy_cache = False
2201-
if use_cache and not isinstance(past_key_values, Cache):
2203+
if use_cache and isinstance(past_key_values, tuple):
22022204
return_legacy_cache = True
22032205
logger.warning_once(
22042206
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,10 @@ def forward(
553553
use_cache = False
554554

555555
# initialize past_key_values
556+
if use_cache and past_key_values is None:
557+
past_key_values = EncoderDecoderCache()
556558
return_legacy_cache = False
557-
if use_cache and not isinstance(past_key_values, Cache):
559+
if use_cache and isinstance(past_key_values, tuple):
558560
return_legacy_cache = True
559561
logger.warning_once(
560562
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,10 @@ def forward(
378378
use_cache = False
379379

380380
# initialize past_key_values
381+
if use_cache and past_key_values is None:
382+
past_key_values = EncoderDecoderCache()
381383
return_legacy_cache = False
382-
if use_cache and not isinstance(past_key_values, Cache):
384+
if use_cache and isinstance(past_key_values, tuple):
383385
return_legacy_cache = True
384386
logger.warning_once(
385387
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -995,8 +995,10 @@ def forward(
995995
use_cache = False
996996

997997
# initialize `past_key_values`
998+
if use_cache and past_key_values is None:
999+
past_key_values = EncoderDecoderCache()
9981000
return_legacy_cache = False
999-
if use_cache and not isinstance(past_key_values, Cache):
1001+
if use_cache and isinstance(past_key_values, tuple):
10001002
return_legacy_cache = True
10011003
logger.warning_once(
10021004
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "

0 commit comments

Comments
 (0)