-
-
Notifications
You must be signed in to change notification settings - Fork 10k
Support mnnvl all2allv from Flashinfer #21003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Flashinfer's mnnvl all2allv
for Mixture-of-Experts (MoE) layers, which is a significant performance enhancement for distributed inference. The changes are comprehensive, touching custom ops, distributed communicators, the MoE layer implementation, and quantization methods.
The core of the change is the new FlashInferAllToAllManager
and its integration into the MoE forward pass. The review focuses on potential issues like hardcoded values, code duplication, and correctness of the communication logic to ensure the new feature is robust and maintainable.
gpus_per_node: int = 4, #TODO(shuw): remove hardcode | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gathered_topk_weights = torch.flatten(gathered_topk_weights.contiguous(), | ||
start_dim=0, | ||
end_dim=-2) | ||
# _flashinfer_all2all = comm.all2all_manager? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("xxxx"*100) | ||
print(all2all_manager) | ||
print(f"ep_size:{self.ep_size}, {self.ep_rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("xxxx"*100) | ||
print(all2all_manager) | ||
print(f"ep_size:{self.ep_size}, {self.ep_rank}") | ||
assert all2all_manager is not None | ||
# TODO(shuw): need to consider chunking for global_num_tokens_cpu | ||
x1, topk_ids1, topk_weights1, alltoall_info = all2all_manager.dispatch( | ||
get_dp_group().device_communicator, | ||
global_num_tokens_cpu, | ||
a1, | ||
topk_ids, | ||
topk_weights, | ||
top_k, | ||
num_experts, | ||
self.ep_rank, | ||
self.ep_size, | ||
) | ||
self.alltoall_info = alltoall_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if enable_flashinfer_fp4_allgather: | ||
topk_weights, topk_ids, a1q, a1q_scale = \ | ||
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], | ||
dim=0, | ||
sizes=get_local_sizes(local_tokens)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code block if enable_flashinfer_fp4_allgather:
seems to perform a redundant communication. The all2all_manager.dispatch
call on line 122 already performs a gather operation. This is then followed by another get_dp_group().all_gatherv
here. This appears to be redundant and could impact performance. Please verify if both are necessary. If not, the redundant call should be removed.
# if enable_flashinfer_alltoall: | ||
# print("all2allcalling"*100) | ||
# a1q = MnnvlMoe.mnnvl_moe_alltoallv(a1q, self.alltoall_info, | ||
# self.alltoall_workspace, | ||
# self.ep_rank, self.ep_size) | ||
# a1q_scale = MnnvlMoe.mnnvl_moe_alltoallv( | ||
# a1q_scale, alltoall_info, self.alltoall_workspace, | ||
# self.ep_rank, self.ep_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
888fe50
to
1ead2de
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Shu Wang. <shuw@nvidia.com>
This pull request has merge conflicts that must be resolved before it can be |
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Needs flashinfer-ai/flashinfer#1245
Purpose
Test Plan
vs.
VLLM_ALL2ALL_BACKEND="naive" \
...
Test Result
accuracy:
perf:
Alltoallv(this PR):
allgather-reducescatter
(Optional) Documentation Update