-
Notifications
You must be signed in to change notification settings - Fork 10
Implementation of Mixture of Experts in JAX using Flax #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RubensZimbres
wants to merge
9
commits into
ml-gde:main
Choose a base branch
from
RubensZimbres:new-moe-branch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+325
−25
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
76515db
RoPEMultiHeadAttention implementation - all tests passed
RubensZimbres 0a64c34
mypy fixed for class RoPEMultiHeadAttention
RubensZimbres 8fb9b0d
useless comments removed
RubensZimbres 6bac808
README file updated with RoPEMultiHeadAttention usage
RubensZimbres ad55578
Docstring (arxiv) added to ROPE implementation
RubensZimbres 1d04a60
My local work: updated readme, added MoE implementation and tests
RubensZimbres aebe9e4
Mixture of Experts implemnatation: all tests passed, pytest, ruff and…
RubensZimbres ab7bd61
comments improved MoE
RubensZimbres 74800e1
Merge branch 'main' into new-moe-branch
RubensZimbres File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ wheels/ | |
|
||
# Virtual environments | ||
.venv | ||
jax_env | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
""" | ||
JAX/Flax implementation of Multi-Head Attention with Rotary Positional Embedding (RoPE). | ||
|
||
This code implements the RoPE technique within a standard Multi-Head Attention | ||
framework. RoPE injects relative positional information by rotating pairs of | ||
features in the Query and Key vectors based on their absolute position before | ||
the attention calculation. | ||
|
||
The method was introduced in the paper: | ||
"RoFormer: Enhanced Transformer with Rotary Position Embedding" | ||
by Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. | ||
arXiv:2104.09864v5 [cs.CL] (Submitted on 20 Apr 2021) | ||
""" | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def rotate_half(x: jnp.ndarray) -> jnp.ndarray: | ||
"""Rotates half the hidden dims of the input tensor.""" | ||
x1 = x[..., ::2] | ||
x2 = x[..., 1::2] | ||
# Builds the rotated tensor by concatenating the negated second half | ||
# and the first half along the last dimension. | ||
return jnp.concatenate((-x2, x1), axis=-1) | ||
|
||
|
||
def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndarray) -> jnp.ndarray: | ||
"""Applies Rotary Positional Embedding to the input tensor. | ||
|
||
Args: | ||
x: Input tensor, e.g., query or key (batch, seq_len, num_heads, head_dim) | ||
cos_emb: Cosine component of the positional embedding. | ||
Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. | ||
sin_emb: Sine component of the positional embedding. | ||
Shape: (1, seq_len, 1, head_dim) or compatible via broadcasting. | ||
Returns: | ||
Tensor with RoPE applied. | ||
""" | ||
# Applying the rotation formula: | ||
# x_rotated = x * cos(theta) + rotate_half(x) * sin(theta) | ||
# Ensure shapes are broadcastable: cos_emb and sin_emb should have dimensions | ||
# for sequence length and features, matching the corresponding dimensions in x. | ||
# Typically, precomputed embeddings have shape (seq_len, head_dim) | ||
# or (1, seq_len, 1, head_dim) for easy broadcasting. | ||
return (x * cos_emb) + (rotate_half(x) * sin_emb) | ||
|
||
|
||
def precompute_rotary_embeddings( | ||
seq_len: int, head_dim: int, base: float = 10000.0 | ||
) -> tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Precomputes the RoPE cosine and sine embeddings. | ||
|
||
Args: | ||
seq_len: The maximum sequence length. | ||
head_dim: The dimension of each attention head (must be even). | ||
base: The base value for the inverse frequency calculation. | ||
|
||
Returns: | ||
cos_emb: Cosine embeddings (1, seq_len, 1, head_dim) | ||
sin_emb: Sine embeddings (1, seq_len, 1, head_dim) | ||
""" | ||
if head_dim % 2 != 0: | ||
raise ValueError(f"head_dim must be even, got {head_dim}") | ||
|
||
# Calculate inverse frequencies (theta_i) | ||
# theta_i = 1 / (base^(2*i / head_dim)) for i in [0, 1, ..., head_dim/2 - 1] | ||
inv_freq = 1.0 / (base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)) | ||
|
||
# Calculate position indices (m) | ||
pos = jnp.arange(seq_len, dtype=jnp.float32) | ||
|
||
# Calculate angles (m * theta_i) | ||
freqs = jnp.outer(pos, inv_freq) # Shape: (seq_len, head_dim / 2) | ||
|
||
# Duplicate frequencies for the full head dimension (for both elements in pairs) | ||
emb = jnp.concatenate((freqs, freqs), axis=-1) # Shape: (seq_len, head_dim) | ||
|
||
# Calculate cosine and sine embeddings | ||
cos_emb = jnp.cos(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) | ||
sin_emb = jnp.sin(emb)[None, :, None, :] # Shape: (1, seq_len, 1, head_dim) | ||
|
||
return cos_emb, sin_emb | ||
|
||
|
||
class RoPEMultiHeadAttention(nn.Module): | ||
"""Multi-Head Attention with Rotary Positional Embeddings.""" | ||
|
||
num_heads: int | ||
head_dim: int | ||
rope_base: float = 10000.0 | ||
dtype: jnp.dtype = jnp.float32 | ||
|
||
def setup(self) -> None: # Added -> None return type | ||
"""Initializes the attention projections.""" | ||
# Check head_dim validity early during setup | ||
if self.head_dim % 2 != 0: | ||
raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.") | ||
|
||
# Define layers here - they will be initialized when the module is first called | ||
total_head_dim = self.num_heads * self.head_dim | ||
self.query_proj = nn.Dense( | ||
features=total_head_dim, use_bias=False, dtype=self.dtype, name="query_proj" | ||
) | ||
self.key_proj = nn.Dense( | ||
features=total_head_dim, use_bias=False, dtype=self.dtype, name="key_proj" | ||
) | ||
self.value_proj = nn.Dense( | ||
features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj" | ||
) | ||
self.output_proj = nn.Dense( | ||
features=self.num_heads * self.head_dim, # Output should match embed_dim | ||
use_bias=False, | ||
dtype=self.dtype, | ||
name="output_proj", | ||
) | ||
|
||
@nn.compact | ||
# Also using Optional for the mask type hint for clarity with None default | ||
def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray: | ||
"""Forward pass for RoPE MHA. | ||
|
||
Args: | ||
x: Input tensor (batch_size, seq_len, embed_dim). | ||
mask: Optional attention mask (batch_size, 1, seq_len, seq_len) | ||
or (batch_size, 1, 1, seq_len) for causal masking. | ||
Mask values should be 0 where attention is allowed, -inf otherwise. | ||
Flax convention often uses boolean masks (True=masked). We'll handle both. | ||
|
||
Returns: | ||
Output tensor (batch_size, seq_len, embed_dim). | ||
""" | ||
batch_size, seq_len, embed_dim = x.shape | ||
total_head_dim = self.num_heads * self.head_dim | ||
|
||
if embed_dim != total_head_dim: | ||
raise ValueError( | ||
f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})" | ||
) | ||
# Note: head_dim even check moved to setup for earlier failure | ||
|
||
# 1. Linear projections for Q, K, V | ||
query = self.query_proj(x) | ||
key = self.key_proj(x) | ||
value = self.value_proj(x) | ||
|
||
# 2. Reshape for multi-head processing | ||
# (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim) | ||
query = query.reshape(batch_size, seq_len, self.num_heads, self.head_dim) | ||
key = key.reshape(batch_size, seq_len, self.num_heads, self.head_dim) | ||
value = value.reshape(batch_size, seq_len, self.num_heads, self.head_dim) | ||
|
||
# 3. Precompute RoPE embeddings (cosine and sine) | ||
# We compute them dynamically based on the input sequence length | ||
cos_emb, sin_emb = precompute_rotary_embeddings(seq_len, self.head_dim, base=self.rope_base) | ||
# Ensure RoPE embeddings have correct dtype | ||
cos_emb = cos_emb.astype(self.dtype) | ||
sin_emb = sin_emb.astype(self.dtype) | ||
|
||
# 4. Apply RoPE to Query and Key | ||
query = apply_rotary_pos_emb(query, cos_emb, sin_emb) | ||
key = apply_rotary_pos_emb(key, cos_emb, sin_emb) | ||
|
||
# 5. Transpose for attention calculation: (batch, num_heads, seq_len, head_dim) | ||
query = query.transpose((0, 2, 1, 3)) | ||
key = key.transpose((0, 2, 1, 3)) | ||
value = value.transpose((0, 2, 1, 3)) | ||
|
||
# 6. Scaled Dot-Product Attention | ||
# Attention scores: (batch, num_heads, seq_len, seq_len) | ||
attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt( | ||
self.head_dim | ||
).astype(self.dtype) # Ensure sqrt is correct dtype | ||
|
||
# Apply mask (if provided) | ||
if mask is not None: | ||
# Standard Flax causal mask is boolean (True means mask) | ||
# nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len) | ||
# Check if mask needs broadcasting or conversion | ||
if mask.ndim == 2: # Likely (seq_len, seq_len) | ||
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) | ||
elif mask.ndim == 3 and mask.shape[1] != self.num_heads: | ||
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl) | ||
mask = mask[:, None, :, :] | ||
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) | ||
|
||
# Ensure mask is broadcastable to attn_scores shape | ||
mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len) | ||
if mask.shape != mask_shape_expected: | ||
# Attempt broadcasting common causal mask shapes | ||
if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == ( | ||
batch_size, | ||
1, | ||
seq_len, | ||
seq_len, | ||
): # Causal mask for all batches/heads | ||
mask = jnp.broadcast_to(mask, mask_shape_expected) | ||
# Add other broadcasting cases if needed | ||
else: | ||
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") | ||
|
||
|
||
# Apply mask: Use large negative number where mask is True | ||
# (or where mask value is 0 if using 0/-inf convention) | ||
# Assuming boolean mask convention (True = mask) common in Flax examples | ||
# If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask | ||
attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores) | ||
|
||
# Softmax to get attention weights | ||
attn_weights = jax.nn.softmax(attn_scores, axis=-1).astype( | ||
self.dtype | ||
) # Shape: (batch, num_heads, seq_len, seq_len) | ||
|
||
# Apply attention weights to Value | ||
# Output per head: (batch, num_heads, seq_len, head_dim) | ||
attn_output = jnp.matmul(attn_weights, value) | ||
|
||
# 7. Concatenate heads and final projection | ||
# Transpose back: (batch, seq_len, num_heads, head_dim) | ||
attn_output = attn_output.transpose((0, 2, 1, 3)) | ||
# Reshape to (batch, seq_len, embed_dim) | ||
attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim) | ||
|
||
# Final linear projection | ||
output = self.output_proj(attn_output) # Use self.output_proj defined in setup | ||
|
||
return output |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic assumes that if
mask.ndim == 3
andmask.shape[1] != self.num_heads
, then the mask is either(batch, seq_len, seq_len)
or causal(1, seq_len, seq_len)
. However, it's possible that the mask has a different shape that is also incompatible withattn_scores
. Consider adding a more robust check to ensure that the mask shape is compatible before attempting to broadcast it. Also, consider adding a comment to explain the assumption.