File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change 23
23
from torch .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
24
24
25
25
from ...activations import ACT2FN
26
- from ...cache_utils import Cache , EncoderDecoderCache
26
+ from ...cache_utils import Cache , DynamicCache , EncoderDecoderCache
27
27
from ...generation import GenerationMixin
28
28
from ...modeling_attn_mask_utils import (
29
29
AttentionMaskConverter ,
@@ -1036,7 +1036,7 @@ def forward(
1036
1036
1037
1037
# initialize `past_key_values`
1038
1038
if use_cache and past_key_values is None :
1039
- past_key_values = EncoderDecoderCache ()
1039
+ past_key_values = EncoderDecoderCache () if encoder_hidden_states is not None else DynamicCache ()
1040
1040
return_legacy_cache = False
1041
1041
if use_cache and isinstance (past_key_values , tuple ):
1042
1042
return_legacy_cache = True
Original file line number Diff line number Diff line change 22
22
from torch import nn
23
23
24
24
from ...activations import ACT2FN
25
- from ...cache_utils import Cache , EncoderDecoderCache
25
+ from ...cache_utils import Cache , DynamicCache , EncoderDecoderCache
26
26
from ...generation import GenerationMixin
27
27
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask , _prepare_4d_causal_attention_mask
28
28
from ...modeling_layers import GradientCheckpointingLayer
@@ -492,7 +492,7 @@ def forward(
492
492
493
493
# initialize `past_key_values`
494
494
if use_cache and past_key_values is None :
495
- past_key_values = EncoderDecoderCache ()
495
+ past_key_values = EncoderDecoderCache () if encoder_hidden_states is not None else DynamicCache ()
496
496
return_legacy_cache = False
497
497
if use_cache and isinstance (past_key_values , tuple ):
498
498
return_legacy_cache = True
You can’t perform that action at this time.
0 commit comments