-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix ModelParallel OOM issue during weight loading #21712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
amitsrivastava78
commented
Oct 3, 2025
- Modified load_own_variables() to use _direct_assign() for sharded variables
- Prevents loading full weight tensors on single device before distribution
- Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
- Maintains backward compatibility for non-sharded variables
- Enables loading of models like Gemma2 2B/7B without OOM errors
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial improvements to Keras's JAX backend, specifically targeting Out-Of-Memory (OOM) issues that arise when loading large models in a ModelParallel distributed setup. By modifying how sharded variables are initialized and assigned, the system now avoids loading entire weight tensors onto a single device, instead distributing them directly across available devices. This change, encapsulated in a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an out-of-memory (OOM) issue during weight loading for models using ModelParallel
by introducing sharded variable assignment. The changes primarily involve modifying load_own_variables
methods to use a new _direct_assign
approach for sharded variables, preventing the full weight tensor from being loaded onto a single device. A new helper function, load_variable_with_sharded_support
, centralizes this logic, and it has been integrated into various layers and the base optimizer.
My review has identified a few areas for improvement, including a potential memory leak, duplicated code, and some inconsistencies. Addressing these points will enhance the robustness and maintainability of the solution. Overall, the changes are well-structured and include thorough testing, which is excellent.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21712 +/- ##
==========================================
- Coverage 82.60% 82.56% -0.04%
==========================================
Files 572 573 +1
Lines 58326 58722 +396
Branches 9134 9197 +63
==========================================
+ Hits 48179 48484 +305
- Misses 7817 7886 +69
- Partials 2330 2352 +22
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Modified load_own_variables() to use _direct_assign() for sharded variables - Prevents loading full weight tensors on single device before distribution - Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel - Maintains backward compatibility for non-sharded variables - Enables loading of models like Gemma2 2B/7B without OOM errors - Added EinsumDense layer testing to ModelParallel sharded variable loading - Fixed line length issues and code formatting
6514085
to
ff78e07
Compare
- Fix _maybe_create_strong_reference to handle JAX tracers during compilation - Add GPTQ support to legacy load_own_variables methods - Fix variable ordering for GPTQ in legacy loading - Add int4 support to embedding legacy loading
- Add _is_jax_tracer() and _safe_has_addressable_shards() helper functions - Update all addressable_shards checks to use safe versions - Prevent tracer concretization errors during JAX JIT compilation - Fix random test failures and improve sharding robustness
- Add _require_min_devices() helper method to skip tests when insufficient devices - Apply skip conditions to 13 failing tests requiring 8 devices and 1 test requiring 4 devices - Tests now skip gracefully in CI environments with limited hardware instead of failing - Maintains full test coverage on systems with sufficient devices
7816f0c
to
af6c766
Compare
… with assignment for shard references
- Prevent mutation of Variable attributes during JAX tracing - Use try/except to gracefully handle trace context errors - Maintain strong reference creation for non-traced execution - Fixes integration test failures while preserving memory leak fix
- Remove os.environ XLA_FLAGS manipulation that forces 2 devices - Add _require_min_devices() helper method to skip tests when insufficient devices - Apply device requirement checks to sharding-related tests that need >= 2 devices - Tests now skip gracefully in single-device environments instead of forcing device creation - Maintains test coverage on multi-device systems while being CI-friendly
- Change Embedding input_dim from 100 to 96 to be divisible by 8 devices - Add skip condition for tests requiring < 2 devices (sharding not meaningful) - Prevents ValueError when sharding dimensions not divisible by device count - Test now properly skips on single-device systems and passes on multi-device CI
3a86f1e
to
303f241
Compare