Skip to content

Conversation

amitsrivastava78
Copy link
Collaborator

  • Modified load_own_variables() to use _direct_assign() for sharded variables
  • Prevents loading full weight tensors on single device before distribution
  • Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
  • Maintains backward compatibility for non-sharded variables
  • Enables loading of models like Gemma2 2B/7B without OOM errors

@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 3, 2025
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces crucial improvements to Keras's JAX backend, specifically targeting Out-Of-Memory (OOM) issues that arise when loading large models in a ModelParallel distributed setup. By modifying how sharded variables are initialized and assigned, the system now avoids loading entire weight tensors onto a single device, instead distributing them directly across available devices. This change, encapsulated in a new load_variable_with_sharded_support utility, ensures efficient memory usage and enables the successful loading of very large models like Gemma2 2B/7B. Additionally, it enhances the robustness of dynamic shape handling and prevents premature garbage collection of sharded JAX arrays, solidifying Keras's distributed capabilities on JAX.

Highlights

  • OOM Fix for ModelParallel: Modified weight loading logic to prevent Out-Of-Memory (OOM) errors when using ModelParallel with large models, especially on the JAX backend.
  • Sharded Variable Loading: Implemented a new utility function, load_variable_with_sharded_support, to ensure sharded variables are loaded directly onto their respective devices, avoiding full tensor materialization on a single device.
  • JAX Backend Enhancements: Updated JaxVariable and NnxVariable to manage strong references to sharded JAX arrays, preventing premature garbage collection and ensuring data availability during inference.
  • Dynamic Shape Handling: Improved compute_output_spec in the JAX backend to more robustly handle dynamic shapes during symbolic tracing.
  • Broad Layer Support: Integrated the new sharded loading mechanism across various Keras layers (e.g., Dense, Conv, Embedding, EinsumDense, BatchNormalization) and optimizers.
  • Comprehensive Testing: Added new tests to validate the correct behavior of sharded array protection, strong references, and the end-to-end sharded variable loading process for ModelParallel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an out-of-memory (OOM) issue during weight loading for models using ModelParallel by introducing sharded variable assignment. The changes primarily involve modifying load_own_variables methods to use a new _direct_assign approach for sharded variables, preventing the full weight tensor from being loaded onto a single device. A new helper function, load_variable_with_sharded_support, centralizes this logic, and it has been integrated into various layers and the base optimizer.

My review has identified a few areas for improvement, including a potential memory leak, duplicated code, and some inconsistencies. Addressing these points will enhance the robustness and maintainability of the solution. Overall, the changes are well-structured and include thorough testing, which is excellent.

@codecov-commenter
Copy link

codecov-commenter commented Oct 3, 2025

Codecov Report

❌ Patch coverage is 77.44681% with 53 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.56%. Comparing base (3fac66f) to head (303f241).
⚠️ Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 70.96% 30 Missing and 15 partials ⚠️
keras/src/layers/core/dense.py 92.00% 1 Missing and 1 partial ⚠️
keras/src/layers/core/einsum_dense.py 92.00% 1 Missing and 1 partial ⚠️
keras/src/layers/core/embedding.py 77.77% 1 Missing and 1 partial ⚠️
keras/src/layers/preprocessing/index_lookup.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21712      +/-   ##
==========================================
- Coverage   82.60%   82.56%   -0.04%     
==========================================
  Files         572      573       +1     
  Lines       58326    58722     +396     
  Branches     9134     9197      +63     
==========================================
+ Hits        48179    48484     +305     
- Misses       7817     7886      +69     
- Partials     2330     2352      +22     
Flag Coverage Δ
keras 82.36% <77.44%> (-0.04%) ⬇️
keras-jax 63.23% <77.44%> (-0.09%) ⬇️
keras-numpy 57.48% <36.59%> (-0.19%) ⬇️
keras-openvino 34.26% <11.48%> (-0.05%) ⬇️
keras-tensorflow 63.84% <36.59%> (-0.22%) ⬇️
keras-torch 63.39% <36.59%> (-0.26%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Modified load_own_variables() to use _direct_assign() for sharded variables
- Prevents loading full weight tensors on single device before distribution
- Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
- Maintains backward compatibility for non-sharded variables
- Enables loading of models like Gemma2 2B/7B without OOM errors
- Added EinsumDense layer testing to ModelParallel sharded variable loading
- Fixed line length issues and code formatting
- Fix _maybe_create_strong_reference to handle JAX tracers during compilation
- Add GPTQ support to legacy load_own_variables methods
- Fix variable ordering for GPTQ in legacy loading
- Add int4 support to embedding legacy loading
- Add _is_jax_tracer() and _safe_has_addressable_shards() helper functions
- Update all addressable_shards checks to use safe versions
- Prevent tracer concretization errors during JAX JIT compilation
- Fix random test failures and improve sharding robustness
- Add _require_min_devices() helper method to skip tests when insufficient devices
- Apply skip conditions to 13 failing tests requiring 8 devices and 1 test requiring 4 devices
- Tests now skip gracefully in CI environments with limited hardware instead of failing
- Maintains full test coverage on systems with sufficient devices
- Prevent mutation of Variable attributes during JAX tracing
- Use try/except to gracefully handle trace context errors
- Maintain strong reference creation for non-traced execution
- Fixes integration test failures while preserving memory leak fix
- Remove os.environ XLA_FLAGS manipulation that forces 2 devices
- Add _require_min_devices() helper method to skip tests when insufficient devices
- Apply device requirement checks to sharding-related tests that need >= 2 devices
- Tests now skip gracefully in single-device environments instead of forcing device creation
- Maintains test coverage on multi-device systems while being CI-friendly
- Change Embedding input_dim from 100 to 96 to be divisible by 8 devices
- Add skip condition for tests requiring < 2 devices (sharding not meaningful)
- Prevents ValueError when sharding dimensions not divisible by device count
- Test now properly skips on single-device systems and passes on multi-device CI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review Gemma Gemma model specific issues size:XL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants