Skip to content

Mnnvl memory with custom communicator #1245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(

std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
CutlassGemmConfig::CandidateConfigTypeParam const config) {
#ifdef FAST_BUILD
#ifdef False //FAST_BUILD
// Fast build disables all configs except this one for SM100
return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ struct GemmProfilerBackend {
mWType = wtype;
mOType = otype;
mNumExperts = num_experts;
mNumExpertsPerNode = num_experts / (parallelism_config.ep_size * parallelism_config.tp_size);
mNumExpertsPerNode = num_experts / (parallelism_config.ep_size);// * parallelism_config.tp_size);
mK = k;
mExpertHiddenSize = hidden_size;
mExpertInterSize = inter_size;
Expand Down
15 changes: 14 additions & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def search_cache(
cache_key = r.get_cache_key(custom_op, input_shapes, tuning_config)

if cache_key in self.profiling_cache:
# print(f"self.profiling_cache:{len(self.profiling_cache)}")
# # print("cache hit", cache_key)
# print(tuning_config)
return True, *self.profiling_cache[cache_key]

return False, 0, -1, None
Expand Down Expand Up @@ -452,9 +455,13 @@ def choose_one(
)
# Record the total configs to try
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
# print("xxx"*20)
# print(f"profiles:{len(profiles)}")

for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
# [print(i.shape) for i in tensors]
# [print(i.dtype) for i in tensors]
is_cache_hit, runner, tactic, _ = self.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config
)
Expand All @@ -464,17 +471,20 @@ def choose_one(
runner, tactic = None, None
for runner_id, r in enumerate(runners):
# TODO: use FakeTensor here.
# [print(t.shape) for t in tensors]
valid_tactics = r.get_valid_tactics(tensors)
runner_arg_names = {
p.name for p in inspect.signature(r.forward).parameters.values()
}
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
r(tensors, tactic=-1, do_preparation=True, **kwargs)
# print(f"valid_tactics: {len(valid_tactics)}")
for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
r, tensors, tac, **kwargs
)
# print(f"time_measured: {time_measured}, {tac}")
except Exception as e:
logger.error(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={[t.size() for t in tensors]}. Error occurred: {e}"
Expand Down Expand Up @@ -508,13 +518,16 @@ def choose_one(
logger.debug(
f"[Autotuner]: profiling chosen runner: {runner} {tactic} for {cache_key}"
)
# print(f"[Autotuner]: profiling chosen runner: {runner} {tactic} for {cache_key}")


# Get the best runner and tactic from cache
# If no valid tactic is found, the fallback runner and tactic will be used
# print("search cache")
_, runner_id, tactic, _ = self.search_cache(
custom_op, runners, input_shapes, tuning_config
)

# print(f"returning tactic: {tactic} for {runners[runner_id]}")
return runners[runner_id], tactic

def _profile_single_kernel(
Expand Down
89 changes: 80 additions & 9 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
# Code imported from TensorRT-LLM/tensorrt_llm/_mnnvl_utils.py
import ctypes
from typing import Optional, Dict, Tuple, List
from abc import ABC, abstractmethod
import logging
import platform
import sys
Expand All @@ -23,7 +25,7 @@
import pynvml
import torch
from cuda import cuda
from mpi4py import MPI
# from mpi4py import MPI

from ..cuda_utils import checkCudaErrors
from .dlpack_utils import create_dlpack_capsule, pack_strided_memory
Expand Down Expand Up @@ -111,6 +113,7 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool:


class MpiComm:
from mpi4py import MPI
_comm: MPI.Intracomm = MPI.COMM_WORLD

@classmethod
Expand All @@ -120,8 +123,55 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm):
def __getattr__(self, name):
return getattr(self._comm, name)

class CommBackend(ABC):
"""Abstract communication backend interface"""
@abstractmethod
def Get_rank(self) -> int: ...

@abstractmethod
def Get_size(self) -> int: ...

@abstractmethod
def allgather(self, data: int) -> List[int]: ...

@abstractmethod
def allgather_bytes(self, data): ...

@abstractmethod
def Split(self, color: int, key: int) -> 'CommBackend': ...
class LegacyMPIBackend(CommBackend):
"""Adapter for the original MpiComm singleton pattern"""
def __init__(self):
self._mpicomm = MpiComm()

def Get_rank(self) -> int:
return self._mpicomm.Get_rank()

def Get_size(self) -> int:
return self._mpicomm.Get_size()

def allgather(self, data: int) -> List[int]:
return self._mpicomm.allgather(data)

def allgather_bytes(self, data):
return self._mpicomm.allgather(data)

def Split(self, color: int, key: int) -> CommBackend:
# Original split logic
new_comm = self._mpicomm.Split(color, key)
return LegacyMPIBackend() # Returns new adapter

@dataclass
class MnnvlConfig:
"""Configuration for MNNVL memory management"""
comm_backend: Optional[CommBackend] = None
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512MB


class MnnvlMemory:
_config: MnnvlConfig = MnnvlConfig(comm_backend=LegacyMPIBackend()) # Default to legacy MPI

initialized: bool = False

current_mem_offset: int = 0
Expand All @@ -145,8 +195,9 @@ class MnnvlMemory:
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
# self._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend())
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size)

def __del__(self):
if not sys.is_finalizing():
MnnvlMemory.close_mnnvl_memory(self.ptr)
Expand Down Expand Up @@ -174,16 +225,33 @@ def initialize():
pynvml.nvmlInit()
MnnvlMemory.initialized = True

@staticmethod
def set_comm(config: MnnvlConfig = None):
MnnvlMemory._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend())

@staticmethod
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm
comm = MpiComm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
"""Modified to work with configurable backends"""
# If using legacy MPI path (original behavior)
if isinstance(MnnvlMemory._config.comm_backend, LegacyMPIBackend):
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm
comm = MpiComm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.tp_rank
)
# New backend-aware path
else:
print(MnnvlMemory._config)
backend = MnnvlMemory._config.comm_backend
comm = backend.Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.tp_rank
)
MnnvlMemory.comm = comm
return comm


