Skip to content

Commit 595a1a4

Browse files
committed
fixes
1 parent ee2693d commit 595a1a4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/transformers/models/mbart/modeling_mbart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2424

2525
from ...activations import ACT2FN
26-
from ...cache_utils import Cache, EncoderDecoderCache
26+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2727
from ...generation import GenerationMixin
2828
from ...modeling_attn_mask_utils import (
2929
AttentionMaskConverter,
@@ -1036,7 +1036,7 @@ def forward(
10361036

10371037
# initialize `past_key_values`
10381038
if use_cache and past_key_values is None:
1039-
past_key_values = EncoderDecoderCache()
1039+
past_key_values = EncoderDecoderCache() if encoder_hidden_states is not None else DynamicCache()
10401040
return_legacy_cache = False
10411041
if use_cache and isinstance(past_key_values, tuple):
10421042
return_legacy_cache = True

src/transformers/models/xglm/modeling_xglm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323

2424
from ...activations import ACT2FN
25-
from ...cache_utils import Cache, EncoderDecoderCache
25+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2626
from ...generation import GenerationMixin
2727
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
2828
from ...modeling_layers import GradientCheckpointingLayer
@@ -492,7 +492,7 @@ def forward(
492492

493493
# initialize `past_key_values`
494494
if use_cache and past_key_values is None:
495-
past_key_values = EncoderDecoderCache()
495+
past_key_values = EncoderDecoderCache() if encoder_hidden_states is not None else DynamicCache()
496496
return_legacy_cache = False
497497
if use_cache and isinstance(past_key_values, tuple):
498498
return_legacy_cache = True

0 commit comments

Comments
 (0)