-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Improve Gemma3n model and tests #39764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…transformers into max-cache-len-fix
@@ -659,7 +658,6 @@ def test_automodelforcausallm(self): | |||
self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM) | |||
|
|||
|
|||
@unittest.skip("Skipped for now!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests were copied from gemma3 and were skipped. I updated and enabled them.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This comment contains run-slow, running the specified jobs: models: ['models/gemma3', 'models/gemma3n'] |
run-slow: gemma3n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once again, sorry for the delay! Currently catching up on reviews! Alright, thanks! Feel free to merge after solving conflicts and you think the tests are good enough!
@manueldeprada can you check the tests with #40163? SWA changes so some values might change here as well but I think this is the more appropriate PR to do this 👀 |
3d2c73b
to
c12d304
Compare
e7f3cc8
to
5263b2e
Compare
5263b2e
to
46cd717
Compare
run-slow: gemma3n |
This comment contains run-slow, running the specified jobs: models: ['models/gemma3n'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We can still improve the sharing a bit though, to stop wasting memory. It's not ideal to write to the Cache like this, but alright for now as it actually fixes the cropped states issue. Can you check how it behaves with compile and static cache when you're done with the changes?
@@ -1325,21 +1329,14 @@ def forward( | |||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) | |||
query_states = query_states.transpose(1, 2) | |||
|
|||
# For layers with shared KV (from kv sharing point onwards), we reuse the cached keys/values from the previous layer. | |||
# During prefill, cache_position is a full range [0, 1, ..., max_cache_len-1], but in autoregressive mode it's a single position [last_token_idx]. | |||
# For sliding window layers, we must clamp or slice indices to the cache's max length to avoid out-of-bounds access. | |||
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm mistaken, we can simplify the path here
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: | |
if self.is_kv_shared_layer and past_key_values is not None: |
if self.store_full_length_kv: | ||
if not hasattr(past_key_values, "shared_layers"): | ||
past_key_values.shared_layers = {} | ||
past_key_values.shared_layers[self.layer_idx] = key_states, value_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be done after update
is called no?
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma3n |
Improves the Gemma3n model and tests by: