|
17 | 17 | RowwiseParallel, |
18 | 18 | SequenceParallel, |
19 | 19 | ) |
| 20 | +from torchtitan.components.quantization.mx.expert_parallel import MXExpertParallel |
20 | 21 | from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
21 | 22 | from torchtitan.config.job_config import Compile as CompileConfig |
22 | 23 | from torchtitan.distributed import NoParallel, ParallelDims |
@@ -98,6 +99,7 @@ def parallelize_llama( |
98 | 99 | if parallel_dims.tp_enabled or parallel_dims.ep_enabled: |
99 | 100 | apply_moe_ep_tp( |
100 | 101 | model, |
| 102 | + job_config, |
101 | 103 | tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, |
102 | 104 | ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, |
103 | 105 | ep_tp_mesh=( |
@@ -436,13 +438,34 @@ def apply_fsdp( |
436 | 438 |
|
437 | 439 | def apply_moe_ep_tp( |
438 | 440 | model: nn.Module, |
| 441 | + job_config: JobConfig, |
439 | 442 | tp_mesh: DeviceMesh | None, |
440 | 443 | ep_mesh: DeviceMesh | None, |
441 | 444 | ep_tp_mesh: DeviceMesh | None, |
442 | 445 | etp_enabled: bool, |
443 | 446 | ): |
444 | 447 | assert ep_mesh is not None or tp_mesh is not None |
445 | 448 |
|
| 449 | + EP_IMPLS = { |
| 450 | + "default": ExpertParallel, |
| 451 | + "mxfp8": MXExpertParallel, |
| 452 | + } |
| 453 | + assert ( |
| 454 | + job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS |
| 455 | + ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}" |
| 456 | + assert ( |
| 457 | + job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS |
| 458 | + ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}" |
| 459 | + |
| 460 | + logger.info( |
| 461 | + f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}" |
| 462 | + ) |
| 463 | + logger.info( |
| 464 | + f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}" |
| 465 | + ) |
| 466 | + |
| 467 | + ep_class = EP_IMPLS[job_config.quantize.expert_parallel_a2a_dispatch_impl] |
| 468 | + |
446 | 469 | for transformer_block in model.layers.values(): |
447 | 470 | if not transformer_block.moe_enabled: |
448 | 471 | continue |
@@ -491,7 +514,7 @@ def apply_moe_ep_tp( |
491 | 514 | elif tp_mesh is None or not etp_enabled: |
492 | 515 | experts_mesh = ep_mesh |
493 | 516 | # input / output sharding on the batch / tokens dim |
494 | | - experts_plan = ExpertParallel() |
| 517 | + experts_plan = ep_class() |
495 | 518 | else: |
496 | 519 | experts_mesh = ep_tp_mesh |
497 | 520 | experts_plan = ExpertTensorParallel() |
|
0 commit comments