Skip to content

Commit d880de2

Browse files
authored
Graduate qwen3 from experiment to core (#1860)
As titled. Added CI for test, fix minor TP issue after adding attention_mask
1 parent 304dfc3 commit d880de2

File tree

13 files changed

+48
-8
lines changed

13 files changed

+48
-8
lines changed

tests/integration_tests/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
7676
"pp+fsdp+tp+ep+etp",
7777
ngpu=8,
7878
),
79+
# Integration Test Cases for Qwen3 dense and MoE model
80+
OverrideDefinitions(
81+
[
82+
[
83+
"--model.name qwen3",
84+
"--parallelism.data_parallel_shard_degree 2",
85+
"--parallelism.tensor_parallel_degree 2",
86+
],
87+
],
88+
"Qwen3 FSDP+TP",
89+
"qwen3_fsdp+tp",
90+
ngpu=4,
91+
),
92+
OverrideDefinitions(
93+
[
94+
[
95+
"--model.name qwen3",
96+
"--model.flavor debugmodel_moe",
97+
"--parallelism.data_parallel_shard_degree 2",
98+
"--parallelism.tensor_parallel_degree 2",
99+
"--parallelism.expert_parallel_degree 2",
100+
"--parallelism.expert_tensor_parallel_degree 2",
101+
],
102+
],
103+
"Qwen3 FSDP+TP+EP+ETP",
104+
"qwen3_fsdp+tp+ep+etp",
105+
ngpu=4,
106+
),
79107
]
80108

81109
return model_tests

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
# LICENSE file in the root directory of this source tree.
66

77
_supported_experiments = frozenset(
8-
["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
8+
["flux", "llama4", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
99
)

torchtitan/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
_supported_models = frozenset(["llama3", "llama3_ft", "deepseek_v3"])
7+
_supported_models = frozenset(["llama3", "llama3_ft", "deepseek_v3", "qwen3"])

torchtitan/experiments/qwen3/README.md renamed to torchtitan/models/qwen3/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ eg, for Qwen3 0.6B model, the HF repo name is `Qwen/Qwen3-0.6B`. For 1.7B model,
2323
## To be added
2424
- Modeling
2525
- CP is not supported currently because of RoPE embedding implementation details.
26-
- `StateDictAdapter` support for MoE model
2726

2827
- Testing
2928
- Learning rate verifying: verify learning rate and schedule with real training jobs (eg, 3k stps), or find official references.

torchtitan/experiments/qwen3/__init__.py renamed to torchtitan/models/qwen3/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@
3030
# Adding different variants of the model
3131

3232
qwen3_configs = {
33+
"debugmodel": Qwen3ModelArgs(
34+
vocab_size=2048,
35+
max_seq_len=4096,
36+
head_dim=128,
37+
dim=256,
38+
n_layers=8,
39+
n_heads=16,
40+
n_kv_heads=8,
41+
qk_norm=True,
42+
hidden_dim=3072,
43+
rope_theta=1000000,
44+
enable_weight_tying=True,
45+
),
3346
"0.6B": Qwen3ModelArgs(
3447
vocab_size=151936,
3548
max_seq_len=4096,
@@ -107,11 +120,11 @@
107120
),
108121
# Qwen3-MoE models
109122
"debugmodel_moe": Qwen3ModelArgs(
110-
vocab_size=151936,
123+
vocab_size=2048,
111124
max_seq_len=4096,
112125
head_dim=128,
113-
dim=1024,
114-
n_layers=28,
126+
dim=256,
127+
n_layers=8,
115128
n_heads=16,
116129
n_kv_heads=8,
117130
qk_norm=True,

torchtitan/experiments/qwen3/infra/parallelize.py renamed to torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def apply_non_moe_tp(
239239
layer_plan = {
240240
"attention_norm": SequenceParallel(),
241241
"attention": prepare_module_input(
242-
input_layouts=(Shard(1), Replicate()),
243-
desired_input_layouts=(Replicate(), Replicate()),
242+
input_layouts=(Shard(1), Replicate(), None),
243+
desired_input_layouts=(Replicate(), Replicate(), None),
244244
),
245245
"attention.wq": colwise_parallel(use_local_output=False),
246246
"attention.wk": colwise_parallel(use_local_output=False),
File renamed without changes.

0 commit comments

Comments
 (0)