Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions examples/nanogpt_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Example script demonstrating NanoGPT usage."""

import os
import time
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
from tqdm import tqdm

from jax_layers.nanogpt.config import Config
from jax_layers.nanogpt.data import create_dataset, get_shakespeare
from jax_layers.nanogpt.model import GPT
from jax_layers.nanogpt.train import (
TrainState,
create_train_state,
create_learning_rate_schedule,
eval_step,
train_step,
)

def estimate_loss(
state: TrainState,
eval_data: Tuple[jnp.ndarray, jnp.ndarray],
eval_iters: int,
) -> Dict[str, float]:
"""Estimate loss on evaluation data."""
out = {}
for k in range(eval_iters):
batch = eval_data[0][k], eval_data[1][k]
metrics = eval_step(state, batch)
for k, v in metrics.items():
out[k] = out.get(k, 0.0) + v
for k in out:
out[k] /= eval_iters
return out

def main():
# Initialize configuration
config = Config()

# Set random seed
jax.random.PRNGKey(config.train.seed)

# Create output directory
os.makedirs(config.out_dir, exist_ok=True)

# Load and prepare data
print('Loading Shakespeare dataset...')
text = get_shakespeare()
train_data, val_data = create_dataset(
text=text,
block_size=config.model.block_size,
batch_size=config.train.batch_size,
)

# Initialize model and training state
print('Initializing model...')
model = GPT(
vocab_size=config.model.vocab_size,
block_size=config.model.block_size,
n_layer=config.model.n_layer,
n_head=config.model.n_head,
n_embd=config.model.n_embd,
dropout=config.model.dropout,
dtype=getattr(jnp, config.model.dtype),
)

# Create learning rate schedule
lr_schedule = create_learning_rate_schedule(
learning_rate=config.train.learning_rate,
warmup_steps=config.train.warmup_iters,
total_steps=config.train.max_iters,
)

# Initialize training state
key = jax.random.PRNGKey(config.train.seed)
dropout_rng, key = jax.random.split(key)
state = create_train_state(
model=model,
learning_rate=lr_schedule,
dropout_rng=dropout_rng,
key=key,
)

# Training loop
print('Starting training...')
best_val_loss = float('inf')
t0 = time.time()

for iter_num in tqdm(range(config.train.max_iters)):
# Determine and set the learning rate for this iteration
lr = lr_schedule(iter_num) if config.train.decay_lr else config.train.learning_rate
state = state.replace(opt_state=state.opt_state.replace(learning_rate=lr))

# Sample a batch of data
batch = train_data[0][iter_num % len(train_data[0])], train_data[1][iter_num % len(train_data[1])]

# Evaluate the loss on train/val sets
if iter_num % config.train.eval_interval == 0:
train_losses = estimate_loss(state, train_data, config.train.eval_iters)
val_losses = estimate_loss(state, val_data, config.train.eval_iters)
print(f'iter {iter_num}: train loss {train_losses["loss"]:.4f}, val loss {val_losses["loss"]:.4f}')

# Save best model
if val_losses['loss'] < best_val_loss:
best_val_loss = val_losses['loss']
if iter_num > 0:
checkpoint = {'model': state.params}
with open(os.path.join(config.out_dir, 'best.ckpt'), 'wb') as f:
jax.serialization.save(f, checkpoint)

# Forward backward update
state, metrics = train_step(state, batch)

# Timing and logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % config.train.eval_interval == 0:
print(f'iter {iter_num}: loss {metrics["loss"]:.4f}, time {dt*1000:.2f}ms')

print('Training completed!')

if __name__ == '__main__':
main()
30 changes: 30 additions & 0 deletions jax_layers/nanogpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""NanoGPT implementation using JAX and Flax."""

from .config import Config, GPTConfig, TrainConfig
from .data import Tokenizer, create_dataset, get_batch, get_shakespeare
from .model import GPT, Block, CausalSelfAttention
from .train import (
TrainState,
create_train_state,
create_learning_rate_schedule,
eval_step,
train_step,
)

__all__ = [
'Config',
'GPTConfig',
'TrainConfig',
'Tokenizer',
'create_dataset',
'get_batch',
'get_shakespeare',
'GPT',
'Block',
'CausalSelfAttention',
'TrainState',
'create_train_state',
'create_learning_rate_schedule',
'eval_step',
'train_step',
]
44 changes: 44 additions & 0 deletions jax_layers/nanogpt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Configuration for NanoGPT."""

from dataclasses import dataclass
from typing import Optional

@dataclass
class GPTConfig:
"""GPT model configuration."""

vocab_size: int = 256
block_size: int = 128
n_layer: int = 6
n_head: int = 6
n_embd: int = 384
dropout: float = 0.1
dtype: str = 'float32'

@dataclass
class TrainConfig:
"""Training configuration."""

batch_size: int = 64
learning_rate: float = 3e-4
max_iters: int = 5000
eval_interval: int = 500
eval_iters: int = 200
warmup_iters: int = 2000
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0
decay_lr: bool = True
min_lr: float = 3e-5
device: str = 'cpu'
seed: int = 42

@dataclass
class Config:
"""Main configuration."""

model: GPTConfig = GPTConfig()
train: TrainConfig = TrainConfig()
out_dir: str = 'out'
resume: Optional[str] = None
91 changes: 91 additions & 0 deletions jax_layers/nanogpt/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Data processing utilities for NanoGPT."""

import os
from typing import Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import regex as re
from tqdm import tqdm

class Tokenizer:
"""Simple character-level tokenizer."""

def __init__(self, vocab_size: int = 256):
self.vocab_size = vocab_size
self.chars = [chr(i) for i in range(vocab_size)]
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}

def encode(self, text: str) -> List[int]:
"""Encode text to token ids."""
return [self.char_to_idx.get(ch, 0) for ch in text]

def decode(self, ids: List[int]) -> str:
"""Decode token ids to text."""
return ''.join(self.idx_to_char.get(i, '') for i in ids)

def get_shakespeare() -> str:
"""Download and load Shakespeare dataset."""
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')

if not os.path.exists(input_file_path):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
import urllib.request
print(f'Downloading Shakespeare dataset to {input_file_path}...')
urllib.request.urlretrieve(data_url, input_file_path)

with open(input_file_path, 'r', encoding='utf-8') as f:
return f.read()

def create_dataset(
text: str,
block_size: int,
batch_size: int,
split: float = 0.9,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Create train and validation datasets from text."""
# Create tokenizer and encode text
tokenizer = Tokenizer()
data = np.array(tokenizer.encode(text))

# Split into train and validation sets
n = int(split * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batches(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
# Create input and target sequences
x = []
y = []
for i in range(0, len(data) - block_size):
x.append(data[i:i + block_size])
y.append(data[i + 1:i + block_size + 1])

# Stack into batches
x = np.stack(x)
y = np.stack(y)

# Shuffle and create batches
indices = np.random.permutation(len(x))
x = x[indices]
y = y[indices]

n_batches = len(x) // batch_size
x = x[:n_batches * batch_size].reshape(n_batches, batch_size, -1)
y = y[:n_batches * batch_size].reshape(n_batches, batch_size, -1)

return x, y

train_x, train_y = get_batches(train_data)
val_x, val_y = get_batches(val_data)

return train_x, train_y, val_x, val_y

def get_batch(
x: np.ndarray,
y: np.ndarray,
batch_idx: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Get a single batch from the dataset."""
return jnp.array(x[batch_idx]), jnp.array(y[batch_idx])
Loading