Skip to content

Commit ee2693d

Browse files
committed
fixes
1 parent d962ae2 commit ee2693d

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/transformers/models/prophetnet/modeling_prophetnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch.nn import LayerNorm
2727

2828
from ...activations import ACT2FN
29-
from ...cache_utils import Cache, EncoderDecoderCache
29+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3030
from ...generation import GenerationMixin
3131
from ...modeling_layers import GradientCheckpointingLayer
3232
from ...modeling_outputs import BaseModelOutput
@@ -1235,7 +1235,7 @@ def forward(
12351235
use_cache = False
12361236

12371237
if use_cache and past_key_values is None:
1238-
past_key_values = EncoderDecoderCache()
1238+
past_key_values = EncoderDecoderCache() if encoder_hidden_states is not None else DynamicCache()
12391239
return_legacy_cache = False
12401240
if use_cache and isinstance(past_key_values, tuple):
12411241
logger.warning_once(

tests/generation/test_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2071,7 +2071,12 @@ def test_generate_with_quant_cache(self):
20712071
model.generate(past_key_valyes=DynamicCache(), **generation_kwargs, **inputs_dict)
20722072

20732073
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
2074-
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
2074+
generation_kwargs["cache_config"] = {
2075+
"backend": "quanto",
2076+
"nbits": 60,
2077+
"q_group_size": 8,
2078+
"residual_length": 128,
2079+
}
20752080
with self.assertRaises(ValueError):
20762081
model.generate(**generation_kwargs, **inputs_dict)
20772082

0 commit comments

Comments
 (0)