-
Notifications
You must be signed in to change notification settings - Fork 23
[llama4 -> llama4_jax] Refactor to be a proper installable Python package #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…kage ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! I left some comments
additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] | ||
for additional_file in additional_files: | ||
full_paths = list(model_path.glob(f"**/{additional_file}")) | ||
full_paths = list(model_path.glob(os.path.join("**", additional_file))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we stick with pathlib.Path throughout here, it largely replaces os.path.*
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After 11 years maybe I can rely on it being available!
os.path.join
is actually used by it internally IIRC; but sure can move to pathlib
for my jax-ml/jax-llm-examples
contributions.
llama4_jax/scripts/download_model.py
Outdated
"--dest-root-path", | ||
required=True, | ||
default="~/", | ||
default=os.path.join(os.path.expanduser("~"), ""), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path("~/").expanduser()
llama4_jax/scripts/quantize_model.py
Outdated
parser = ArgumentParser() | ||
parser.add_argument( | ||
"--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path" | ||
"--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path("~/DeepSeek-R1-Distill-Llama-70B").expanduser()
name = "llama4_jax" | ||
version = "0.1.0" | ||
description = "" | ||
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
|
||
The entire model is defined in [model.py](llama4_jax/model.py) and invoked | ||
via [main.py](main.py). Among other things, the model code demonstrates: | ||
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep main.py
separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdyro It's your call. But it is very nonstandard Python, the __main__.py
properly integrates… seeing itself as within Python module; being relocatable; and installable.
One other idea is to create a hierarchy:
.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax
All packages are installable. There are no main.py
s. Outside of jax_examples_cli
there are no __main__.py
s. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd()
or JAX_LLM_EXAMPLES
env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py
.
Already your main.py
s are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …
# Conflicts: # llama4_jax/llama4_jax/model.py # llama4_jax/scripts/convert_weights.py # llama4_jax/scripts/download_model.py # llama4_jax/scripts/quantize_model.py
…face-hub` to deps ; sort deps
I like the changes! One request, let's please avoid os.path entirely in favor of Lots of changes, so I'll need to provision some compute to test those, give me a couple of days, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I really appreciate the desire to make this repo better, but I cannot merge this PR as is. It's introducing too many unrelated changes at once:
- import sorting in python files
- dependency sorting in pyproject files
- extra spaces for import groups separation
- renaming
llama4
tollama4_jax
- renaming of
main.py
to__main__.py
- support for non-unix path separators
Can you please separate those into separate PRs or focus on only one of those?
license = { text = "Apache-2.0" } | ||
|
||
dependencies = [ | ||
"datasets", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please undo import sorting
from jax.experimental.pallas import tpu as pltpu | ||
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import PartitionSpec as P, use_mesh | ||
from jax.experimental.shard import auto_axes, reshard | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
import jax | ||
from jax import numpy as jnp | ||
from jax.sharding import PartitionSpec as P | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
import re | ||
from concurrent.futures import ThreadPoolExecutor | ||
import math | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
from pathlib import Path | ||
|
||
from etils import epath | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
|
||
model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() | ||
files = list(model_path.glob("**/*safetensors")) | ||
files = list(model_path.glob(os.path.join("**", "*safetensors"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't mix pathlib and os.path
files = list(model_path.glob(os.path.join("**", "*safetensors"))) | ||
assert len(files) > 1 | ||
config_files = list(model_path.glob("**/config.json")) | ||
config_files = list(model_path.glob(os.path.join("**", "config.json"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't mix pathlib and os.path
|
||
import dataclasses | ||
import os | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert extra spaces in imports
|
||
def main(model_path: str | Path, ckpt_path: str | Path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
importing at top level makes for a very slow help message, import under main
Also I know you have
path = Path(path).expanduser().absolute()
but that doesn't provide nice--help
text and should be expanded earlier anyway. I can remove yourPath(path).expanduser()
if you give the go-ahead.PS: If you merge this PR I can send you a new PR for deepseek_r1_jax