@staticmethod
def get_allocation_prop(dev_id: int):
location = cuda.CUmemLocation()
Expand Down Expand Up @@ -249,7 +317,6 @@ def open_mnnvl_memory(mapping: Mapping, size: int):
), "Not all rank allocating same size."
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity

if (
MnnvlMemory.current_mem_offset + aligned_size
> MnnvlMemory.current_rank_stride
Expand All @@ -272,7 +339,11 @@ def open_mnnvl_memory(mapping: Mapping, size: int):
0,
)
)
all_handles_data = comm.allgather(exported_fabric_handle.data)
print(f"cccccccccccccccccc : {exported_fabric_handle.data}")
# all_handles_data = comm.allgather(exported_fabric_handle.data)
all_handles_data = comm.allgather_bytes(exported_fabric_handle.data)
print(f"passssss : {all_handles_data}")
# all_handles_data = comm.allgather(exported_fabric_handle.data)
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
# can use buf = memoryview(data) to import if using plain buffer for data.

Expand Down
8 changes: 4 additions & 4 deletions flashinfer/comm/trtllm_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from ..jit import gen_jit_spec
from ..utils import register_custom_op
from .mapping import Mapping
from .mnnvl import MnnvlMemory

from .mnnvl import (MnnvlMemory, MnnvlConfig)

def gen_comm_alltoall_module() -> JitSpec:
return gen_jit_spec(
Expand Down Expand Up @@ -296,13 +295,14 @@ class MnnvlMoe:
moe_mapping: Mapping = None

@staticmethod
def get_moe_workspaces(mapping: Mapping):
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None):
if MnnvlMoe.moe_workspace is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_workspace_tensor

MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
MnnvlMemory.set_comm(config)
MnnvlMemory.initialize()
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
torch.uint64
Expand Down
12 changes: 9 additions & 3 deletions flashinfer/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def gen_fused_moe_sm100_module() -> JitSpec:
"-DCOMPILE_HOPPER_TMA_GEMMS",
],
extra_cflags=[
"-DFAST_BUILD",
# "-DFAST_BUILD",
],
extra_ldflags=["-lcuda"],
extra_include_paths=[
Expand Down Expand Up @@ -195,7 +195,6 @@ def get_valid_tactics(
invalid = (m > 128 and min_latency_mode) or (
m <= 128 and min_latency_mode and (not self._is_nvfp4)
)

return (
[] if invalid else list(range(self._fused_moe_runner.get_tactic_num()))
)
Expand All @@ -210,6 +209,10 @@ def forward(
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
# determine if we should use min latency mode according to the profiled seq len
# print("uuuu"*10)
# import traceback
# traceback.print_stack()
# print(f"do_preparation: {do_preparation}, gemm_idx: {gemm_idx}, tactic: {tactic}")
self._fused_moe_runner.run_gemm_profile(
x,
fc1_expert_weights,
Expand Down Expand Up @@ -309,7 +312,10 @@ def next_positive_power_of_2(x: int) -> int:
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
gemm_idx=2,
)

# print(f"input:{input.shape}")
# print(f"fc1_expert_weights:{fc1_expert_weights.shape}")
# print(f"fc2_expert_weights:{fc2_expert_weights.shape}")
print(gemm_tactic_1, gemm_tactic_2)
run_moe = (
moe_runner._fused_moe_runner.run_moe_min_latency
if min_latency_mode
Expand Down
46 changes: 46 additions & 0 deletions run_collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import subprocess
import re

# Values to replace the last parameter (1)
values = [
1, 2, 4, 8, 16, 24, 32, 48, 64,
96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]

base_cmd = (
"pytest -s "
"'tests/test_trtllm_cutlass_fused_moe.py::"
"test_moe_nvfp4[True-True-otype0-wtype0-256-8-256-7168-{}]'"
)

time_pattern = re.compile(r"Elapsed time: ([\d.]+) ms")

results = []

for v in values:
print(f"Running with last param = {v}")
cmd = base_cmd.format(v)
try:
output = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT, text=True)
match = time_pattern.search(output)
if match:
elapsed_time = float(match.group(1))
else:
elapsed_time = None
print(f"Warning: Elapsed time not found in output for {v}")
except subprocess.CalledProcessError as e:
output = e.output
elapsed_time = None
print(f"Error running test for {v}:\n{output}")

results.append((v, elapsed_time))

# Print results as a table
print("\nResults:")
print(f"{'Value':>6} | {'Time (ms)':>10}")
print("-" * 20)
for val, time in results:
time_str = f"{time:.2f}" if time is not None else "N/A"
print(f"{val:6} | {time_str:>10}")

Loading
Loading