Skip to content
4 changes: 3 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,11 @@ def _prepare_generation_config(
if key == "cache_implementation" and model_generation_config.cache_implementation == "hybrid":
continue
global_default_value = getattr(global_default_generation_config, key, None)
has_custom_gen_config_value = hasattr(generation_config, key)
custom_gen_config_value = getattr(generation_config, key, None)
if (
custom_gen_config_value == global_default_value
(not has_custom_gen_config_value)
and custom_gen_config_value == global_default_value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the case:
When user sets generation_config = GenerationConfig(max_new_tokens=20, do_sample=False), here do_sample is False, while global_default_value is also False, and model_gen_config_value is True, the current condition cannot distinguish whether the value is explicitly set by user or just the global default, so it just modify it to model default. Add another condition to make sure it happens only when user didn't explicitly specify the value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed this is an issue, and afair we discussed it with @gante after merging the PR with config value overwriting. cc @gante for review

and model_gen_config_value != global_default_value
):
modified_values[key] = model_gen_config_value
Expand Down
9 changes: 4 additions & 5 deletions tests/models/gemma3n/test_modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def setUp(self):
audio_ds = load_dataset(
"etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav"
)
self.audio_file_path = audio_ds["train"][0]["audio"]["path"]
self.audio_file_path = audio_ds["train"][0]["audio"].metadata.path
cleanup(torch_device, gc_collect=True)

def tearDown(self):
Expand Down Expand Up @@ -996,15 +996,14 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.get_text_config().sliding_window)

out = model.generate(**inputs, generation_config=GenerationConfig(max_new_tokens=20, do_sample=False))[
:, input_size:
]
generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)

EXPECTED_COMPLETIONS = Expectations({
# FIXME: This test is VERY flaky on ROCm
("cuda", None): [" and I am glad to be here. This is a nice place. This is a nice place.", ", green, yellow, purple, orange, pink, brown, black, white.\n\nHere are"],
("rocm", (9, 4)): [' and I think it makes this place special. This is a nice place. This is a nice place', ', green, yellow, purple, orange, pink, brown, black, white.\n\nHere are'],
("xpu", None): [" and I think it is very nice. I think it is nice. This is a nice place.", ", green, yellow, purple, orange, pink, brown, black, white.\n\nHere are"],
("xpu", None): [" and I think it's a nice place to visit. This is a nice place. This is", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"],
}).get_expectation() # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)