Skip to content

Commit 795318e

Browse files
[mxfp8 moe training] add mxfp8 all to all impl
1 parent e43621c commit 795318e

File tree

11 files changed

+84
-9
lines changed

11 files changed

+84
-9
lines changed

torchtitan/components/quantization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ def _validate(job_config: JobConfig):
5858

5959
# Import to register quantization modules as ModelConverter
6060
# (imports down here to avoid circular imports with QuantizationConverter)
61-
import torchtitan.components.quantization.float8 # noqa: F401
62-
import torchtitan.components.quantization.mx # noqa: F401
61+
import torchtitan.components.quantization.float8.converters # noqa: F401
62+
import torchtitan.components.quantization.mx.converters # noqa: F401
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

torchtitan/components/quantization/float8.py renamed to torchtitan/components/quantization/float8/converters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FP8_GROUP_ALIGNMENT_SIZE,
1212
QuantizationConverter,
1313
)
14+
from torchtitan.components.quantization.utils import module_filter_fn
1415

1516
from torchtitan.config.job_config import Float8Linear, JobConfig
1617
from torchtitan.distributed import ParallelDims
@@ -19,8 +20,6 @@
1920
from torchtitan.tools.logging import logger
2021
from torchtitan.tools.utils import has_cuda_capability
2122

22-
from .utils import module_filter_fn
23-
2423
AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"
2524

2625

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

torchtitan/components/quantization/mx.py renamed to torchtitan/components/quantization/mx/converters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MXFP8_GROUP_ALIGNMENT_SIZE,
1414
QuantizationConverter,
1515
)
16+
from torchtitan.components.quantization.utils import module_filter_fn
1617

1718
from torchtitan.config.job_config import JobConfig
1819
from torchtitan.distributed import ParallelDims
@@ -21,8 +22,6 @@
2122
from torchtitan.tools.logging import logger
2223
from torchtitan.tools.utils import has_cuda_capability
2324

24-
from .utils import module_filter_fn
25-
2625

2726
class MXLinearConverter(QuantizationConverter):
2827
"""Converts the linear layers of `model` to `MXLinear`."""
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.distributed.expert_parallel import ExpertParallel
8+
9+
10+
class MXExpertParallel(ExpertParallel):
11+
def __init__(self) -> None:
12+
super().__init__()
13+
try:
14+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
15+
to_mxfp8_a2a_dequant,
16+
)
17+
except ImportError as err:
18+
raise ImportError(
19+
"Please install torchao v0.14+ to use MXExpertParallel"
20+
) from err
21+
self._a2a_dispatch_impl = to_mxfp8_a2a_dequant
22+
self._a2a_combine_impl = to_mxfp8_a2a_dequant

torchtitan/config/job_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,24 @@ class Quantize:
752752
grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM)
753753
"""Quantized training config for grouped GEMMs"""
754754

755+
expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
756+
"""
757+
All-to-all implementation to use for the token dispatch step in expert parallelism.
758+
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
759+
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
760+
using all_to_all_single on the quantized data and scales, then dequantizing
761+
the outputs back to original precision.
762+
"""
763+
764+
expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default"
765+
"""
766+
All-to-all implementation to use for the token combine step in expert parallelism.
767+
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
768+
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
769+
using all_to_all_single on the quantized data and scales, then dequantizing
770+
the outputs back to original precision.
771+
"""
772+
755773

756774
@dataclass
757775
class Comm:

torchtitan/distributed/expert_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def __init__(self):
7171
self.output_splits = None
7272
self.input_shape = None
7373
self.permuted_indices = None
74+
self._a2a_dispatch_impl = all_to_all_single_autograd
75+
self._a2a_combine_impl = all_to_all_single_autograd
7476

7577
# performing all-to-all dispatch on the input
7678
def _token_dispatch(self, mod, inputs, device_mesh):
@@ -107,7 +109,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
107109
self.output_splits = output_splits.tolist()
108110

109111
# perform all-to-all
110-
routed_input = all_to_all_single_autograd(
112+
routed_input = self._a2a_dispatch_impl(
111113
routed_input,
112114
self.output_splits,
113115
self.input_splits,
@@ -150,7 +152,7 @@ def _token_combine(self, mod, routed_output, device_mesh):
150152
routed_output, self.input_shape, self.permuted_indices
151153
)
152154

153-
routed_output = all_to_all_single_autograd(
155+
routed_output = self._a2a_combine_impl(
154156
routed_output,
155157
self.input_splits,
156158
self.output_splits,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def parallelize_deepseekv3(
9191
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
9292
apply_moe_ep_tp(
9393
model,
94+
job_config,
9495
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
9596
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
9697
ep_tp_mesh=(

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RowwiseParallel,
1818
SequenceParallel,
1919
)
20+
from torchtitan.components.quantization.mx.expert_parallel import MXExpertParallel
2021
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2122
from torchtitan.config.job_config import Compile as CompileConfig
2223
from torchtitan.distributed import NoParallel, ParallelDims
@@ -98,6 +99,7 @@ def parallelize_llama(
9899
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
99100
apply_moe_ep_tp(
100101
model,
102+
job_config,
101103
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
102104
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
103105
ep_tp_mesh=(
@@ -436,13 +438,34 @@ def apply_fsdp(
436438

437439
def apply_moe_ep_tp(
438440
model: nn.Module,
441+
job_config: JobConfig,
439442
tp_mesh: DeviceMesh | None,
440443
ep_mesh: DeviceMesh | None,
441444
ep_tp_mesh: DeviceMesh | None,
442445
etp_enabled: bool,
443446
):
444447
assert ep_mesh is not None or tp_mesh is not None
445448

449+
EP_IMPLS = {
450+
"default": ExpertParallel,
451+
"mxfp8": MXExpertParallel,
452+
}
453+
assert (
454+
job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS
455+
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}"
456+
assert (
457+
job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS
458+
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}"
459+
460+
logger.info(
461+
f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}"
462+
)
463+
logger.info(
464+
f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}"
465+
)
466+
467+
ep_class = EP_IMPLS[job_config.quantize.expert_parallel_a2a_dispatch_impl]
468+
446469
for transformer_block in model.layers.values():
447470
if not transformer_block.moe_enabled:
448471
continue
@@ -491,7 +514,7 @@ def apply_moe_ep_tp(
491514
elif tp_mesh is None or not etp_enabled:
492515
experts_mesh = ep_mesh
493516
# input / output sharding on the batch / tokens dim
494-
experts_plan = ExpertParallel()
517+
experts_plan = ep_class()
495518
else:
496519
experts_mesh = ep_tp_mesh
497520
experts_plan = ExpertTensorParallel()

0 commit comments

Comments
 (0)