diff --git a/examples/nanogpt_example.py b/examples/nanogpt_example.py new file mode 100644 index 0000000..5930a0f --- /dev/null +++ b/examples/nanogpt_example.py @@ -0,0 +1,126 @@ +"""Example script demonstrating NanoGPT usage.""" + +import os +import time +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from jax_layers.nanogpt.config import Config +from jax_layers.nanogpt.data import create_dataset, get_shakespeare +from jax_layers.nanogpt.model import GPT +from jax_layers.nanogpt.train import ( + TrainState, + create_train_state, + create_learning_rate_schedule, + eval_step, + train_step, +) + +def estimate_loss( + state: TrainState, + eval_data: Tuple[jnp.ndarray, jnp.ndarray], + eval_iters: int, +) -> Dict[str, float]: + """Estimate loss on evaluation data.""" + out = {} + for k in range(eval_iters): + batch = eval_data[0][k], eval_data[1][k] + metrics = eval_step(state, batch) + for k, v in metrics.items(): + out[k] = out.get(k, 0.0) + v + for k in out: + out[k] /= eval_iters + return out + +def main(): + # Initialize configuration + config = Config() + + # Set random seed + jax.random.PRNGKey(config.train.seed) + + # Create output directory + os.makedirs(config.out_dir, exist_ok=True) + + # Load and prepare data + print('Loading Shakespeare dataset...') + text = get_shakespeare() + train_data, val_data = create_dataset( + text=text, + block_size=config.model.block_size, + batch_size=config.train.batch_size, + ) + + # Initialize model and training state + print('Initializing model...') + model = GPT( + vocab_size=config.model.vocab_size, + block_size=config.model.block_size, + n_layer=config.model.n_layer, + n_head=config.model.n_head, + n_embd=config.model.n_embd, + dropout=config.model.dropout, + dtype=getattr(jnp, config.model.dtype), + ) + + # Create learning rate schedule + lr_schedule = create_learning_rate_schedule( + learning_rate=config.train.learning_rate, + warmup_steps=config.train.warmup_iters, + total_steps=config.train.max_iters, + ) + + # Initialize training state + key = jax.random.PRNGKey(config.train.seed) + dropout_rng, key = jax.random.split(key) + state = create_train_state( + model=model, + learning_rate=lr_schedule, + dropout_rng=dropout_rng, + key=key, + ) + + # Training loop + print('Starting training...') + best_val_loss = float('inf') + t0 = time.time() + + for iter_num in tqdm(range(config.train.max_iters)): + # Determine and set the learning rate for this iteration + lr = lr_schedule(iter_num) if config.train.decay_lr else config.train.learning_rate + state = state.replace(opt_state=state.opt_state.replace(learning_rate=lr)) + + # Sample a batch of data + batch = train_data[0][iter_num % len(train_data[0])], train_data[1][iter_num % len(train_data[1])] + + # Evaluate the loss on train/val sets + if iter_num % config.train.eval_interval == 0: + train_losses = estimate_loss(state, train_data, config.train.eval_iters) + val_losses = estimate_loss(state, val_data, config.train.eval_iters) + print(f'iter {iter_num}: train loss {train_losses["loss"]:.4f}, val loss {val_losses["loss"]:.4f}') + + # Save best model + if val_losses['loss'] < best_val_loss: + best_val_loss = val_losses['loss'] + if iter_num > 0: + checkpoint = {'model': state.params} + with open(os.path.join(config.out_dir, 'best.ckpt'), 'wb') as f: + jax.serialization.save(f, checkpoint) + + # Forward backward update + state, metrics = train_step(state, batch) + + # Timing and logging + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if iter_num % config.train.eval_interval == 0: + print(f'iter {iter_num}: loss {metrics["loss"]:.4f}, time {dt*1000:.2f}ms') + + print('Training completed!') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/jax_layers/nanogpt/__init__.py b/jax_layers/nanogpt/__init__.py new file mode 100644 index 0000000..5db5829 --- /dev/null +++ b/jax_layers/nanogpt/__init__.py @@ -0,0 +1,30 @@ +"""NanoGPT implementation using JAX and Flax.""" + +from .config import Config, GPTConfig, TrainConfig +from .data import Tokenizer, create_dataset, get_batch, get_shakespeare +from .model import GPT, Block, CausalSelfAttention +from .train import ( + TrainState, + create_train_state, + create_learning_rate_schedule, + eval_step, + train_step, +) + +__all__ = [ + 'Config', + 'GPTConfig', + 'TrainConfig', + 'Tokenizer', + 'create_dataset', + 'get_batch', + 'get_shakespeare', + 'GPT', + 'Block', + 'CausalSelfAttention', + 'TrainState', + 'create_train_state', + 'create_learning_rate_schedule', + 'eval_step', + 'train_step', +] \ No newline at end of file diff --git a/jax_layers/nanogpt/config.py b/jax_layers/nanogpt/config.py new file mode 100644 index 0000000..2dec949 --- /dev/null +++ b/jax_layers/nanogpt/config.py @@ -0,0 +1,44 @@ +"""Configuration for NanoGPT.""" + +from dataclasses import dataclass +from typing import Optional + +@dataclass +class GPTConfig: + """GPT model configuration.""" + + vocab_size: int = 256 + block_size: int = 128 + n_layer: int = 6 + n_head: int = 6 + n_embd: int = 384 + dropout: float = 0.1 + dtype: str = 'float32' + +@dataclass +class TrainConfig: + """Training configuration.""" + + batch_size: int = 64 + learning_rate: float = 3e-4 + max_iters: int = 5000 + eval_interval: int = 500 + eval_iters: int = 200 + warmup_iters: int = 2000 + weight_decay: float = 0.1 + beta1: float = 0.9 + beta2: float = 0.95 + grad_clip: float = 1.0 + decay_lr: bool = True + min_lr: float = 3e-5 + device: str = 'cpu' + seed: int = 42 + +@dataclass +class Config: + """Main configuration.""" + + model: GPTConfig = GPTConfig() + train: TrainConfig = TrainConfig() + out_dir: str = 'out' + resume: Optional[str] = None \ No newline at end of file diff --git a/jax_layers/nanogpt/data.py b/jax_layers/nanogpt/data.py new file mode 100644 index 0000000..d6c1eec --- /dev/null +++ b/jax_layers/nanogpt/data.py @@ -0,0 +1,91 @@ +"""Data processing utilities for NanoGPT.""" + +import os +from typing import Dict, List, Optional, Tuple + +import jax.numpy as jnp +import numpy as np +import regex as re +from tqdm import tqdm + +class Tokenizer: + """Simple character-level tokenizer.""" + + def __init__(self, vocab_size: int = 256): + self.vocab_size = vocab_size + self.chars = [chr(i) for i in range(vocab_size)] + self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)} + self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)} + + def encode(self, text: str) -> List[int]: + """Encode text to token ids.""" + return [self.char_to_idx.get(ch, 0) for ch in text] + + def decode(self, ids: List[int]) -> str: + """Decode token ids to text.""" + return ''.join(self.idx_to_char.get(i, '') for i in ids) + +def get_shakespeare() -> str: + """Download and load Shakespeare dataset.""" + input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') + + if not os.path.exists(input_file_path): + data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' + import urllib.request + print(f'Downloading Shakespeare dataset to {input_file_path}...') + urllib.request.urlretrieve(data_url, input_file_path) + + with open(input_file_path, 'r', encoding='utf-8') as f: + return f.read() + +def create_dataset( + text: str, + block_size: int, + batch_size: int, + split: float = 0.9, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Create train and validation datasets from text.""" + # Create tokenizer and encode text + tokenizer = Tokenizer() + data = np.array(tokenizer.encode(text)) + + # Split into train and validation sets + n = int(split * len(data)) + train_data = data[:n] + val_data = data[n:] + + def get_batches(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + # Create input and target sequences + x = [] + y = [] + for i in range(0, len(data) - block_size): + x.append(data[i:i + block_size]) + y.append(data[i + 1:i + block_size + 1]) + + # Stack into batches + x = np.stack(x) + y = np.stack(y) + + # Shuffle and create batches + indices = np.random.permutation(len(x)) + x = x[indices] + y = y[indices] + + n_batches = len(x) // batch_size + x = x[:n_batches * batch_size].reshape(n_batches, batch_size, -1) + y = y[:n_batches * batch_size].reshape(n_batches, batch_size, -1) + + return x, y + + train_x, train_y = get_batches(train_data) + val_x, val_y = get_batches(val_data) + + return train_x, train_y, val_x, val_y + +def get_batch( + x: np.ndarray, + y: np.ndarray, + batch_idx: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Get a single batch from the dataset.""" + return jnp.array(x[batch_idx]), jnp.array(y[batch_idx]) \ No newline at end of file diff --git a/jax_layers/nanogpt/model.py b/jax_layers/nanogpt/model.py new file mode 100644 index 0000000..ae0f27a --- /dev/null +++ b/jax_layers/nanogpt/model.py @@ -0,0 +1,108 @@ +"""NanoGPT model implementation using JAX and Flax.""" + +from typing import Optional, Tuple, Any, Dict + +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax import lax +from jax_layers.attention import MultiHeadAttention + +class CausalSelfAttention(nn.Module): + """Causal self-attention layer.""" + + n_head: int + n_embd: int + dropout: float = 0.1 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, x: jnp.ndarray, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + # Create causal mask if not provided + if mask is None: + B, T = x.shape[:2] + mask = jnp.tril(jnp.ones((B, 1, T, T))) + + # Apply attention + attn = MultiHeadAttention( + num_heads=self.n_head, + in_features=self.n_embd, + dropout_rate=self.dropout, + dtype=self.dtype, + implementation="flash", # Use Flash Attention if available + ) + return attn(x, mask=mask) + +class Block(nn.Module): + """Transformer block.""" + + n_head: int + n_embd: int + dropout: float = 0.1 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, x: jnp.ndarray, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + # attention + y = nn.LayerNorm(dtype=self.dtype)(x) + y = CausalSelfAttention(self.n_head, self.n_embd, self.dropout, self.dtype)(y, mask) + x = x + y + + # mlp + y = nn.LayerNorm(dtype=self.dtype)(x) + y = nn.Dense(4 * self.n_embd, dtype=self.dtype)(y) + y = jax.nn.gelu(y) + y = nn.Dense(self.n_embd, dtype=self.dtype)(y) + x = x + y + return x + +class GPT(nn.Module): + """GPT model.""" + + vocab_size: int + block_size: int + n_layer: int + n_head: int + n_embd: int + dropout: float = 0.1 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, idx: jnp.ndarray, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: + B, T = idx.shape + + # token and position embeddings + tok_emb = nn.Embed(self.vocab_size, self.n_embd, dtype=self.dtype)(idx) + pos = jnp.arange(0, T, dtype=jnp.int32)[None, :] # shape (1, T) + pos_emb = nn.Embed(self.block_size, self.n_embd, dtype=self.dtype)(pos) + x = nn.Dropout(rate=self.dropout)(tok_emb + pos_emb, deterministic=True) + + # transformer blocks + for _ in range(self.n_layer): + x = Block(self.n_head, self.n_embd, self.dropout, self.dtype)(x, mask) + + # final layer norm + x = nn.LayerNorm(dtype=self.dtype)(x) + + # language modeling head + logits = nn.Dense(self.vocab_size, dtype=self.dtype)(x) + return logits + + def generate(self, idx: jnp.ndarray, max_new_tokens: int, temperature: float = 1.0) -> jnp.ndarray: + """Generate new tokens given a sequence.""" + for _ in range(max_new_tokens): + # crop context if needed + idx_cond = idx if idx.shape[1] <= self.block_size else idx[:, -self.block_size:] + + # get predictions + logits = self(idx_cond) + logits = logits[:, -1, :] / temperature + + # sample from the distribution + probs = jax.nn.softmax(logits, axis=-1) + idx_next = jax.random.categorical(jax.random.PRNGKey(0), probs) + + # append sampled index to the sequence + idx = jnp.concatenate((idx, idx_next[:, None]), axis=1) + + return idx \ No newline at end of file diff --git a/jax_layers/nanogpt/train.py b/jax_layers/nanogpt/train.py new file mode 100644 index 0000000..9ed12c6 --- /dev/null +++ b/jax_layers/nanogpt/train.py @@ -0,0 +1,101 @@ +"""Training utilities for NanoGPT.""" + +from typing import Any, Dict, Optional, Tuple, cast + +import jax +import jax.numpy as jnp +import optax +from flax import training +from flax.core import FrozenDict + +from .model import GPT + +class TrainState(training.TrainState): + """Training state for the GPT model.""" + dropout_rng: jax.random.PRNGKey + key: jax.random.PRNGKey + +def create_train_state( + model: GPT, + learning_rate: float, + dropout_rng: jax.random.PRNGKey, + key: jax.random.PRNGKey, +) -> TrainState: + """Create initial training state.""" + params = model.init(key, jnp.ones((1, 1), dtype=jnp.int32))['params'] + tx = optax.adamw(learning_rate=learning_rate) + return TrainState.create( + apply_fn=model.apply, + params=params, + tx=tx, + dropout_rng=dropout_rng, + key=key, + ) + +def train_step( + state: TrainState, + batch: Tuple[jnp.ndarray, jnp.ndarray], +) -> Tuple[TrainState, Dict[str, float]]: + """Perform a single training step.""" + inputs, targets = batch + + def loss_fn(params: FrozenDict[str, Any]) -> Tuple[float, jnp.ndarray]: + logits = state.apply_fn( + {'params': params}, + inputs, + rngs={'dropout': state.dropout_rng}, + ) + loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.params) + + # Update state + state = state.apply_gradients(grads=grads) + + # Generate new dropout rng + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + state = state.replace(dropout_rng=new_dropout_rng) + + # Compute metrics + accuracy = (jnp.argmax(logits, axis=-1) == targets).mean() + + return state, { + 'loss': loss, + 'accuracy': accuracy, + } + +def eval_step( + state: TrainState, + batch: Tuple[jnp.ndarray, jnp.ndarray], +) -> Dict[str, float]: + """Perform a single evaluation step.""" + inputs, targets = batch + + logits = state.apply_fn( + {'params': state.params}, + inputs, + rngs={'dropout': jax.random.PRNGKey(0)}, # Use fixed rng for eval + ) + loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + accuracy = (jnp.argmax(logits, axis=-1) == targets).mean() + + return { + 'loss': loss, + 'accuracy': accuracy, + } + +def create_learning_rate_schedule( + learning_rate: float, + warmup_steps: int, + total_steps: int, +) -> optax.Schedule: + """Create a learning rate schedule with warmup.""" + return optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=learning_rate, + warmup_steps=warmup_steps, + decay_steps=total_steps - warmup_steps, + end_value=0.0, + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 83ce2d9..47d4e3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "jax>=0.4.34,<0.5", - "flax>=0.8.3", + "jax>=0.4.1,<0.5.0", + "flax>=0.6.10,<0.7.0", + "optax>=0.1.7,<0.2.0", ] [project.optional-dependencies]