|
| 1 | +"""Test for core.py.""" |
| 2 | + |
1 | 3 | import os
|
2 | 4 |
|
| 5 | +os.environ["KERAS_BACKEND"] = "jax" |
| 6 | +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" |
| 7 | + |
| 8 | +import tempfile |
| 9 | + |
3 | 10 | import jax
|
4 | 11 | import jax.numpy as jnp
|
5 | 12 | import numpy as np
|
6 | 13 | import pytest
|
7 | 14 |
|
8 | 15 | import keras
|
9 | 16 | from keras.src import backend
|
| 17 | +from keras.src import layers |
| 18 | +from keras.src import models |
10 | 19 | from keras.src import testing
|
11 | 20 | from keras.src.backend.config import is_nnx_enabled
|
| 21 | +from keras.src.backend.jax.core import JaxVariable |
| 22 | +from keras.src.backend.jax.core import _ProtectedShardedArray |
| 23 | + |
| 24 | +if is_nnx_enabled(): |
| 25 | + from keras.src.backend.jax.core import NnxVariable |
| 26 | +from keras.src.utils.variable_loading import load_variable_with_sharded_support |
12 | 27 |
|
13 | 28 | if is_nnx_enabled():
|
14 | 29 | from flax import nnx
|
15 | 30 |
|
16 | 31 | from keras.src.backend.jax.core import NnxVariable
|
17 | 32 |
|
18 | 33 |
|
| 34 | +class JaxCoreTest(testing.TestCase): |
| 35 | + def test_protected_sharded_array_deletion(self): |
| 36 | + """Test _ProtectedShardedArray prevents deletion of sharded arrays.""" |
| 37 | + # Create a mock sharded array |
| 38 | + array = jax.numpy.ones((10, 10)) |
| 39 | + sharded_array = jax.device_put(array, jax.devices()[0]) |
| 40 | + sharded_array.addressable_shards = [ |
| 41 | + jax.device_put(array, d) for d in jax.devices() |
| 42 | + ] |
| 43 | + |
| 44 | + protected = _ProtectedShardedArray(sharded_array) |
| 45 | + |
| 46 | + # Attempt deletion (should not delete sharded arrays) |
| 47 | + protected.delete() |
| 48 | + |
| 49 | + # Verify array is still accessible |
| 50 | + self.assertIs(protected._array, sharded_array) |
| 51 | + self.assertTrue( |
| 52 | + hasattr(protected, "_is_sharded") and protected._is_sharded |
| 53 | + ) |
| 54 | + |
| 55 | + def test_jax_variable_strong_references_and_logging(self): |
| 56 | + """Test JaxVariable strong references and logging.""" |
| 57 | + # Create a sharded variable |
| 58 | + var = JaxVariable(jax.numpy.ones((100, 100))) |
| 59 | + |
| 60 | + # Check strong references |
| 61 | + self.assertTrue(hasattr(var, "_shard_references")) |
| 62 | + self.assertGreater(len(var._shard_references), 0) |
| 63 | + |
| 64 | + # Access value multiple times to simulate inference |
| 65 | + for _ in range(5): |
| 66 | + value = var.value |
| 67 | + self.assertIsNotNone( |
| 68 | + value |
| 69 | + ) # Ensure no "Array has been deleted" error |
| 70 | + |
| 71 | + # Final check: Value should still be accessible |
| 72 | + self.assertIsNotNone(var.value) |
| 73 | + |
| 74 | + @pytest.mark.skipif(not is_nnx_enabled(), reason="NNX not enabled") |
| 75 | + def test_nnx_variable_strong_references_and_logging(self): |
| 76 | + """Test NnxVariable strong references and logging.""" |
| 77 | + # Create NNX variable with sharding |
| 78 | + var = NnxVariable(jax.numpy.ones((50, 50)), layout=("model", None)) |
| 79 | + |
| 80 | + # Check strong references |
| 81 | + self.assertTrue(hasattr(var, "_shard_references")) |
| 82 | + self.assertGreater(len(var._shard_references), 0) |
| 83 | + |
| 84 | + # Access value (simulates inference) and assert no deletion |
| 85 | + value = var.value |
| 86 | + self.assertIsNotNone(value) # Ensure no "Array has been deleted" error |
| 87 | + |
| 88 | + # Additional accesses to simulate repeated inference |
| 89 | + for _ in range(5): |
| 90 | + value = var.value |
| 91 | + self.assertIsNotNone(value) |
| 92 | + |
| 93 | + def test_variable_loading_with_sharding(self): |
| 94 | + """Test variable loading with sharding support.""" |
| 95 | + # Create a temporary file for the variable |
| 96 | + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f: |
| 97 | + temp_file = f.name |
| 98 | + np.save(temp_file, jax.numpy.ones((10, 10))) |
| 99 | + |
| 100 | + try: |
| 101 | + # Create variable with sharding |
| 102 | + var = JaxVariable(jax.numpy.zeros((10, 10))) |
| 103 | + # Load data into it |
| 104 | + load_variable_with_sharded_support(var, np.load(temp_file)) |
| 105 | + |
| 106 | + # Verify it's a JaxVariable with sharding |
| 107 | + self.assertIsInstance(var, JaxVariable) |
| 108 | + self.assertTrue(hasattr(var, "_shard_references")) |
| 109 | + self.assertGreater(len(var._shard_references), 0) |
| 110 | + |
| 111 | + # Access value to ensure no deletion |
| 112 | + self.assertIsNotNone(var.value) |
| 113 | + finally: |
| 114 | + os.unlink(temp_file) |
| 115 | + |
| 116 | + def test_inference_simulation_no_array_deletion(self): |
| 117 | + """Test inference simulation for no 'Array has been deleted' errors.""" |
| 118 | + # Create a simple model with sharding |
| 119 | + inputs = layers.Input(shape=(10,)) |
| 120 | + x = layers.Dense(50, name="dense")(inputs) |
| 121 | + model = models.Model(inputs, x) |
| 122 | + |
| 123 | + # Build and access weights (triggers sharding and protection) |
| 124 | + model.build((None, 10)) |
| 125 | + for var in model.weights: |
| 126 | + value = var.value # Access to trigger protection |
| 127 | + self.assertIsNotNone(value) # Ensure initial access succeeds |
| 128 | + |
| 129 | + # Simulate inference (multiple accesses) and assert no deletion |
| 130 | + test_input = np.random.randn(1, 10) |
| 131 | + for _ in range(10): |
| 132 | + output = model(test_input) |
| 133 | + self.assertIsNotNone( |
| 134 | + output |
| 135 | + ) # Ensure inference succeeds without errors |
| 136 | + |
| 137 | + # Final check: Weights should still be accessible |
| 138 | + for var in model.weights: |
| 139 | + self.assertIsNotNone(var.value) |
| 140 | + |
| 141 | + |
19 | 142 | @pytest.mark.skipif(
|
20 | 143 | backend.backend() != "jax",
|
21 | 144 | reason="JAX backend specific test for core Variable integration with NNX.",
|
|
25 | 148 | reason="Test requires NNX backend to be enabled by default for setup.",
|
26 | 149 | )
|
27 | 150 | class NnxVariableTest(testing.TestCase):
|
28 |
| - def setup(self): |
29 |
| - super().setup() |
| 151 | + def setUp(self): |
| 152 | + super().setUp() |
30 | 153 |
|
31 | 154 | class NNXModel(nnx.Module):
|
32 | 155 | def __init__(self, rngs):
|
|
0 commit comments