Skip to content

Commit ff78e07

Browse files
Fix ModelParallel OOM issue during weight loading
- Modified load_own_variables() to use _direct_assign() for sharded variables - Prevents loading full weight tensors on single device before distribution - Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel - Maintains backward compatibility for non-sharded variables - Enables loading of models like Gemma2 2B/7B without OOM errors - Added EinsumDense layer testing to ModelParallel sharded variable loading - Fixed line length issues and code formatting
1 parent 3fac66f commit ff78e07

File tree

11 files changed

+976
-80
lines changed

11 files changed

+976
-80
lines changed

keras/src/backend/jax/core.py

Lines changed: 369 additions & 43 deletions
Large diffs are not rendered by default.

keras/src/backend/jax/core_test.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,144 @@
1+
"""Test for core.py."""
2+
13
import os
24

5+
os.environ["KERAS_BACKEND"] = "jax"
6+
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
7+
8+
import tempfile
9+
310
import jax
411
import jax.numpy as jnp
512
import numpy as np
613
import pytest
714

815
import keras
916
from keras.src import backend
17+
from keras.src import layers
18+
from keras.src import models
1019
from keras.src import testing
1120
from keras.src.backend.config import is_nnx_enabled
21+
from keras.src.backend.jax.core import JaxVariable
22+
from keras.src.backend.jax.core import _ProtectedShardedArray
23+
24+
if is_nnx_enabled():
25+
from keras.src.backend.jax.core import NnxVariable
26+
from keras.src.utils.variable_loading import load_variable_with_sharded_support
1227

1328
if is_nnx_enabled():
1429
from flax import nnx
1530

1631
from keras.src.backend.jax.core import NnxVariable
1732

1833

34+
class JaxCoreTest(testing.TestCase):
35+
def test_protected_sharded_array_deletion(self):
36+
"""Test _ProtectedShardedArray prevents deletion of sharded arrays."""
37+
# Create a mock sharded array
38+
array = jax.numpy.ones((10, 10))
39+
sharded_array = jax.device_put(array, jax.devices()[0])
40+
sharded_array.addressable_shards = [
41+
jax.device_put(array, d) for d in jax.devices()
42+
]
43+
44+
protected = _ProtectedShardedArray(sharded_array)
45+
46+
# Attempt deletion (should not delete sharded arrays)
47+
protected.delete()
48+
49+
# Verify array is still accessible
50+
self.assertIs(protected._array, sharded_array)
51+
self.assertTrue(
52+
hasattr(protected, "_is_sharded") and protected._is_sharded
53+
)
54+
55+
def test_jax_variable_strong_references_and_logging(self):
56+
"""Test JaxVariable strong references and logging."""
57+
# Create a sharded variable
58+
var = JaxVariable(jax.numpy.ones((100, 100)))
59+
60+
# Check strong references
61+
self.assertTrue(hasattr(var, "_shard_references"))
62+
self.assertGreater(len(var._shard_references), 0)
63+
64+
# Access value multiple times to simulate inference
65+
for _ in range(5):
66+
value = var.value
67+
self.assertIsNotNone(
68+
value
69+
) # Ensure no "Array has been deleted" error
70+
71+
# Final check: Value should still be accessible
72+
self.assertIsNotNone(var.value)
73+
74+
@pytest.mark.skipif(not is_nnx_enabled(), reason="NNX not enabled")
75+
def test_nnx_variable_strong_references_and_logging(self):
76+
"""Test NnxVariable strong references and logging."""
77+
# Create NNX variable with sharding
78+
var = NnxVariable(jax.numpy.ones((50, 50)), layout=("model", None))
79+
80+
# Check strong references
81+
self.assertTrue(hasattr(var, "_shard_references"))
82+
self.assertGreater(len(var._shard_references), 0)
83+
84+
# Access value (simulates inference) and assert no deletion
85+
value = var.value
86+
self.assertIsNotNone(value) # Ensure no "Array has been deleted" error
87+
88+
# Additional accesses to simulate repeated inference
89+
for _ in range(5):
90+
value = var.value
91+
self.assertIsNotNone(value)
92+
93+
def test_variable_loading_with_sharding(self):
94+
"""Test variable loading with sharding support."""
95+
# Create a temporary file for the variable
96+
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f:
97+
temp_file = f.name
98+
np.save(temp_file, jax.numpy.ones((10, 10)))
99+
100+
try:
101+
# Create variable with sharding
102+
var = JaxVariable(jax.numpy.zeros((10, 10)))
103+
# Load data into it
104+
load_variable_with_sharded_support(var, np.load(temp_file))
105+
106+
# Verify it's a JaxVariable with sharding
107+
self.assertIsInstance(var, JaxVariable)
108+
self.assertTrue(hasattr(var, "_shard_references"))
109+
self.assertGreater(len(var._shard_references), 0)
110+
111+
# Access value to ensure no deletion
112+
self.assertIsNotNone(var.value)
113+
finally:
114+
os.unlink(temp_file)
115+
116+
def test_inference_simulation_no_array_deletion(self):
117+
"""Test inference simulation for no 'Array has been deleted' errors."""
118+
# Create a simple model with sharding
119+
inputs = layers.Input(shape=(10,))
120+
x = layers.Dense(50, name="dense")(inputs)
121+
model = models.Model(inputs, x)
122+
123+
# Build and access weights (triggers sharding and protection)
124+
model.build((None, 10))
125+
for var in model.weights:
126+
value = var.value # Access to trigger protection
127+
self.assertIsNotNone(value) # Ensure initial access succeeds
128+
129+
# Simulate inference (multiple accesses) and assert no deletion
130+
test_input = np.random.randn(1, 10)
131+
for _ in range(10):
132+
output = model(test_input)
133+
self.assertIsNotNone(
134+
output
135+
) # Ensure inference succeeds without errors
136+
137+
# Final check: Weights should still be accessible
138+
for var in model.weights:
139+
self.assertIsNotNone(var.value)
140+
141+
19142
@pytest.mark.skipif(
20143
backend.backend() != "jax",
21144
reason="JAX backend specific test for core Variable integration with NNX.",
@@ -25,8 +148,8 @@
25148
reason="Test requires NNX backend to be enabled by default for setup.",
26149
)
27150
class NnxVariableTest(testing.TestCase):
28-
def setup(self):
29-
super().setup()
151+
def setUp(self):
152+
super().setUp()
30153

31154
class NNXModel(nnx.Module):
32155
def __init__(self, rngs):

0 commit comments

Comments
 (0)