From 76515db8fa5eca3d20d136424586ad919bc17290 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 30 Apr 2025 12:27:41 -0300 Subject: [PATCH 1/8] RoPEMultiHeadAttention implementation - all tests passed --- .gitignore | 1 + .../attention/rope_multi_head_attention.py | 194 ++++++++++++++++++ .../attention/test_RoPEMultiHeadAttention.py | 136 ++++++++++++ 3 files changed, 331 insertions(+) create mode 100644 jaxgarden/attention/rope_multi_head_attention.py create mode 100644 tests/attention/test_RoPEMultiHeadAttention.py diff --git a/.gitignore b/.gitignore index 6437e9e..fbe9ab1 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ wheels/ # Virtual environments .venv +jax_env # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/jaxgarden/attention/rope_multi_head_attention.py b/jaxgarden/attention/rope_multi_head_attention.py new file mode 100644 index 0000000..57bddd5 --- /dev/null +++ b/jaxgarden/attention/rope_multi_head_attention.py @@ -0,0 +1,194 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def rotate_half(x): + """Rotates half the hidden dims of the input tensor.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + # Builds the rotated tensor by concatenating the negated second half + # and the first half along the last dimension. + return jnp.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(x, cos_emb, sin_emb): + """Applies Rotary Positional Embedding to the input tensor. + + Args: + x: Input tensor, e.g., query or key (batch, seq_len, num_heads, head_dim) + cos_emb: Cosine component of the positional embedding. + Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. + sin_emb: Sine component of the positional embedding. + Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. + Returns: + Tensor with RoPE applied. + """ + # Applying the rotation formula: + # x_rotated = x * cos(theta) + rotate_half(x) * sin(theta) + # Ensure shapes are broadcastable: cos_emb and sin_emb should have dimensions + # for sequence length and features, matching the corresponding dimensions in x. + # Typically, precomputed embeddings have shape (seq_len, head_dim) + # or (1, seq_len, 1, head_dim) for easy broadcasting. + return (x * cos_emb) + (rotate_half(x) * sin_emb) + + +def precompute_rotary_embeddings(seq_len, head_dim, base=10000.0): + """Precomputes the RoPE cosine and sine embeddings. + + Args: + seq_len: The maximum sequence length. + head_dim: The dimension of each attention head (must be even). + base: The base value for the inverse frequency calculation. + + Returns: + cos_emb: Cosine embeddings (1, seq_len, 1, head_dim) + sin_emb: Sine embeddings (1, seq_len, 1, head_dim) + """ + if head_dim % 2 != 0: + raise ValueError(f"head_dim must be even, got {head_dim}") + + # Calculate inverse frequencies (theta_i) + # theta_i = 1 / (base^(2*i / head_dim)) for i in [0, 1, ..., head_dim/2 - 1] + inv_freq = 1.0 / (base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)) + + # Calculate position indices (m) + pos = jnp.arange(seq_len, dtype=jnp.float32) + + # Calculate angles (m * theta_i) + freqs = jnp.outer(pos, inv_freq) # Shape: (seq_len, head_dim / 2) + + # Duplicate frequencies for the full head dimension (for both elements in pairs) + emb = jnp.concatenate((freqs, freqs), axis=-1) # Shape: (seq_len, head_dim) + + # Calculate cosine and sine embeddings + cos_emb = jnp.cos(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) + sin_emb = jnp.sin(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) + + return cos_emb, sin_emb + + +class RoPEMultiHeadAttention(nn.Module): + """Multi-Head Attention with Rotary Positional Embeddings.""" + + num_heads: int + head_dim: int + rope_base: float = 10000.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # Check head_dim validity early during setup + if self.head_dim % 2 != 0: + raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") + + # Define layers here - they will be initialized when the module is first called + total_head_dim = self.num_heads * self.head_dim + self.query_proj = nn.Dense( + features=total_head_dim, use_bias=False, dtype=self.dtype, name="query_proj" + ) + self.key_proj = nn.Dense( + features=total_head_dim, use_bias=False, dtype=self.dtype, name="key_proj" + ) + self.value_proj = nn.Dense( + features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj" + ) + self.output_proj = nn.Dense( + features=self.num_heads * self.head_dim, # Output should match embed_dim + use_bias=False, + dtype=self.dtype, + name="output_proj" + ) + + + @nn.compact + def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None): + """Forward pass for RoPE MHA. + + Args: + x: Input tensor (batch_size, seq_len, embed_dim). + mask: Optional attention mask (batch_size, 1, seq_len, seq_len) + or (batch_size, 1, 1, seq_len) for causal masking. + Mask values should be 0 where attention is allowed, -inf otherwise. + Flax convention often uses boolean masks (True=masked). We'll handle both. + + Returns: + Output tensor (batch_size, seq_len, embed_dim). + """ + batch_size, seq_len, embed_dim = x.shape + total_head_dim = self.num_heads * self.head_dim + + if embed_dim != total_head_dim: + raise ValueError( + f"embed_dim ({embed_dim}) must equal num_heads*head_dim" + f" ({total_head_dim})" + ) + # Note: head_dim even check moved to setup for earlier failure + + # 1. Linear projections for Q, K, V + query = self.query_proj(x) + key = self.key_proj(x) + value = self.value_proj(x) + + # 2. Reshape for multi-head processing + # (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim) + query = query.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + key = key.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + value = value.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + + # 3. Precompute RoPE embeddings (cosine and sine) + # We compute them dynamically based on the input sequence length + cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, self.head_dim, base=self.rope_base) + # Ensure RoPE embeddings have correct dtype + cos_emb = cos_emb.astype(self.dtype) + sin_emb = sin_emb.astype(self.dtype) + + + # 4. Apply RoPE to Query and Key + query = apply_rotary_pos_emb(query, cos_emb, sin_emb) + key = apply_rotary_pos_emb(key, cos_emb, sin_emb) + + # 5. Transpose for attention calculation: (batch, num_heads, seq_len, head_dim) + query = query.transpose((0, 2, 1, 3)) + key = key.transpose((0, 2, 1, 3)) + value = value.transpose((0, 2, 1, 3)) + + # 6. Scaled Dot-Product Attention + # Attention scores: (batch, num_heads, seq_len, seq_len) + attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt(self.head_dim) + + # Apply mask (if provided) + if mask is not None: + # Standard Flax causal mask is boolean (True means mask) + # nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len) + # Check if mask needs broadcasting or conversion + if mask.ndim == 2: # Likely (seq_len, seq_len) + mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) + elif mask.ndim == 3: # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) + if mask.shape[0] != batch_size and mask.shape[0] == 1: + mask = mask[None, :, :, :] # Add head dim: (batch, 1, seq_len, seq_len) + else: # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) + mask = mask[:, None, :, :] + + # Apply mask: Use large negative number where mask is True + # Assuming boolean mask convention (True = mask) common in Flax examples + attn_scores = jnp.where(mask, -jnp.inf, attn_scores) + + # Softmax to get attention weights + attn_weights = jax.nn.softmax( + attn_scores, axis=-1 + ) # Shape: (batch, num_heads, seq_len, seq_len) + + # Apply attention weights to Value + # Output per head: (batch, num_heads, seq_len, head_dim) + attn_output = jnp.matmul(attn_weights, value) + + # 7. Concatenate heads and final projection + # Transpose back: (batch, seq_len, num_heads, head_dim) + attn_output = attn_output.transpose((0, 2, 1, 3)) + # Reshape to (batch, seq_len, embed_dim) + attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim) + + # Final linear projection + output = self.output_proj(attn_output) # Use self.output_proj defined in setup + + return output diff --git a/tests/attention/test_RoPEMultiHeadAttention.py b/tests/attention/test_RoPEMultiHeadAttention.py new file mode 100644 index 0000000..85841da --- /dev/null +++ b/tests/attention/test_RoPEMultiHeadAttention.py @@ -0,0 +1,136 @@ +"""Tests for the RoPEMultiHeadAttention class.""" + +import flax.linen as nn +import jax +import jax.numpy as jnp +import pytest + +from jaxgarden.attention.rope_multi_head_attention import ( + RoPEMultiHeadAttention, + apply_rotary_pos_emb, + precompute_rotary_embeddings, + rotate_half, +) + + +def test_rotate_half(): + """Tests the rotate_half function.""" + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (2, 4, 6, 8)) # batch, seq, heads, dim + rotated_x = rotate_half(x) + + assert rotated_x.shape == x.shape + # Check specific values after rotation + x1 = x[..., ::2] + x2 = x[..., 1::2] + expected = jnp.concatenate((-x2, x1), axis=-1) + assert jnp.allclose(rotated_x, expected) + + +def test_precompute_rotary_embeddings(): + """Tests the precompute_rotary_embeddings function.""" + seq_len = 16 + head_dim = 8 + base = 10000.0 + + cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim, base) + + assert cos_emb.shape == (1, seq_len, 1, head_dim) + assert sin_emb.shape == (1, seq_len, 1, head_dim) + + # Check properties - e.g., cos^2 + sin^2 = 1 + assert jnp.allclose(cos_emb**2 + sin_emb**2, jnp.ones_like(cos_emb), atol=1e-6) + + # Check different base value + cos_emb_b2, sin_emb_b2 = precompute_rotary_embeddings(seq_len, head_dim, base=500.0) + assert not jnp.allclose(cos_emb, cos_emb_b2) + + # Test with odd head_dim (should raise error) + with pytest.raises(ValueError, match="head_dim must be even"): + precompute_rotary_embeddings(seq_len, head_dim=7) + + +def test_apply_rotary_pos_emb(): + """Tests the apply_rotary_pos_emb function.""" + key = jax.random.PRNGKey(1) + batch, seq_len, num_heads, head_dim = 2, 16, 4, 8 + x = jax.random.normal(key, (batch, seq_len, num_heads, head_dim)) + + cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, head_dim) + + rotated_x = apply_rotary_pos_emb(x, cos_emb, sin_emb) + + assert rotated_x.shape == x.shape + # Applying RoPE again should not give the original x (unless pos=0, which isn't the whole seq) + assert not jnp.allclose(rotated_x, x) + + + +# --- Test RoPEMultiHeadAttention Module --- + +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_rope_mha_forward_pass(dtype): + """Tests the forward pass of RoPEMultiHeadAttention.""" + key = jax.random.PRNGKey(2) + batch_size = 2 + seq_len = 16 + num_heads = 4 + head_dim = 8 + embed_dim = num_heads * head_dim + + x = jax.random.normal(key, (batch_size, seq_len, embed_dim), dtype=dtype) + + rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim, dtype=dtype) + params = rope_mha.init(key, x)["params"] + output = rope_mha.apply({"params": params}, x) + + assert output.shape == (batch_size, seq_len, embed_dim) + assert output.dtype == dtype + + +def test_rope_mha_masking(): + """Tests causal masking in RoPEMultiHeadAttention.""" + key = jax.random.PRNGKey(3) + batch_size = 1 + seq_len = 4 + num_heads = 2 + head_dim = 4 + embed_dim = num_heads * head_dim + + x = jax.random.normal(key, (batch_size, seq_len, embed_dim)) + # Create a causal mask (True means masked) + causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq) + + rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim) + params = rope_mha.init(key, x, causal_mask)["params"] + + # Apply without mask + output_unmasked = rope_mha.apply({"params": params}, x) + + # Apply with mask + output_masked = rope_mha.apply({"params": params}, x, mask=causal_mask) + + # Basic check: outputs should differ if mask has an effect + assert not jnp.allclose(output_unmasked, output_masked, atol=1e-5) + + # More rigorous check (requires inspecting attention weights, omitted for brevity) + + +def test_rope_mha_errors(): + """Tests error conditions for RoPEMultiHeadAttention.""" + key = jax.random.PRNGKey(4) + rope_mha_odd_dim = RoPEMultiHeadAttention(num_heads=8, head_dim=7) + x_dummy_odd = jax.random.normal(key, (2, 16, 8 * 7)) + # Test with odd head_dim (should raise error during initialization/setup) + with pytest.raises(ValueError, match=r"head_dim \(\d+\) must be even"): + + rope_mha_odd_dim.init(key, x_dummy_odd) + + + # Test with mismatched embed_dim (should raise error during forward pass / init) + rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32 + x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim + + with pytest.raises(ValueError, match=r"embed_dim \(\d+\) must equal"): + rope_mha.init(key, x_mismatch) + From 0a64c34ed675f89533717087cbd7d57950d02f97 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 30 Apr 2025 14:04:34 -0300 Subject: [PATCH 2/8] mypy fixed for class RoPEMultiHeadAttention --- .../attention/rope_multi_head_attention.py | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/jaxgarden/attention/rope_multi_head_attention.py b/jaxgarden/attention/rope_multi_head_attention.py index 57bddd5..0ffe63c 100644 --- a/jaxgarden/attention/rope_multi_head_attention.py +++ b/jaxgarden/attention/rope_multi_head_attention.py @@ -1,9 +1,10 @@ + import flax.linen as nn import jax import jax.numpy as jnp -def rotate_half(x): +def rotate_half(x: jnp.ndarray) -> jnp.ndarray: """Rotates half the hidden dims of the input tensor.""" x1 = x[..., ::2] x2 = x[..., 1::2] @@ -12,7 +13,7 @@ def rotate_half(x): return jnp.concatenate((-x2, x1), axis=-1) -def apply_rotary_pos_emb(x, cos_emb, sin_emb): +def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndarray) -> jnp.ndarray: """Applies Rotary Positional Embedding to the input tensor. Args: @@ -33,7 +34,8 @@ def apply_rotary_pos_emb(x, cos_emb, sin_emb): return (x * cos_emb) + (rotate_half(x) * sin_emb) -def precompute_rotary_embeddings(seq_len, head_dim, base=10000.0): +def precompute_rotary_embeddings(seq_len: int, head_dim: int, + base: float = 10000.0) -> tuple[jnp.ndarray, jnp.ndarray]: """Precomputes the RoPE cosine and sine embeddings. Args: @@ -76,7 +78,9 @@ class RoPEMultiHeadAttention(nn.Module): rope_base: float = 10000.0 dtype: jnp.dtype = jnp.float32 - def setup(self): + # Corrected method definition + def setup(self) -> None: # Added -> None return type + """Initializes the attention projections.""" # Check head_dim validity early during setup if self.head_dim % 2 != 0: raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") @@ -101,7 +105,9 @@ def setup(self): @nn.compact - def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None): + # Corrected method definition (return type was missing) + # Also using Optional for the mask type hint for clarity with None default + def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray: """Forward pass for RoPE MHA. Args: @@ -154,7 +160,8 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None): # 6. Scaled Dot-Product Attention # Attention scores: (batch, num_heads, seq_len, seq_len) - attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt(self.head_dim) + attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt( + self.head_dim).astype(self.dtype) # Ensure sqrt is correct dtype # Apply mask (if provided) if mask is not None: @@ -163,20 +170,35 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None): # Check if mask needs broadcasting or conversion if mask.ndim == 2: # Likely (seq_len, seq_len) mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) - elif mask.ndim == 3: # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) - if mask.shape[0] != batch_size and mask.shape[0] == 1: - mask = mask[None, :, :, :] # Add head dim: (batch, 1, seq_len, seq_len) - else: # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) - mask = mask[:, None, :, :] + elif mask.ndim == 3 and mask.shape[1] != self.num_heads: + # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) + mask = mask[:, None, :, :] + # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) + + # Ensure mask is broadcastable to attn_scores shape + mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len) + if mask.shape != mask_shape_expected: + # Attempt broadcasting common causal mask shapes + if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == (batch_size, 1, + seq_len, seq_len): # Causal mask for all batches/heads + mask = jnp.broadcast_to(mask, mask_shape_expected) + # Add other broadcasting cases if needed + else: + raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") + # Apply mask: Use large negative number where mask is True + # (or where mask value is 0 if using 0/-inf convention) # Assuming boolean mask convention (True = mask) common in Flax examples - attn_scores = jnp.where(mask, -jnp.inf, attn_scores) + # If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask + attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores) + # Softmax to get attention weights attn_weights = jax.nn.softmax( attn_scores, axis=-1 - ) # Shape: (batch, num_heads, seq_len, seq_len) + ).astype(self.dtype) # Shape: (batch, num_heads, seq_len, seq_len) + # Apply attention weights to Value # Output per head: (batch, num_heads, seq_len, head_dim) From 8fb9b0def6dca20f76af44acae6c6d8e86ee2e08 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 30 Apr 2025 14:24:25 -0300 Subject: [PATCH 3/8] useless comments removed --- jaxgarden/attention/rope_multi_head_attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxgarden/attention/rope_multi_head_attention.py b/jaxgarden/attention/rope_multi_head_attention.py index 0ffe63c..e1dd9d6 100644 --- a/jaxgarden/attention/rope_multi_head_attention.py +++ b/jaxgarden/attention/rope_multi_head_attention.py @@ -78,7 +78,6 @@ class RoPEMultiHeadAttention(nn.Module): rope_base: float = 10000.0 dtype: jnp.dtype = jnp.float32 - # Corrected method definition def setup(self) -> None: # Added -> None return type """Initializes the attention projections.""" # Check head_dim validity early during setup @@ -105,7 +104,6 @@ def setup(self) -> None: # Added -> None return type @nn.compact - # Corrected method definition (return type was missing) # Also using Optional for the mask type hint for clarity with None default def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray: """Forward pass for RoPE MHA. From 6bac808e87ad8f403d62cceeb413a14180e9637f Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Thu, 1 May 2025 13:31:08 -0300 Subject: [PATCH 4/8] README file updated with RoPEMultiHeadAttention usage --- README.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c0dadc3..de20bd2 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ model.from_hf(model_id) # but without the dependency of the whole `transformers` library. # Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX. tokenizer = Tokenizer.from_pretrained(model_id) - + text = "The meaning of life is" model_inputs = tokenizer.encode(text) output = model.generate(**model_inputs, max_length=20, do_sample=True) @@ -97,6 +97,31 @@ mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len) output = attention(x, mask=mask) ``` +### RoPEMultiHeadAttention Module (Flax NNX) + +```python +import jax +import jax.numpy as jnp +import flax.linen as nn +from jaxgarden.attention.rope_multi_head_attention import RoPEMultiHeadAttention + +# 1. Setup +key = jax.random.PRNGKey(0) +batch_size, seq_len = 2, 16 +num_heads, head_dim = 4, 32 +embed_dim = num_heads * head_dim +x = jnp.ones((batch_size, seq_len, embed_dim)) + +# 2. Instantiate Module +attention = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim) + +# 3. Initialize Parameters +params = attention.init(key, x)['params'] + +# 4. Apply Module (Forward Pass) +output = attention.apply({'params': params}, x) +``` + ### Functional API #### Dot Product Attention with Implementation Selection From ad55578d5d93bd248794657761f9d56dd9bcaadd Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Thu, 1 May 2025 22:28:40 -0300 Subject: [PATCH 5/8] Docstring (arxiv) added to ROPE implementation --- jaxgarden/attention/rope_multi_head_attention.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/jaxgarden/attention/rope_multi_head_attention.py b/jaxgarden/attention/rope_multi_head_attention.py index e1dd9d6..7a0be73 100644 --- a/jaxgarden/attention/rope_multi_head_attention.py +++ b/jaxgarden/attention/rope_multi_head_attention.py @@ -1,3 +1,16 @@ +""" +JAX/Flax implementation of Multi-Head Attention with Rotary Positional Embedding (RoPE). + +This code implements the RoPE technique within a standard Multi-Head Attention +framework. RoPE injects relative positional information by rotating pairs of +features in the Query and Key vectors based on their absolute position before +the attention calculation. + +The method was introduced in the paper: +"RoFormer: Enhanced Transformer with Rotary Position Embedding" +by Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. +arXiv:2104.09864v5 [cs.CL] (Submitted on 20 Apr 2021) +""" import flax.linen as nn import jax From 1d04a6063731c53190a1ad378c8c179bb69bfd00 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 14 May 2025 09:55:37 -0300 Subject: [PATCH 6/8] My local work: updated readme, added MoE implementation and tests --- README.md | 33 +++++ jaxgarden/functional/mixture_of_experts.py | 120 ++++++++++++++++ tests/functional/test_MoE.py | 158 +++++++++++++++++++++ 3 files changed, 311 insertions(+) create mode 100644 jaxgarden/functional/mixture_of_experts.py create mode 100644 tests/functional/test_MoE.py diff --git a/README.md b/README.md index de20bd2..243e39a 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,39 @@ params = attention.init(key, x)['params'] output = attention.apply({'params': params}, x) ``` +### Mixture of Experts (Flax NNX) + +```python +import jax +import jax.numpy as jnp +import flax.linen as nn +from jaxgarden.functional.MoE import MixtureOfExperts + +# 1. Setup +batch_size = 4 +input_dim = 10 +num_experts = 3 +expert_output_dim = 5 +key = jax.random.PRNGKey(0) +dummy_input = jax.random.normal(key, (batch_size, input_dim)) + +# 2. Instantiate Module +moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim) + +# 3. Initialize the model parameters (weights and biases) +key, params_key = jax.random.split(key) +params = moe_model.init(params_key, dummy_input)['params'] + +print("Initialized MoE parameters:", jax.tree_util.tree_map(lambda x: x.shape, params)) + +# 4. Apply Module (Forward Pass) +output = moe_model.apply({'params': params}, dummy_input) + +print("\nInput shape:", dummy_input.shape) +print("Output shape:", output.shape) +``` + + ### Functional API #### Dot Product Attention with Implementation Selection diff --git a/jaxgarden/functional/mixture_of_experts.py b/jaxgarden/functional/mixture_of_experts.py new file mode 100644 index 0000000..654ada6 --- /dev/null +++ b/jaxgarden/functional/mixture_of_experts.py @@ -0,0 +1,120 @@ +""" +JAX/Flax implementation of a Mixture of Experts (MoE) layer. + +This code provides a conceptual implementation of a Mixture of Experts layer, +a neural network architecture where multiple specialized "expert" sub-networks +are combined. A gating network determines which expert (or combination of +experts) processes a given input, allowing the model to learn to route +different parts of the input space to specialized modules. This can lead to +models with higher capacity and better efficiency, especially in sparse +formulations where only a subset of experts are activated per input. + +The core concept of Mixture of Experts was introduced in the paper: +"Adaptive Mixtures of Local Experts" +by Robert A. Jacobs, Michael I. Jordan, Steven J. Nowlan, and Geoffrey E. Hinton. +Published in Neural Computation, Volume 3, Issue 1, Pages 79-87, 1991. +Available at: https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf +""" + +import flax.linen as nn +import jax.numpy as jnp + + +# Expert Network +class Expert(nn.Module): + num_outputs: int + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """ + A simple feed-forward expert network. + Args: + x: Input tensor. + Returns: + Output tensor from the expert. + """ + x = nn.Dense(features=self.num_outputs * 2, name="expert_dense_1")(x) + x = nn.relu(x) + x = nn.Dense(features=self.num_outputs, name="expert_dense_2")(x) + return x + +# Gating Network +class GatingNetwork(nn.Module): + num_experts: int # The number of experts to choose from + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """ + A simple gating network that outputs weights for each expert. + Args: + x: Input tensor. + Returns: + A tensor of weights for each expert (after softmax). + """ + # The gating network is often a linear layer followed by a softmax + gate_logits = nn.Dense(features=self.num_experts, name="gating_dense")(x) + gate_weights = nn.softmax(gate_logits, axis=-1) + return gate_weights + +# Mixture of Experts Layer +class MixtureOfExperts(nn.Module): + num_experts: int + expert_output_dim: int + + # You could also type self.experts and self.gating_network here for completeness, + # though mypy might not require it if they are only set in setup. + # experts: List[Expert] + # gating_network: GatingNetwork + + def setup(self) -> None: + """ + Initialize the experts and the gating network. + This method is called by Flax automatically. + """ + # List of Expert modules + # Using nn.scan or vmap can be more efficient for identical experts, + self.experts = [Expert(num_outputs=self.expert_output_dim, + name=f"expert_{i}") for i in range(self.num_experts)] + # Create the GatingNetwork module + self.gating_network = GatingNetwork(num_experts=self.num_experts, name="gating_network") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """ + Forward pass for the Mixture of Experts layer. + Args: + x: Input tensor. + Returns: + The combined output from the experts. + """ + # 1. Get the gating weights + # Input shape: (batch_size, input_dim) + # Gate weights shape: (batch_size, num_experts) + gate_weights = self.gating_network(x) + + # 2. Get the outputs from all experts + # We'll store expert outputs in a list + expert_outputs = [] + for i in range(self.num_experts): + # Expert output shape: (batch_size, expert_output_dim) + expert_out = self.experts[i](x) + expert_outputs.append(expert_out) + + # Stack expert outputs along a new dimension to facilitate weighted sum + # Stacked expert_outputs shape: (batch_size, num_experts, expert_output_dim) + stacked_expert_outputs = jnp.stack(expert_outputs, axis=1) + + # 3. Combine expert outputs using the gating weights + # We want to weight each expert's output for each item in the batch. + # Gate weights shape: (batch_size, num_experts) + # Needs to be broadcast to: (batch_size, num_experts, expert_output_dim) + # to multiply with stacked_expert_outputs. + expanded_gate_weights = jnp.expand_dims(gate_weights, axis=-1) + + # Weighted outputs shape: (batch_size, num_experts, expert_output_dim) + weighted_expert_outputs = stacked_expert_outputs * expanded_gate_weights + + # Sum the weighted outputs along the num_experts dimension + # Final output shape: (batch_size, expert_output_dim) + final_output = jnp.sum(weighted_expert_outputs, axis=1) + + return final_output diff --git a/tests/functional/test_MoE.py b/tests/functional/test_MoE.py new file mode 100644 index 0000000..ae93eb5 --- /dev/null +++ b/tests/functional/test_MoE.py @@ -0,0 +1,158 @@ +import jax +import jax.numpy as jnp +import pytest + +from jaxgarden.functional.mixture_of_experts import Expert, GatingNetwork, MixtureOfExperts + + +@pytest.fixture +def key(): + """Provides a JAX PRNG key for tests.""" + return jax.random.PRNGKey(0) + +@pytest.fixture +def dummy_input_batch(): + """Provides a dummy input batch.""" + # (batch_size, input_dim) + return jnp.ones((4, 16)) # Batch of 4, input dimension 16 + +def test_expert_initialization_and_call(key, dummy_input_batch): + """Tests the Expert module initialization and forward pass.""" + expert_output_dim = 8 + expert_model = Expert(num_outputs=expert_output_dim) + + params = expert_model.init(key, dummy_input_batch)['params'] + output = expert_model.apply({'params': params}, dummy_input_batch) + + assert output.shape == (dummy_input_batch.shape[0], expert_output_dim) + assert output.dtype == jnp.float32 # Default dtype for Dense + +def test_gating_network_initialization_and_call(key, dummy_input_batch): + """Tests the GatingNetwork module initialization and forward pass.""" + num_experts = 3 + gating_model = GatingNetwork(num_experts=num_experts) + + params = gating_model.init(key, dummy_input_batch)['params'] + gate_weights = gating_model.apply({'params': params}, dummy_input_batch) + + assert gate_weights.shape == (dummy_input_batch.shape[0], num_experts) + # Check if softmax output sums to 1 for each item in the batch + assert jnp.allclose(jnp.sum(gate_weights, axis=-1), + jnp.ones(dummy_input_batch.shape[0]), atol=1e-6) + assert jnp.all(gate_weights >= 0) # Probabilities should be non-negative + +def test_mixture_of_experts_initialization_and_call(key, dummy_input_batch): + """Tests the MixtureOfExperts module initialization and forward pass.""" + num_experts = 5 + expert_output_dim = 10 + moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim) + + params = moe_model.init(key, dummy_input_batch)['params'] + final_output = moe_model.apply({'params': params}, dummy_input_batch) + + assert final_output.shape == (dummy_input_batch.shape[0], expert_output_dim) + assert final_output.dtype == jnp.float32 + +def test_mixture_of_experts_output_logic(key, dummy_input_batch): + """ + Tests the output logic of MoE by checking if expert outputs are combined + as expected based on gate weights. + """ + num_experts = 2 + expert_output_dim = 3 + # input_dim = dummy_input_batch.shape[1] # Not strictly needed for this test logic after init + + # Create an MoE model + moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim) + variables = moe_model.init(key, dummy_input_batch) + params = variables['params'] + + # --- Manually compute expected output for a specific case --- + + # Get gate weights by applying a GatingNetwork instance with its specific parameters + gating_sub_model = GatingNetwork(num_experts=num_experts) + gate_weights = gating_sub_model.apply({'params': params['gating_network']}, dummy_input_batch) + + # Get individual expert outputs + expert_sub_model_template = Expert(num_outputs=expert_output_dim) + + expert0_output = expert_sub_model_template.apply({'params': + params['expert_0']}, dummy_input_batch) + expert1_output = expert_sub_model_template.apply({'params': + params['expert_1']}, dummy_input_batch) + + # Expected combination + # gate_weights is (batch_size, num_experts) + # expertN_output is (batch_size, expert_output_dim) + # We use slicing and broadcasting for the weighted sum. + expected_output_manual = \ + gate_weights[:, 0:1] * expert0_output + \ + gate_weights[:, 1:2] * expert1_output + + moe_output = moe_model.apply({'params': params}, dummy_input_batch) + + assert moe_output.shape == (dummy_input_batch.shape[0], expert_output_dim) + assert jnp.allclose(moe_output, expected_output_manual, atol=1e-6) + +def test_mixture_of_experts_single_expert_case(key, dummy_input_batch): + """Tests MoE with only one expert.""" + num_experts = 1 + expert_output_dim = 7 + moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim) + + params = moe_model.init(key, dummy_input_batch)['params'] + final_output = moe_model.apply({'params': params}, dummy_input_batch) + + # The output should be identical to the output of the single expert + # as gate_weights will be [[1.], [1.], ...] + + # Instantiate an Expert model to apply its specific parameters + expert_sub_model = Expert(num_outputs=expert_output_dim) + expert_output = expert_sub_model.apply({'params': params['expert_0']}, dummy_input_batch) + + assert final_output.shape == (dummy_input_batch.shape[0], expert_output_dim) + assert jnp.allclose(final_output, expert_output, atol=1e-6) + + # Check gate weights for the single expert case + # Instantiate a GatingNetwork model to apply its specific parameters + gating_sub_model = GatingNetwork(num_experts=num_experts) + gate_weights = gating_sub_model.apply({'params': params['gating_network']}, dummy_input_batch) + + assert gate_weights.shape == (dummy_input_batch.shape[0], 1) + assert jnp.allclose(gate_weights, jnp.ones_like(gate_weights), atol=1e-6) + + +def test_expert_different_output_dims(key): + """Tests Expert with varying output dimensions.""" + input_data = jnp.ones((2, 5)) # batch=2, features=5 + for out_dim in [1, 5, 20]: + expert_model = Expert(num_outputs=out_dim) + params = expert_model.init(key, input_data)['params'] + output = expert_model.apply({'params': params}, input_data) + assert output.shape == (input_data.shape[0], out_dim) + +def test_gating_network_different_num_experts(key): + """Tests GatingNetwork with varying number of experts.""" + input_data = jnp.ones((3, 8)) # batch=3, features=8 + for num_exp in [1, 4, 10]: + gating_model = GatingNetwork(num_experts=num_exp) + params = gating_model.init(key, input_data)['params'] + gate_weights = gating_model.apply({'params': params}, input_data) + assert gate_weights.shape == (input_data.shape[0], num_exp) + assert jnp.allclose(jnp.sum(gate_weights, axis=-1), + jnp.ones(input_data.shape[0]), atol=1e-6) + +def test_mixture_of_experts_different_params(key): + """Tests MixtureOfExperts with varying numbers of experts and output dimensions.""" + input_data = jnp.ones((2, 12)) # batch=2, features=12 + configurations = [ + (2, 4), # num_experts, expert_output_dim + (4, 8), + (1, 6), + (3, 3) + ] + for num_exp, exp_out_dim in configurations: + moe_model = MixtureOfExperts(num_experts=num_exp, expert_output_dim=exp_out_dim) + params = moe_model.init(key, input_data)['params'] + final_output = moe_model.apply({'params': params}, input_data) + assert final_output.shape == (input_data.shape[0], exp_out_dim) From aebe9e41101b102f16633283e58bfee2ec802ce4 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 14 May 2025 10:04:47 -0300 Subject: [PATCH 7/8] Mixture of Experts implemnatation: all tests passed, pytest, ruff and mypy passed --- .../attention/rope_multi_head_attention.py | 58 +++++++++---------- .../attention/test_RoPEMultiHeadAttention.py | 11 ++-- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/jaxgarden/attention/rope_multi_head_attention.py b/jaxgarden/attention/rope_multi_head_attention.py index 7a0be73..14ecb1c 100644 --- a/jaxgarden/attention/rope_multi_head_attention.py +++ b/jaxgarden/attention/rope_multi_head_attention.py @@ -47,8 +47,9 @@ def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndar return (x * cos_emb) + (rotate_half(x) * sin_emb) -def precompute_rotary_embeddings(seq_len: int, head_dim: int, - base: float = 10000.0) -> tuple[jnp.ndarray, jnp.ndarray]: +def precompute_rotary_embeddings( + seq_len: int, head_dim: int, base: float = 10000.0 +) -> tuple[jnp.ndarray, jnp.ndarray]: """Precomputes the RoPE cosine and sine embeddings. Args: @@ -91,11 +92,11 @@ class RoPEMultiHeadAttention(nn.Module): rope_base: float = 10000.0 dtype: jnp.dtype = jnp.float32 - def setup(self) -> None: # Added -> None return type + def setup(self) -> None: # Added -> None return type """Initializes the attention projections.""" # Check head_dim validity early during setup if self.head_dim % 2 != 0: - raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") + raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") # Define layers here - they will be initialized when the module is first called total_head_dim = self.num_heads * self.head_dim @@ -109,13 +110,12 @@ def setup(self) -> None: # Added -> None return type features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj" ) self.output_proj = nn.Dense( - features=self.num_heads * self.head_dim, # Output should match embed_dim + features=self.num_heads * self.head_dim, # Output should match embed_dim use_bias=False, dtype=self.dtype, - name="output_proj" + name="output_proj", ) - @nn.compact # Also using Optional for the mask type hint for clarity with None default def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray: @@ -136,8 +136,7 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr if embed_dim != total_head_dim: raise ValueError( - f"embed_dim ({embed_dim}) must equal num_heads*head_dim" - f" ({total_head_dim})" + f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})" ) # Note: head_dim even check moved to setup for earlier failure @@ -159,7 +158,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr cos_emb = cos_emb.astype(self.dtype) sin_emb = sin_emb.astype(self.dtype) - # 4. Apply RoPE to Query and Key query = apply_rotary_pos_emb(query, cos_emb, sin_emb) key = apply_rotary_pos_emb(key, cos_emb, sin_emb) @@ -172,31 +170,35 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr # 6. Scaled Dot-Product Attention # Attention scores: (batch, num_heads, seq_len, seq_len) attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt( - self.head_dim).astype(self.dtype) # Ensure sqrt is correct dtype + self.head_dim + ).astype(self.dtype) # Ensure sqrt is correct dtype # Apply mask (if provided) if mask is not None: # Standard Flax causal mask is boolean (True means mask) # nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len) # Check if mask needs broadcasting or conversion - if mask.ndim == 2: # Likely (seq_len, seq_len) - mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) + if mask.ndim == 2: # Likely (seq_len, seq_len) + mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) elif mask.ndim == 3 and mask.shape[1] != self.num_heads: - # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) + # Likely (batch, seq_len, seq_len) or causal (1, sl, sl) mask = mask[:, None, :, :] - # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) + # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) # Ensure mask is broadcastable to attn_scores shape mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len) if mask.shape != mask_shape_expected: - # Attempt broadcasting common causal mask shapes - if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == (batch_size, 1, - seq_len, seq_len): # Causal mask for all batches/heads - mask = jnp.broadcast_to(mask, mask_shape_expected) - # Add other broadcasting cases if needed - else: - raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") - + # Attempt broadcasting common causal mask shapes + if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == ( + batch_size, + 1, + seq_len, + seq_len, + ): # Causal mask for all batches/heads + mask = jnp.broadcast_to(mask, mask_shape_expected) + # Add other broadcasting cases if needed + else: + raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") # Apply mask: Use large negative number where mask is True # (or where mask value is 0 if using 0/-inf convention) @@ -204,12 +206,10 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr # If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores) - # Softmax to get attention weights - attn_weights = jax.nn.softmax( - attn_scores, axis=-1 - ).astype(self.dtype) # Shape: (batch, num_heads, seq_len, seq_len) - + attn_weights = jax.nn.softmax(attn_scores, axis=-1).astype( + self.dtype + ) # Shape: (batch, num_heads, seq_len, seq_len) # Apply attention weights to Value # Output per head: (batch, num_heads, seq_len, head_dim) @@ -222,6 +222,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim) # Final linear projection - output = self.output_proj(attn_output) # Use self.output_proj defined in setup + output = self.output_proj(attn_output) # Use self.output_proj defined in setup return output diff --git a/tests/attention/test_RoPEMultiHeadAttention.py b/tests/attention/test_RoPEMultiHeadAttention.py index 85841da..0bca47a 100644 --- a/tests/attention/test_RoPEMultiHeadAttention.py +++ b/tests/attention/test_RoPEMultiHeadAttention.py @@ -65,9 +65,9 @@ def test_apply_rotary_pos_emb(): assert not jnp.allclose(rotated_x, x) - # --- Test RoPEMultiHeadAttention Module --- + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_rope_mha_forward_pass(dtype): """Tests the forward pass of RoPEMultiHeadAttention.""" @@ -99,7 +99,7 @@ def test_rope_mha_masking(): x = jax.random.normal(key, (batch_size, seq_len, embed_dim)) # Create a causal mask (True means masked) - causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq) + causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq) rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim) params = rope_mha.init(key, x, causal_mask)["params"] @@ -123,14 +123,11 @@ def test_rope_mha_errors(): x_dummy_odd = jax.random.normal(key, (2, 16, 8 * 7)) # Test with odd head_dim (should raise error during initialization/setup) with pytest.raises(ValueError, match=r"head_dim \(\d+\) must be even"): - rope_mha_odd_dim.init(key, x_dummy_odd) - # Test with mismatched embed_dim (should raise error during forward pass / init) - rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32 - x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim + rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32 + x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim with pytest.raises(ValueError, match=r"embed_dim \(\d+\) must equal"): rope_mha.init(key, x_mismatch) - From ab7bd61e2c4835f10470a7e08f95a158a4d32385 Mon Sep 17 00:00:00 2001 From: RubensZimbres Date: Wed, 14 May 2025 12:35:45 -0300 Subject: [PATCH 8/8] comments improved MoE --- jaxgarden/functional/mixture_of_experts.py | 7 ------- tests/functional/test_MoE.py | 17 +++++++++-------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/jaxgarden/functional/mixture_of_experts.py b/jaxgarden/functional/mixture_of_experts.py index 654ada6..423edde 100644 --- a/jaxgarden/functional/mixture_of_experts.py +++ b/jaxgarden/functional/mixture_of_experts.py @@ -61,18 +61,12 @@ class MixtureOfExperts(nn.Module): num_experts: int expert_output_dim: int - # You could also type self.experts and self.gating_network here for completeness, - # though mypy might not require it if they are only set in setup. - # experts: List[Expert] - # gating_network: GatingNetwork - def setup(self) -> None: """ Initialize the experts and the gating network. This method is called by Flax automatically. """ # List of Expert modules - # Using nn.scan or vmap can be more efficient for identical experts, self.experts = [Expert(num_outputs=self.expert_output_dim, name=f"expert_{i}") for i in range(self.num_experts)] # Create the GatingNetwork module @@ -103,7 +97,6 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # Stacked expert_outputs shape: (batch_size, num_experts, expert_output_dim) stacked_expert_outputs = jnp.stack(expert_outputs, axis=1) - # 3. Combine expert outputs using the gating weights # We want to weight each expert's output for each item in the batch. # Gate weights shape: (batch_size, num_experts) # Needs to be broadcast to: (batch_size, num_experts, expert_output_dim) diff --git a/tests/functional/test_MoE.py b/tests/functional/test_MoE.py index ae93eb5..656e8ee 100644 --- a/tests/functional/test_MoE.py +++ b/tests/functional/test_MoE.py @@ -14,7 +14,7 @@ def key(): def dummy_input_batch(): """Provides a dummy input batch.""" # (batch_size, input_dim) - return jnp.ones((4, 16)) # Batch of 4, input dimension 16 + return jnp.ones((4, 16)) def test_expert_initialization_and_call(key, dummy_input_batch): """Tests the Expert module initialization and forward pass.""" @@ -25,7 +25,7 @@ def test_expert_initialization_and_call(key, dummy_input_batch): output = expert_model.apply({'params': params}, dummy_input_batch) assert output.shape == (dummy_input_batch.shape[0], expert_output_dim) - assert output.dtype == jnp.float32 # Default dtype for Dense + assert output.dtype == jnp.float32 def test_gating_network_initialization_and_call(key, dummy_input_batch): """Tests the GatingNetwork module initialization and forward pass.""" @@ -60,15 +60,13 @@ def test_mixture_of_experts_output_logic(key, dummy_input_batch): """ num_experts = 2 expert_output_dim = 3 - # input_dim = dummy_input_batch.shape[1] # Not strictly needed for this test logic after init + # input_dim = dummy_input_batch.shape[1] # Create an MoE model moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim) variables = moe_model.init(key, dummy_input_batch) params = variables['params'] - # --- Manually compute expected output for a specific case --- - # Get gate weights by applying a GatingNetwork instance with its specific parameters gating_sub_model = GatingNetwork(num_experts=num_experts) gate_weights = gating_sub_model.apply({'params': params['gating_network']}, dummy_input_batch) @@ -133,7 +131,8 @@ def test_expert_different_output_dims(key): def test_gating_network_different_num_experts(key): """Tests GatingNetwork with varying number of experts.""" - input_data = jnp.ones((3, 8)) # batch=3, features=8 + # (batch,features) + input_data = jnp.ones((3, 8)) for num_exp in [1, 4, 10]: gating_model = GatingNetwork(num_experts=num_exp) params = gating_model.init(key, input_data)['params'] @@ -144,9 +143,11 @@ def test_gating_network_different_num_experts(key): def test_mixture_of_experts_different_params(key): """Tests MixtureOfExperts with varying numbers of experts and output dimensions.""" - input_data = jnp.ones((2, 12)) # batch=2, features=12 + # batch=2, features=12 + input_data = jnp.ones((2, 12)) + # num_experts, expert_output_dim configurations = [ - (2, 4), # num_experts, expert_output_dim + (2, 4), (4, 8), (1, 6), (3, 3)