File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed
src/transformers/models/prophetnet Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 26
26
from torch .nn import LayerNorm
27
27
28
28
from ...activations import ACT2FN
29
- from ...cache_utils import Cache , EncoderDecoderCache
29
+ from ...cache_utils import Cache , DynamicCache , EncoderDecoderCache
30
30
from ...generation import GenerationMixin
31
31
from ...modeling_layers import GradientCheckpointingLayer
32
32
from ...modeling_outputs import BaseModelOutput
@@ -1235,7 +1235,7 @@ def forward(
1235
1235
use_cache = False
1236
1236
1237
1237
if use_cache and past_key_values is None :
1238
- past_key_values = EncoderDecoderCache ()
1238
+ past_key_values = EncoderDecoderCache () if encoder_hidden_states is not None else DynamicCache ()
1239
1239
return_legacy_cache = False
1240
1240
if use_cache and isinstance (past_key_values , tuple ):
1241
1241
logger .warning_once (
Original file line number Diff line number Diff line change @@ -2071,7 +2071,12 @@ def test_generate_with_quant_cache(self):
2071
2071
model .generate (past_key_valyes = DynamicCache (), ** generation_kwargs , ** inputs_dict )
2072
2072
2073
2073
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
2074
- generation_kwargs ["cache_config" ] = {"nbits" : 60 , "q_group_size" : 8 , "residual_length" : 128 }
2074
+ generation_kwargs ["cache_config" ] = {
2075
+ "backend" : "quanto" ,
2076
+ "nbits" : 60 ,
2077
+ "q_group_size" : 8 ,
2078
+ "residual_length" : 128 ,
2079
+ }
2075
2080
with self .assertRaises (ValueError ):
2076
2081
model .generate (** generation_kwargs , ** inputs_dict )
2077
2082
You can’t perform that action at this time.
0 commit comments