Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
model: ["deepseek_r1_jax", "kimi_k2", "llama3", "llama4", "qwen3"]
model: ["deepseek_r1_jax", "kimi_k2", "llama3", "llama4", "qwen3", "gpt_oss"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Current contents include:
* [Llama 3](llama3/)
* [Qwen 3](qwen3/)
* [Kimi K2](kimi_k2/)
* [OpenAI GPT OSS](gpt_oss/)

---

Expand Down
246 changes: 133 additions & 113 deletions deepseek_r1_jax/deepseek_r1_jax/model.py

Large diffs are not rendered by default.

147 changes: 0 additions & 147 deletions deepseek_r1_jax/main.ipynb

This file was deleted.

27 changes: 21 additions & 6 deletions deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import jax
from jax.sharding import PartitionSpec as P
from argparse import ArgumentParser

from deepseek_r1_jax.model import ShardingRules, Config
from deepseek_r1_jax import chkpt_utils as utils

def main():
root_path = Path("/mnt/storage/DeepSeek-R1")
dest_path = Path("/mnt/storage/deepseek-r1-jax-chkpt")
def main(root_path, dest_path):
from deepseek_r1_jax.model import ShardingRules, Config
from deepseek_r1_jax import chkpt_utils as utils

root_path, dest_path = Path(root_path), Path(dest_path)
dest_path.mkdir(exist_ok=True, parents=True)

cfg = Config()
cfg.quantize_mlp = False
Expand All @@ -39,4 +41,17 @@ def main():
utils.convert_hf_checkpoint(params_map, root_path, dest_path, cfg)

if __name__ == "__main__":
main()
parser = ArgumentParser()
parser.add_argument(
"--source-path", default="/mnt/storage/DeepSeek-R1-weights-only", required=True, help="HF model directory path"
)
parser.add_argument(
"--dest-path",
default="~/deepseek_r1_jax",
required=True,
help="JAX model model directory (to be created).",
)
args = parser.parse_args()
main(args.source_path, args.dest_path)

main(args)
14 changes: 14 additions & 0 deletions gpt_oss/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
poetry.lock
scratch/**

projects/charformer/data/
projects/bio/data/

# Python ignores
__pycache__/
*.pyc
*.egg-info
build/**

.venv
.vscode
21 changes: 21 additions & 0 deletions gpt_oss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Minimal OpenAI GPT OSS inference

**tl;dr: open-source OpenAI GPT OSS inference using JAX, minimal yet performant**

This model is a work in progress, but it should already work well on both TPU and GPU.

<br/>

This is a pure JAX implementation of OpenAI's GPT OSS for inference, including a
checkpoint converter for the K2 Instruct weights. on TPU.
It should work on GPU.

The entire model is defined in [model.py](gpt_oss_jax/model.py) and invoked
via [main.py](main.py).

## Quickstart

Run:
```
$ python3 main.py
```
Loading