Skip to content

Conversation

manueldeprada
Copy link
Contributor

@manueldeprada manueldeprada commented Jul 29, 2025

This PR removes reliance on Cache.from_legacy_cache(past_key_values) for initializing None past_key_values, replacing it with explicit cache initialization. The previous approach also set return_legacy_cache=True, unintentionally returning legacy tuples and masking other issues.

This change is necessary to support the upcoming deprecation of from_legacy_cache in v4.58.

Note: This update revealed an issue in pipelines, where loader_batch_item expects legacy tuples when iterating over ModelOutputs. It failed when encountering Cache objects.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

gante
gante previously approved these changes Aug 5, 2025
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thank you for working on it

@manueldeprada manueldeprada marked this pull request as ready for review August 5, 2025 17:28
@manueldeprada manueldeprada changed the title [draft] No more using from_legacy_cache as initialization Stop using from_legacy_cache as Cache initialization Aug 5, 2025
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for the delay on this! Thanks a lot, happy to start cleaning up everything to finally drop the legacy format soon! This needs a rebase to fix the conflict though! And for EncoderDecoderCache, let's not provide default values to the init, let's instantiate with 2 DynamicCache in the models instead (see the 2 PR I linked!)

Cyrilvallez
Cyrilvallez previously approved these changes Aug 12, 2025
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks a lot! Looks like you just need to run make fix-copies for Flaubert (probably only the docstrings change) to make CI happy.
If not consistent anymore based on real change, then you can just break the copy!
Feel free to merge once it's done and CI is happy!

@manueldeprada manueldeprada changed the title Stop using from_legacy_cache as Cache initialization Stop using from_legacy_cache as Cache init, make pipeline use Caches, fix and test Cache.select_indices Aug 14, 2025
@gante
Copy link
Member

gante commented Aug 14, 2025

@manueldeprada @Cyrilvallez sorry to potentially add one more task here :D Should we update all DynamicCache() added in this PR to DynamicCache(config=self.config), following the pattern introduced in #40039 ?

@@ -1197,6 +1197,28 @@ def test_dynamic_cache(self):
"DynamicCache Scenario 2 layer 1 failed",
)

def test_dynamic_cache_batch_select_indices(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional test does not hurt, since batch_select_indices was never tested

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, thanks for reverting previous changes.
We just need to remove the kwarg in all the initialization of EncoderDecoderCache to simplify our lives with #40008.
Also, I just checked generate, and we actually return Cache format all the time, even when kv are provided as tuple. So, I think it makes sense to do the same here, and always return a Cache, whatever the input, not only half the time. So we can drop return_legacy_cache in all modeling, and never go back to legacy format 🤗

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: autoformer, bark, bart, bert, bert_generation, big_bird, bigbird_pegasus, biogpt, blenderbot, blenderbot_small, blip, bridgetower, camembert, clvp, cpmant, ctrl

@manueldeprada
Copy link
Contributor Author

@Cyrilvallez all done!! thanks, now the diff is very nice: +378 −538 🧹 🧹 😄

@manueldeprada
Copy link
Contributor Author

run-slow: autoformer, bark, bart, bert, bert_generation, big_bird, bigbird_pegasus, biogpt, blenderbot, blenderbot_small, blip, bridgetower, camembert, clvp, cpmant, ctrl

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/autoformer', 'models/bark', 'models/bart', 'models/bert', 'models/bert_generation', 'models/big_bird', 'models/bigbird_pegasus', 'models/biogpt', 'models/blenderbot', 'models/blenderbot_small', 'models/blip', 'models/bridgetower', 'models/camembert', 'models/clvp', 'models/cpmant', 'models/ctrl']
quantizations: [] ...

@manueldeprada manueldeprada requested review from Cyrilvallez and removed request for Cyrilvallez August 18, 2025 11:19
@manueldeprada
Copy link
Contributor Author

Checked the slow tests and all the failures seem unrelated. There are a bunch of them that will be fixed by #40008 !

@Cyrilvallez Cyrilvallez changed the title 🚨 Return Cache objects in models when past_key_values are not provided (to align with generate) 🚨 Always return Cache objects in modelings (to align with generate) Aug 18, 2025
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for the cleanup! Happy to get rid of it! Feel free to merge! 🤗

@gante
Copy link
Member

gante commented Aug 18, 2025

Generally, we don't want any link Trainer <-> Cache as caching is only an inference speedup trick

@Cyrilvallez for completeness, PEFT has some fine-tuning methods using caches -- which is why we don't simply set use_cache &= not self.training. AFAIK Trainer has no use for caches, except in the eval step.

@manueldeprada manueldeprada merged commit a36d51e into huggingface:main Aug 18, 2025
24 of 25 checks passed
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 💛

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Aug 18, 2025

Nice, thanks @gante for the heads-up! cc @BenjaminBossan then just in case, but I believe peft already supports both formats anyway, as recent models have been returning Cache objects for a long time! 🤗

@BenjaminBossan
Copy link
Member

Thanks for the heads up, I ran the relevant tests against the latest transformers main branch and they all passed. The new warning was not triggered, so I think we're good on the PEFT side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants