Skip to content

Conversation

SamuelMarks
Copy link

@SamuelMarks SamuelMarks commented Apr 17, 2025

[llama4 -> llama4_jax] Refactor to be a proper installable Python package ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage

Also I know you have path = Path(path).expanduser().absolute() but that doesn't provide nice --help text and should be expanded earlier anyway. I can remove your Path(path).expanduser() if you give the go-ahead.

PS: If you merge this PR I can send you a new PR for deepseek_r1_jax

…kage ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage
Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! I left some comments

additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
for additional_file in additional_files:
full_paths = list(model_path.glob(f"**/{additional_file}"))
full_paths = list(model_path.glob(os.path.join("**", additional_file)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we stick with pathlib.Path throughout here, it largely replaces os.path.*

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After 11 years maybe I can rely on it being available!

os.path.join is actually used by it internally IIRC; but sure can move to pathlib for my jax-ml/jax-llm-examples contributions.

"--dest-root-path",
required=True,
default="~/",
default=os.path.join(os.path.expanduser("~"), ""),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/").expanduser()

parser = ArgumentParser()
parser.add_argument(
"--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path"
"--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/DeepSeek-R1-Distill-Llama-70B").expanduser()

name = "llama4_jax"
version = "0.1.0"
description = ""
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


The entire model is defined in [model.py](llama4_jax/model.py) and invoked
via [main.py](main.py). Among other things, the model code demonstrates:
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep main.py separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdyro It's your call. But it is very nonstandard Python, the __main__.py properly integrates… seeing itself as within Python module; being relocatable; and installable.

One other idea is to create a hierarchy:

.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax

All packages are installable. There are no main.pys. Outside of jax_examples_cli there are no __main__.pys. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd() or JAX_LLM_EXAMPLES env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py.

Already your main.pys are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …

# Conflicts:
#	llama4_jax/llama4_jax/model.py
#	llama4_jax/scripts/convert_weights.py
#	llama4_jax/scripts/download_model.py
#	llama4_jax/scripts/quantize_model.py
@rdyro
Copy link
Collaborator

rdyro commented Jun 24, 2025

I like the changes! One request, let's please avoid os.path entirely in favor of pathlib.Path, I prefer not to mix them.

Lots of changes, so I'll need to provision some compute to test those, give me a couple of days, thanks!

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I really appreciate the desire to make this repo better, but I cannot merge this PR as is. It's introducing too many unrelated changes at once:

  • import sorting in python files
  • dependency sorting in pyproject files
  • extra spaces for import groups separation
  • renaming llama4 to llama4_jax
  • renaming of main.py to __main__.py
  • support for non-unix path separators

Can you please separate those into separate PRs or focus on only one of those?

license = { text = "Apache-2.0" }

dependencies = [
"datasets",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please undo import sorting

from jax.experimental.pallas import tpu as pltpu
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P, use_mesh
from jax.experimental.shard import auto_axes, reshard

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

import re
from concurrent.futures import ThreadPoolExecutor
import math

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports

from pathlib import Path

from etils import epath

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports


model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser()
files = list(model_path.glob("**/*safetensors"))
files = list(model_path.glob(os.path.join("**", "*safetensors")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't mix pathlib and os.path

files = list(model_path.glob(os.path.join("**", "*safetensors")))
assert len(files) > 1
config_files = list(model_path.glob("**/config.json"))
config_files = list(model_path.glob(os.path.join("**", "config.json")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't mix pathlib and os.path


import dataclasses
import os

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert extra spaces in imports


def main(model_path: str | Path, ckpt_path: str | Path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing at top level makes for a very slow help message, import under main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants