diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index e495825de..7957771cd 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,9 +16,11 @@ import ctypes import logging import os +from abc import ABC, abstractmethod +from dataclasses import dataclass import platform import sys -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING import torch from cuda import cuda @@ -129,6 +131,22 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: return device_ptr +class CommBackend(ABC): + """Abstract communication backend interface""" + + @abstractmethod + def Get_rank(self) -> int: ... + + @abstractmethod + def Get_size(self) -> int: ... + + @abstractmethod + def allgather(self, data: int) -> List[int]: ... + + @abstractmethod + def Split(self, color: int, key: int) -> "CommBackend": ... + + if IS_BUILDING_DOCS: # Mock classes for building docs @@ -208,18 +226,66 @@ def supports_mnnvl() -> bool: else: import pynvml - from mpi4py import MPI + + if TYPE_CHECKING: + from mpi4py import MPI # noqa: F401 + + def lazy_import_mpi(): + """Lazy import for mpi4py""" + try: + from mpi4py import MPI + + return MPI + except ImportError as err: + raise ImportError("mpi4py is not installed") from err # type: ignore[no-redef] class MpiComm: # type: ignore[no-redef] - _comm: MPI.Intracomm = MPI.COMM_WORLD + _comm: Any = None + _MPI: Any = None @classmethod - def set_mpi_comm(cls, new_comm: MPI.Intracomm): + def _get_mpi(cls): + if cls._MPI is None: + cls._MPI = lazy_import_mpi() + cls._comm = cls._MPI.COMM_WORLD + return cls._MPI + + @classmethod + def set_mpi_comm(cls, new_comm: Any): + cls._get_mpi() + # Optional: add type checking here cls._comm = new_comm def __getattr__(self, name): + if self._comm is None: + self._get_mpi() return getattr(self._comm, name) + class MPIBackend(CommBackend): + def __init__(self): + self._mpicomm = MpiComm() + + def Get_rank(self) -> int: + return self._mpicomm.Get_rank() + + def Get_size(self) -> int: + return self._mpicomm.Get_size() + + def allgather(self, data: int) -> List[int]: + return self._mpicomm.allgather(data) + + def Split(self, color: int, key: int) -> CommBackend: + self._mpicomm = self._mpicomm.Split(color, key) + return MPIBackend() # Returns new adapter + + @dataclass + class MnnvlConfig: + """Configuration for MNNVL memory management""" + + comm_backend: Optional[CommBackend] = None + allocation_granularity: int = 0 + fabric_page_size: int = 1 << 29 # 512MB + class MnnvlMemory: # type: ignore[no-redef] initialized: bool = False @@ -234,13 +300,15 @@ class MnnvlMemory: # type: ignore[no-redef] fabric_page_size: int = 1 << 29 # MPI communicator - comm = None + comm: Optional[CommBackend] = None dev_id: int = None allocated_map: Dict[int, Any] = {} address_refcnt: Dict[int, Any] = {} + config: Optional[MnnvlConfig] = None + def __init__(self, mapping: Mapping, size: int): self.mapping = mapping self.segment_size = size @@ -275,6 +343,14 @@ def initialize(): pynvml.nvmlInit() MnnvlMemory.initialized = True + @staticmethod + def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): + MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] + comm = config.comm_backend.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) + MnnvlMemory.comm = comm # type: ignore[assignment] + @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index 595a84990..3af2de1c2 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -26,7 +26,7 @@ from ..jit import gen_jit_spec from ..utils import register_custom_op from .mapping import Mapping -from .mnnvl import MnnvlMemory +from .mnnvl import MnnvlMemory, MnnvlConfig def gen_comm_alltoall_module() -> JitSpec: @@ -296,13 +296,15 @@ class MnnvlMoe: moe_mapping: Mapping = None @staticmethod - def get_moe_workspaces(mapping: Mapping): + def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None): if MnnvlMoe.moe_workspace is not None: assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now" return MnnvlMoe.moe_workspace_tensor MnnvlMoe.moe_mapping = mapping workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size) + if config: + MnnvlMemory.set_comm_from_config(mapping, config) # type: ignore[attr-defined] MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank) MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor( torch.uint64 diff --git a/tests/test_mnnvl_custom_comm.py b/tests/test_mnnvl_custom_comm.py new file mode 100644 index 000000000..0c7bb25da --- /dev/null +++ b/tests/test_mnnvl_custom_comm.py @@ -0,0 +1,183 @@ +import multiprocessing as mp +import socket +from typing import Any + +import pytest +import torch +import torch.distributed as dist + +import pynvml + +from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import MnnvlConfig, MnnvlMemory +from flashinfer.comm.mnnvl import CommBackend as CommBackend + + +pynvml.nvmlInit() + + +class CustomCommunicator(CommBackend): + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return dist.get_rank(self._group) + + def Get_size(self) -> int: + return dist.get_world_size(self._group) + + def allgather(self, data: int | bytes): + device = f"cuda:{torch.cuda.current_device()}" + if isinstance(data, int): + local_tensor = torch.tensor([data], device=device, dtype=torch.int32) + world_size = self.Get_size() + gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] + + dist.all_gather(gathered, local_tensor, group=self._group) + return [int(x.item()) for x in gathered] + + elif isinstance(data, bytes): + local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device) + world_size = self.Get_size() + gathered = [data] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + else: + raise TypeError(f"Unsupported type for allgather: {type(data)}") + + def Split(self, color: int, key: int) -> "CustomCommunicator": + return self + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, dtype, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0, ( + f"Process {i} failed with exit code {procs[i].exitcode}" + ) + + +def align_memory(size: int): + align_size = 2 * 1024 * 1024 + return (size + align_size - 1) // align_size * align_size + + +def _init_mnnvl_memory(world_size, rank, dtype, distributed_init_port): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + + torch.cuda.set_device(rank) + MnnvlMemory.initialize() + mapping = Mapping(world_size, rank, world_size, tp_size=world_size) + + allocate0_size = 4 * 1024 * 1024 - 3 * 1024 + mnnvl_config = MnnvlConfig( + comm_backend=CustomCommunicator(group), + fabric_page_size=1 << 29, # 512MB + allocation_granularity=0, # Auto-detect + ) + MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) + mnnvl_memory0 = MnnvlMemory(mapping, allocate0_size) + allocate0_size_aligned = align_memory(allocate0_size) + + assert MnnvlMemory.current_mem_offset == allocate0_size_aligned + tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32) + numel_per_rank = allocate0_size // 4 + tensor0[(rank + 1) % world_size] = torch.arange( + start=rank, end=rank + numel_per_rank, device="cuda" + ) + dist.barrier(group=group) + for r in range(world_size): + torch.equal( + tensor0[(r + 1) % world_size], + torch.arange(start=r, end=r + numel_per_rank, device="cuda"), + ) + + allocate1_size = 30 * 1024 * 1024 - 2 * 1024 + mnnvl_memory1 = MnnvlMemory(mapping, allocate1_size) + allocate1_size_aligned = align_memory(allocate1_size) + assert ( + MnnvlMemory.current_mem_offset + == allocate0_size_aligned + allocate1_size_aligned + ) + tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32) + numel_per_rank = allocate1_size // 4 + tensor1[(rank + 5) % world_size] = torch.arange( + start=rank, + end=rank + numel_per_rank, + dtype=torch.float32, + device="cuda", + ) + dist.barrier(group=group) + for r in range(world_size): + torch.equal( + tensor1[(r + 5) % world_size], + torch.arange( + start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda" + ), + ) + dist.barrier(group=group) + del tensor0, mnnvl_memory0 + dist.barrier(group=group) + + large_allocation2_size = 768 * 1024 * 1024 + large_mnnvl_memory2 = MnnvlMemory(mapping, large_allocation2_size) + allocate2_size_aligned = align_memory(large_allocation2_size) + assert MnnvlMemory.current_mem_offset == allocate2_size_aligned + assert large_mnnvl_memory2.rank_stride == (1 << 30) + + del tensor1 + + +@pytest.mark.skipif( + not MnnvlMemory.supports_mnnvl(), + reason="Mnnvl memory is not supported on this platform", +) +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mnnvl_custom_communicator(world_size): + dtype = torch.float16 + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + + multi_process_parallel( + world_size, + dtype, + _init_mnnvl_memory, + target_args=(), + ) + print(f"custom mnnvl communicator world_size = {world_size}: OK")