Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama4/.gitignore → llama4_jax/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ __pycache__/
build/**

.venv
.vscode
.vscode
8 changes: 4 additions & 4 deletions llama4/README.md → llama4_jax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint
converter for the weights. It currently runs on TPU. Support for GPU is
in-progress.

The entire model is defined in [model.py](llama4_jax/model.py) and invoked
via [main.py](main.py). Among other things, the model code demonstrates:
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep main.py separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdyro It's your call. But it is very nonstandard Python, the __main__.py properly integrates… seeing itself as within Python module; being relocatable; and installable.

One other idea is to create a hierarchy:

.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax

All packages are installable. There are no main.pys. Outside of jax_examples_cli there are no __main__.pys. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd() or JAX_LLM_EXAMPLES env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py.

Already your main.pys are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …

via `python3 -m llam4_jax`. Among other things, the model code demonstrates:
* an MLA attention implementation;
* expert and tensor-parallelism via JAX's
[`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map)
Expand Down Expand Up @@ -47,12 +47,12 @@ the full model. We've tested on v5e-64.

Run on all hosts in the TPU cluster:
```
$ python3 main.py
$ python3 -m llam4_jax
```
e.g. for Cloud TPU:
```
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \
--command="cd ~/llama4_jax && python3 main.py"
--command="cd ~/llama4_jax && python3 -m llam4_jax"
```

Responses:
Expand Down
Empty file.
1 change: 0 additions & 1 deletion llama4/main.py → llama4_jax/llama4_jax/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import dataclasses
from etils import epath
import json
from pprint import pformat

import jax
from jax import numpy as jnp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import dataclasses
from typing import Any

import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P
import torch
from tqdm import tqdm

from . import model as l4jax
from llama4_jax import model as l4jax


def quantize_model(ckpt_path: Path, quant_ckpt_path: Path):
Expand Down
7 changes: 4 additions & 3 deletions llama4/llama4_jax/model.py → llama4_jax/llama4_jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from jax.experimental.shard import auto_axes, reshard
from etils import epath

from . import ragged_attention
from .decode_ragged_dot import decode_ragged_dot
from llama4_jax import ragged_attention
from llama4_jax.decode_ragged_dot import decode_ragged_dot

map, builtin_map = jax.util.safe_map, map
builtin_map = map
map = jax.util.safe_map
AxisName = str | tuple[str, ...] | None
Axes = tuple[AxisName, ...]

Expand Down
3 changes: 2 additions & 1 deletion llama4/pyproject.toml → llama4_jax/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "llama4_jax"
version = "0.1.0"
description = ""
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

authors = [
{ name = "Robert Dyro" },
]
Expand All @@ -13,6 +13,7 @@ dependencies = [
"jax",
"torch",
"transformers", # for the model config and the tokenizer
"tune_jax",
"tqdm",
"numpy",
"orbax-checkpoint",
Expand Down
Empty file added llama4_jax/scripts/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
#!/usr/bin/env python3

import sys
from pathlib import Path
from pprint import pprint
from argparse import ArgumentParser
import dataclasses
import os.path
import shutil

try:
from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils
except ImportError:
sys.path.append(str(Path(__file__).parent.absolute()))

from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils

from transformers import AutoConfig
from safetensors import safe_open
from tqdm import tqdm

from llama4_jax import model as l4jax
from llama4_jax import chkpt_utils as utils


def main(model_path: str | Path, ckpt_path: str | Path):
model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser()
files = list(model_path.glob("**/*safetensors"))
files = list(model_path.glob(os.path.join("**", "*safetensors")))
assert len(files) > 1
config_files = list(model_path.glob("**/config.json"))
config_files = list(model_path.glob(os.path.join("**", "config.json")))
assert len(config_files) == 1, "Must have only one `config.json` file in the model path"
config = AutoConfig.from_pretrained(config_files[0]).text_config
cfg = l4jax.hf_to_jax_config(config)
Expand All @@ -44,9 +37,9 @@ def main(model_path: str | Path, ckpt_path: str | Path):

additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
for additional_file in additional_files:
full_paths = list(model_path.glob(f"**/{additional_file}"))
full_paths = list(model_path.glob(os.path.join("**", additional_file)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we stick with pathlib.Path throughout here, it largely replaces os.path.*

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After 11 years maybe I can rely on it being available!

os.path.join is actually used by it internally IIRC; but sure can move to pathlib for my jax-ml/jax-llm-examples contributions.

if len(full_paths) != 1:
print(f"Found more than 1 file for {additional_file}")
print("Found more than 1 file for", additional_file)
if len(full_paths) == 0:
continue
full_path = full_paths[0]
Expand All @@ -56,11 +49,12 @@ def main(model_path: str | Path, ckpt_path: str | Path):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--source-path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="HF model directory path"
"--source-path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"),
required=True, help="HF model directory path"
)
parser.add_argument(
"--dest-path",
default="~/DeepSeek-R1-Distill-Llama-3.1-70B-Instruct",
default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-3.1-70B-Instruct"),
required=True,
help="JAX model model directory (to be created).",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3

import os
from argparse import ArgumentParser
from pathlib import Path

Expand All @@ -12,19 +12,20 @@


def main(model_id: str, dest_root_path: str | Path):
local_dir = Path(dest_root_path).expanduser().absolute() / str(model_id).replace("/", "--")
local_dir = Path(dest_root_path).expanduser().absolute() / str(model_id).replace(os.path.sep, "--")
snapshot_download(repo_id=model_id, local_dir=local_dir)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--model-id", required=True, help=f"HuggingFace model / repo id. Examples include: {example_models}"
"--model-id", required=True,
help=f"HuggingFace model / repo id. Examples include: {example_models}"
)
parser.add_argument(
"--dest-root-path",
required=True,
default="~/",
default=os.path.join(os.path.expanduser("~"), ""),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/").expanduser()

help="Destination root directory, the model will be saved into its own directory.",
)
args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#!/usr/bin/env python3

import sys
import os.path
from pathlib import Path
from argparse import ArgumentParser

try:
from llama4_jax import chkpt_utils as utils
except ImportError:
sys.path.append(str(Path(__file__).parent.absolute()))

from llama4_jax import chkpt_utils as utils
from llama4_jax import chkpt_utils as utils


def main(path: str | Path, suffix: str):
Expand All @@ -21,7 +16,8 @@ def main(path: str | Path, suffix: str):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path"
"--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/DeepSeek-R1-Distill-Llama-70B").expanduser()

required=True, help="Existing JAX model checkpoint path"
)
parser.add_argument(
"--suffix",
Expand Down