Skip to content
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,39 @@ params = attention.init(key, x)['params']
output = attention.apply({'params': params}, x)
```

### Mixture of Experts (Flax NNX)

```python
import jax
import jax.numpy as jnp
import flax.linen as nn
from jaxgarden.functional.MoE import MixtureOfExperts

# 1. Setup
batch_size = 4
input_dim = 10
num_experts = 3
expert_output_dim = 5
key = jax.random.PRNGKey(0)
dummy_input = jax.random.normal(key, (batch_size, input_dim))

# 2. Instantiate Module
moe_model = MixtureOfExperts(num_experts=num_experts, expert_output_dim=expert_output_dim)

# 3. Initialize the model parameters (weights and biases)
key, params_key = jax.random.split(key)
params = moe_model.init(params_key, dummy_input)['params']

print("Initialized MoE parameters:", jax.tree_util.tree_map(lambda x: x.shape, params))

# 4. Apply Module (Forward Pass)
output = moe_model.apply({'params': params}, dummy_input)

print("\nInput shape:", dummy_input.shape)
print("Output shape:", output.shape)
```


### Functional API

#### Dot Product Attention with Implementation Selection
Expand Down
35 changes: 16 additions & 19 deletions jaxgarden/attention/rope_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndar
return (x * cos_emb) + (rotate_half(x) * sin_emb)


def precompute_rotary_embeddings(seq_len: int, head_dim: int,
base: float = 10000.0) -> tuple[jnp.ndarray, jnp.ndarray]:
def precompute_rotary_embeddings(
seq_len: int, head_dim: int, base: float = 10000.0
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Precomputes the RoPE cosine and sine embeddings.

Args:
Expand Down Expand Up @@ -91,11 +92,11 @@ class RoPEMultiHeadAttention(nn.Module):
rope_base: float = 10000.0
dtype: jnp.dtype = jnp.float32

def setup(self) -> None: # Added -> None return type
def setup(self) -> None: # Added -> None return type
"""Initializes the attention projections."""
# Check head_dim validity early during setup
if self.head_dim % 2 != 0:
raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.")
raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.")

# Define layers here - they will be initialized when the module is first called
total_head_dim = self.num_heads * self.head_dim
Expand All @@ -109,13 +110,12 @@ def setup(self) -> None: # Added -> None return type
features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj"
)
self.output_proj = nn.Dense(
features=self.num_heads * self.head_dim, # Output should match embed_dim
features=self.num_heads * self.head_dim, # Output should match embed_dim
use_bias=False,
dtype=self.dtype,
name="output_proj"
name="output_proj",
)


@nn.compact
# Also using Optional for the mask type hint for clarity with None default
def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray:
Expand All @@ -136,8 +136,7 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr

if embed_dim != total_head_dim:
raise ValueError(
f"embed_dim ({embed_dim}) must equal num_heads*head_dim"
f" ({total_head_dim})"
f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})"
)
# Note: head_dim even check moved to setup for earlier failure

Expand All @@ -159,7 +158,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
cos_emb = cos_emb.astype(self.dtype)
sin_emb = sin_emb.astype(self.dtype)


# 4. Apply RoPE to Query and Key
query = apply_rotary_pos_emb(query, cos_emb, sin_emb)
key = apply_rotary_pos_emb(key, cos_emb, sin_emb)
Expand All @@ -179,12 +177,12 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
# Standard Flax causal mask is boolean (True means mask)
# nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
# Check if mask needs broadcasting or conversion
if mask.ndim == 2: # Likely (seq_len, seq_len)
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
if mask.ndim == 2: # Likely (seq_len, seq_len)
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
elif mask.ndim == 3 and mask.shape[1] != self.num_heads:
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
mask = mask[:, None, :, :]
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)

# Ensure mask is broadcastable to attn_scores shape
mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len)
Expand All @@ -197,19 +195,17 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
else:
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")


# Apply mask: Use large negative number where mask is True
# (or where mask value is 0 if using 0/-inf convention)
# Assuming boolean mask convention (True = mask) common in Flax examples
# If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask
attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores)


# Softmax to get attention weights
# Shape: (batch, num_heads, seq_len, seq_len)
attn_weights = jax.nn.softmax(
attn_scores, axis=-1
).astype(self.dtype) # Shape: (batch, num_heads, seq_len, seq_len)

).astype(self.dtype)

# Apply attention weights to Value
# Output per head: (batch, num_heads, seq_len, head_dim)
Expand All @@ -222,6 +218,7 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim)

# Final linear projection
output = self.output_proj(attn_output) # Use self.output_proj defined in setup
output = self.output_proj(attn_output) # Use self.output_proj defined in setup


return output
113 changes: 113 additions & 0 deletions jaxgarden/functional/mixture_of_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
JAX/Flax implementation of a Mixture of Experts (MoE) layer.

This code provides a conceptual implementation of a Mixture of Experts layer,
a neural network architecture where multiple specialized "expert" sub-networks
are combined. A gating network determines which expert (or combination of
experts) processes a given input, allowing the model to learn to route
different parts of the input space to specialized modules. This can lead to
models with higher capacity and better efficiency, especially in sparse
formulations where only a subset of experts are activated per input.

The core concept of Mixture of Experts was introduced in the paper:
"Adaptive Mixtures of Local Experts"
by Robert A. Jacobs, Michael I. Jordan, Steven J. Nowlan, and Geoffrey E. Hinton.
Published in Neural Computation, Volume 3, Issue 1, Pages 79-87, 1991.
Available at: https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf
"""

import flax.linen as nn
import jax.numpy as jnp


# Expert Network
class Expert(nn.Module):
num_outputs: int

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
A simple feed-forward expert network.
Args:
x: Input tensor.
Returns:
Output tensor from the expert.
"""
x = nn.Dense(features=self.num_outputs * 2, name="expert_dense_1")(x)
x = nn.relu(x)
x = nn.Dense(features=self.num_outputs, name="expert_dense_2")(x)
return x

# Gating Network
class GatingNetwork(nn.Module):
num_experts: int # The number of experts to choose from

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
A simple gating network that outputs weights for each expert.
Args:
x: Input tensor.
Returns:
A tensor of weights for each expert (after softmax).
"""
# The gating network is often a linear layer followed by a softmax
gate_logits = nn.Dense(features=self.num_experts, name="gating_dense")(x)
gate_weights = nn.softmax(gate_logits, axis=-1)
return gate_weights

# Mixture of Experts Layer
class MixtureOfExperts(nn.Module):
num_experts: int
expert_output_dim: int

def setup(self) -> None:
"""
Initialize the experts and the gating network.
This method is called by Flax automatically.
"""
# List of Expert modules
self.experts = [Expert(num_outputs=self.expert_output_dim,
name=f"expert_{i}") for i in range(self.num_experts)]
# Create the GatingNetwork module
self.gating_network = GatingNetwork(num_experts=self.num_experts, name="gating_network")

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Forward pass for the Mixture of Experts layer.
Args:
x: Input tensor.
Returns:
The combined output from the experts.
"""
# 1. Get the gating weights
# Input shape: (batch_size, input_dim)
# Gate weights shape: (batch_size, num_experts)
gate_weights = self.gating_network(x)

# 2. Get the outputs from all experts
# We'll store expert outputs in a list
expert_outputs = []
for i in range(self.num_experts):
# Expert output shape: (batch_size, expert_output_dim)
expert_out = self.experts[i](x)
expert_outputs.append(expert_out)

# Stack expert outputs along a new dimension to facilitate weighted sum
# Stacked expert_outputs shape: (batch_size, num_experts, expert_output_dim)
stacked_expert_outputs = jnp.stack(expert_outputs, axis=1)

# We want to weight each expert's output for each item in the batch.
# Gate weights shape: (batch_size, num_experts)
# Needs to be broadcast to: (batch_size, num_experts, expert_output_dim)
# to multiply with stacked_expert_outputs.
expanded_gate_weights = jnp.expand_dims(gate_weights, axis=-1)

# Weighted outputs shape: (batch_size, num_experts, expert_output_dim)
weighted_expert_outputs = stacked_expert_outputs * expanded_gate_weights

# Sum the weighted outputs along the num_experts dimension
# Final output shape: (batch_size, expert_output_dim)
final_output = jnp.sum(weighted_expert_outputs, axis=1)

return final_output
10 changes: 4 additions & 6 deletions tests/attention/test_RoPEMultiHeadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def test_apply_rotary_pos_emb():
assert not jnp.allclose(rotated_x, x)



# --- Test RoPEMultiHeadAttention Module ---

@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
Expand Down Expand Up @@ -99,7 +98,8 @@ def test_rope_mha_masking():

x = jax.random.normal(key, (batch_size, seq_len, embed_dim))
# Create a causal mask (True means masked)
causal_mask = nn.make_causal_mask(x[:, :, 0]) # Gets (batch, seq, seq) or (1, seq, seq)
causal_mask = nn.make_causal_mask(x[:, :, 0])
# Gets (batch, seq, seq) or (1, seq, seq)

rope_mha = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim)
params = rope_mha.init(key, x, causal_mask)["params"]
Expand All @@ -126,11 +126,9 @@ def test_rope_mha_errors():

rope_mha_odd_dim.init(key, x_dummy_odd)


# Test with mismatched embed_dim (should raise error during forward pass / init)
rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32
x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim
rope_mha = RoPEMultiHeadAttention(num_heads=4, head_dim=8) # Expects embed_dim 32
x_mismatch = jax.random.normal(key, (2, 16, 100)) # Incorrect embed_dim

with pytest.raises(ValueError, match=r"embed_dim \(\d+\) must equal"):
rope_mha.init(key, x_mismatch)

Loading
Loading