From ff56b7ee487bb5b186daf731e7eeb66c65d12c5d Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 22 Aug 2025 21:40:12 +0000 Subject: [PATCH 1/3] Add mnnvl_moe_alltoallv_prepare_without_allgather --- 3rdparty/cutlass | 2 +- csrc/trtllm_alltoall.cu | 99 ++- csrc/trtllm_alltoall_prepare.cu | 686 ++++++++++++++++++ flashinfer/comm/trtllm_alltoall.py | 150 ++++ .../comm/trtllm_alltoall_prepare.cuh | 128 ++++ tests/test_trtllm_alltoall.py | 290 ++++++++ 6 files changed, 1353 insertions(+), 2 deletions(-) create mode 100644 csrc/trtllm_alltoall_prepare.cu create mode 100644 include/flashinfer/comm/trtllm_alltoall_prepare.cuh diff --git a/3rdparty/cutlass b/3rdparty/cutlass index e51efbfe1..f115c3f85 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit e51efbfe18fe4f4cbb66ab814c55bf4aa0185491 +Subproject commit f115c3f85467d5d9619119d1dbeb9c03c3d73864 diff --git a/csrc/trtllm_alltoall.cu b/csrc/trtllm_alltoall.cu index 1bdc546fd..4baaa9f37 100644 --- a/csrc/trtllm_alltoall.cu +++ b/csrc/trtllm_alltoall.cu @@ -17,9 +17,12 @@ #include -#include "flashinfer/comm/trtllm_alltoall.cuh" +#include #include "pytorch_extension_utils.h" +#include "flashinfer/comm/trtllm_alltoall.cuh" +#include "flashinfer/comm/trtllm_alltoall_prepare.cuh" + using namespace flashinfer::trtllm_alltoall; void moeCommPrepareIndicesOp(at::Tensor gatheredTargetRankIds, @@ -217,10 +220,104 @@ void setMaxUsableSmCount(int64_t maxSmCount) { flashinfer::trtllm_alltoall::setMaxUsableSmCount(static_cast(maxSmCount)); } +int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) +{ + int epSize32 = static_cast(epSize); + return flashinfer::trtllm_alltoall::moe_prepare::getMoePrepareWorkspaceSize(epSize32); +} + +std::tuple, at::Tensor, at::Tensor, at::Tensor, at::Tensor, + at::Tensor, c10::optional> +moePrepareOp(at::Tensor expertsIds, c10::optional scales, c10::optional expertsStatics, + at::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, + int64_t slotCount, int64_t topK) +{ + CHECK_INPUT_TYPE(expertsIds, at::ScalarType::Int); + TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4"); + TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4"); + + int64_t maxSendRanksPerToken = std::max(epSize, topK); + int64_t tokenCount = expertsIds.size(0); + + at::Tensor preparedLocalExpertIds + = at::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor sendRankCountCumSum = at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor RecvRankCountCumSum = at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherRecvRankIndices + = at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor recvRankIndices + = at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherBackwardRecvRankIndices + = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor backwardRecvRankIndices + = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherSendRankIndices + = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor sendRankIndices + = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); + + c10::optional preparedLocalScales; + float* scalesPtr = nullptr; + float* preparedLocalScalesPtr = nullptr; + if (scales.has_value()) + { + CHECK_INPUT_TYPE(scales.value(), at::ScalarType::Float); + scalesPtr = scales->data_ptr(); + preparedLocalScales + = at::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(at::ScalarType::Float)); + preparedLocalScalesPtr = preparedLocalScales->data_ptr(); + } + + int* localExpertStaticsPtr = nullptr; + int* gatheredExpertStaticsPtr = nullptr; + c10::optional gatheredExpertStatics; + if (expertsStatics.has_value()) + { + localExpertStaticsPtr = expertsStatics.value().data_ptr(); + gatheredExpertStatics = at::empty({epSize, expertCount}, expertsIds.options().dtype(at::ScalarType::Int)); + gatheredExpertStaticsPtr = gatheredExpertStatics.value().data_ptr(); + } + + flashinfer::trtllm_alltoall::moe_prepare::MoeCommWorkspace workspace; + workspace.workspacePtr = allWorkspaces.data_ptr(); + workspace.rankStrideInU64 = allWorkspaces.stride(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + flashinfer::trtllm_alltoall::moe_prepare::computeCountAndIndice(expertsIds.data_ptr(), + sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), + backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), workspace, tokenCount, + maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); + + flashinfer::trtllm_alltoall::moe_prepare::computeCumsum( + sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), epRank, epSize, stream); + + flashinfer::trtllm_alltoall::moe_prepare::moveIndice(sendRankCountCumSum.data_ptr(), + RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), gatherSendRankIndices.data_ptr(), + backwardRecvRankIndices.data_ptr(), gatherBackwardRecvRankIndices.data_ptr(), + recvRankIndices.data_ptr(), gatherRecvRankIndices.data_ptr(), epRank, epSize, maxTokenCountPerRank, + stream); + + flashinfer::trtllm_alltoall::moe_prepare::allToAllMetadata(expertsIds.data_ptr(), + preparedLocalExpertIds.data_ptr(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr, + gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), + RecvRankCountCumSum.data_ptr(), recvRankIndices.data_ptr(), tokenCount, maxTokenCountPerRank, topK, + expertCount, slotCount, epRank, epSize, stream); + + return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices, + RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics); +} + TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("moe_comm_prepare_indices", &moeCommPrepareIndicesOp); m.def("moe_local_gather", &moeLocalGatherOp); m.def("moe_comm", &moeCommOp); m.def("set_moe_max_usable_sm_count", static_cast(&setMaxUsableSmCount)); m.def("get_moe_commworkspace_size_per_rank", &getWorkspaceSizePerRank); + m.def("get_moe_prepare_workspace_size_per_rank", &getPrepareWorkspaceSizePerRank); + m.def("moe_prepare", &moePrepareOp); } diff --git a/csrc/trtllm_alltoall_prepare.cu b/csrc/trtllm_alltoall_prepare.cu new file mode 100644 index 000000000..8c115fcc4 --- /dev/null +++ b/csrc/trtllm_alltoall_prepare.cu @@ -0,0 +1,686 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flashinfer/comm/trtllm_alltoall_prepare.cuh" + +#include +#include +#include +#include + +#include "flashinfer/exception.h" +#include "flashinfer/utils.cuh" + +// Local definition to avoid multiple definition issues from trtllm_alltoall.cuh +static int getMultiProcessorCount() { + int device_id; + int multi_processor_count; + FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); + return multi_processor_count; +} + +namespace cg = cooperative_groups; + +namespace flashinfer::trtllm_alltoall +{ + +namespace moe_prepare +{ + +__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val) +{ + asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory"); +} + +__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t volatile* ptr) +{ + uint64_t ret; + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr) +{ + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +class StepCommunicatorBase +{ +public: + static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo); + + __device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo) + : fifoConnInfo(fifoConnInfo) + , localCachedHead(0) + , localCachedTail(0) + { + } + + __forceinline__ __device__ void reset() + { + fifoConnInfo->head = 0; + fifoConnInfo->tail = 0; + } + + __forceinline__ __device__ void releaseSendStep() + { + localCachedHead += 1; + st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead)); + } + + __forceinline__ __device__ void releaseRecvStep() + { + localCachedTail += 1; + st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail)); + } + + __forceinline__ __device__ uint64_t acquireTail() + { + uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail)); + localCachedTail = tail; + return tail; + } + + __forceinline__ __device__ uint64_t acquireHead() + { + uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head)); + localCachedHead = head; + return head; + } + + __forceinline__ __device__ int acquireNewSendStep() + { + + int64_t tail; + do + { + tail = acquireTail(); + } while (localCachedHead >= tail + STEP_DEPTH); + // depth = 2, head = 1, tail = 0 , ok + // depth = 2, head = 2, tail = 0, should wait + + return localCachedHead % STEP_DEPTH; + } + + __forceinline__ __device__ int acquireNewRecvStep() + { + int64_t head = 0; + do + { + head = acquireHead(); + } while (localCachedTail >= head); + + return localCachedTail % STEP_DEPTH; + } + +public: + MoeCommFifoConnInfo* fifoConnInfo; + uint64_t localCachedHead; + uint64_t localCachedTail; + int rank; + int targetRank; +}; + +// Use MoeCommFifoConnInfo as media to transfer a counter number. +// Use the "head" field as flag. +// Use the "tail" field to transfer the counter number. +class CounterCommunicator +{ +public: + __device__ __inline__ CounterCommunicator(MoeCommFifoConnInfo* fifoConnInfo) + : fifoConnInfo(fifoConnInfo) + { + } + + __forceinline__ __device__ void releaseValue(uint64_t value) + { + // Avoid block on 0 + st_release_sys_global(&(fifoConnInfo->count), value + 1); + } + + __forceinline__ __device__ uint64_t acquireValue() + { + uint64_t localCount = 0; + do + { + localCount = ld_acquire_sys_global(&(fifoConnInfo->count)); + } while (localCount == 0); + + fifoConnInfo->count = 0; // reset the count + + return localCount - 1; + } + +protected: + MoeCommFifoConnInfo* fifoConnInfo; +}; + +template +__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount, + int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace, + int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize) +{ + cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); + int laneInTile = tile.thread_rank(); + int tileId = threadIdx.x / kThreadsGroupSize; + int tileCountPerBlock = blockDim.x / kThreadsGroupSize; + int expertCountPerRank = expertCount / epSize; + if (threadIdx.x == 0) + { + *sharedSendRecvRankCount = 0; + } + __syncthreads(); + int targetRankId = blockIdx.x; + int readRankTokenCount = tokenCount; + if (targetRankId >= epSize) + { + return; + } + + int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank; + int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank; + + for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock) + { + int expertRankId = laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize; + bool rankMatched = (expertRankId == targetRankId); + bool hasRankMatched = tile.any(rankMatched); + int mask = tile.ballot(rankMatched); + int firstMatchLane = __ffs(mask) - 1; // only valid if hasRankMatched is true + if (hasRankMatched && laneInTile == 0) + { + int index = atomicAdd_block(sharedSendRecvRankCount, 1); + localSendIndice[index] = i; + localBackwardIndice[index] = i * topK + firstMatchLane; + } + tile.sync(); + } + __syncthreads(); + if (threadIdx.x == 0) + { + CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1)); + int count = *(sharedSendRecvRankCount); + // printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId); + counter.releaseValue(uint64_t(count)); + *(sendCounts + targetRankId) = count; + } +} + +__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase, + MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount) +{ + int rankOffset = threadIdx.x / THREADS_PER_PIPELINE; + if (rankOffset >= PIPELINE_PER_CTA) + { + return; + } + int* sharedCountsThisRank = sharedCountsBase + rankOffset; + int targetRankId = (blockIdx.x - rankCount) * PIPELINE_PER_CTA + rankOffset; + if (targetRankId >= rankCount) + { + return; + } + int unitId = threadIdx.x % UNIT_PER_PIPELINE; + cg::thread_block_tile rankTile + = cg::tiled_partition(cg::this_thread_block()); + int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank; + int rankRecvCount; + if (rankTile.thread_rank() == 0) + { + CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1)); + rankRecvCount = int(counter.acquireValue()); + // printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId); + *(recvCounts + targetRankId) = rankRecvCount; + *(sharedCountsThisRank) = rankRecvCount; + } + rankTile.sync(); + + rankRecvCount = *(sharedCountsThisRank); + for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE) + { + *(localRecvIndice + tokenId) = tokenId; + } +} + +template +__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, + int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount) +{ + __shared__ int sharedCounts[PIPELINE_PER_CTA]; + bool isSender = blockIdx.x < rankCount; + if (isSender) + { + computeCountAndSend(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace, + backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount); + } + else + { + recvCount( + recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount); + } +} + +__global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, + int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int maxTokenCountPerRank) +{ + int targetRankId = blockIdx.x; + if (blockIdx.y == 0) + { + // sendIndice and backwardIndice CTA + int startIndex = targetRankId == 0 ? 0 : sendCountsCumsum[targetRankId - 1]; + int endIndex = sendCountsCumsum[targetRankId]; + int count = endIndex - startIndex; + int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank; + int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank; + for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) + { + gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx]; + gatherBackwardIndice[startIndex + localIdx] = localBackwardIndice[localIdx]; + } + } + else + { + // recvIndice CTA + int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1]; + int endIndex = recvCountsCumsum[targetRankId]; + int count = endIndex - startIndex; + for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) + { + gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx; + } + } +} + +__global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount) +{ + int* inputOutputPtr = blockIdx.x == 0 ? sendCountsCumsum : recvCountsCumsum; + + // Use 2 block to comuteCumsum + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int tid = threadIdx.x; + int threadData = tid < rankCount ? inputOutputPtr[tid] : 0; + int count = threadData; + __syncthreads(); + + BlockScan(temp_storage).InclusiveSum(threadData, threadData); + if (tid < rankCount) + { + inputOutputPtr[tid] = threadData; + // printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid, + // threadData, count); + } +} + +template +class PacketPipeline +{ +public: + __device__ __inline__ PacketPipeline( + void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender) + : bufferBase(bufferBase) + , stepCommunicator(stepCommunicator) + , shared_new_step(sharedNewStepPtr) + { + step = 0; + needRelease = false; + packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1; + } + + __device__ __forceinline__ void* getFirstSendPacket() + { + return bufferBase; + } + + __device__ __inline__ void* finishSendPacket(bool acquireNewStep) + { + + packetId++; + if (packetId < PipelineConfig::PACKET_PER_STEP) + { + return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE + + packetId * PipelineConfig::PACKET_SIZE + : nullptr; + } + + __syncthreads(); + if (threadIdx.x == 0) + { + stepCommunicator->releaseSendStep(); + if (acquireNewStep) + { + step = stepCommunicator->acquireNewSendStep(); + *(shared_new_step) = step; + } + } + __syncthreads(); + + if (acquireNewStep) + { + step = *(shared_new_step); + packetId = 0; + return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; + } + + return nullptr; + } + + __device__ __forceinline__ void* sendFinalize() + { + if (packetId > 0 && threadIdx.x == 0) + { + stepCommunicator->releaseSendStep(); + } + } + + __device__ __inline__ void* getNewRecvPacket() + { + packetId++; + if (packetId < PipelineConfig::PACKET_PER_STEP) + { + return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE + + packetId * PipelineConfig::PACKET_SIZE; + } + + __syncthreads(); + if (threadIdx.x == 0) + { + if (needRelease) + { + stepCommunicator->releaseRecvStep(); + } + step = stepCommunicator->acquireNewRecvStep(); + needRelease = true; + *(shared_new_step) = step; + } + __syncthreads(); + packetId = 0; + step = *(shared_new_step); + void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; + + return packetPtr; + } + + __device__ __forceinline__ void reset() + { + if (threadIdx.x == 0) + { + stepCommunicator->reset(); + } + } + + void* bufferBase; + StepCommunicatorBase* stepCommunicator; + int step; + int packetId; + bool needRelease; + int* shared_new_step; +}; + +template +__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, + int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, + int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, + int topK, int expertCount, int slotCount, int rankId, int rankCount) +{ + bool isSender = (blockIdx.y == 0); + int targetRankId = blockIdx.x; + int slotCountPerRank = slotCount / rankCount; + int groupSize = topK / PipelineConfig::UNIT_SIZE; + + __shared__ int sharedNewStep; + __align__(16) int experts[PipelineConfig::UNIT_SIZE]; + __align__(16) float scales[PipelineConfig::UNIT_SIZE]; + + uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1)); + StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1)); + PacketPipeline pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender); + + if (isSender) + { + int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1); + int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum; + int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE; + + void* packPtr = pipeline.getFirstSendPacket(); + int indexBase = 0; + int staticCopyBase = 0; + bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0); + while (acquireNewStep) + { + if (threadIdx.x < UNIT_PER_ITER) + { + int index = indexBase + threadIdx.x; + int groupId = index % groupSize; + if (index < unitCount) + { + int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize)); + *((ExpertType*) (experts)) + = *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + +#pragma unroll + for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++) + { + int expertId = experts[j]; + if (expertId / slotCountPerRank != targetRankId) + { + experts[j] = slotCount; + } + } + + int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts)); + if (sendScales != nullptr) + { + *((ScaleType*) (scales)) + = *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET); + float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales)); + } + } + } + else if (localExpertStatics != nullptr) + { + int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; + if (staticCopyBase + staticCopyIdx * 4 < expertCount) + { + int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET); + int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4); + *(staticBasePtr + staticCopyIdx) = staticData; + } + } + + indexBase += UNIT_PER_ITER; + staticCopyBase += STATIC_COPY_PER_ITER * 4; + acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount; + packPtr = pipeline.finishSendPacket(acquireNewStep); + } + + pipeline.sendFinalize(); + } + else + { + int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1); + int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum; + int recvUnitCount = recvTokenCount * groupSize; + + int unitIdBase = 0; + int staticCopyBase = 0; + while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount)) + { + void* packetPtr = pipeline.getNewRecvPacket(); + int packetUnitCount + = unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase; + packetUnitCount = max(packetUnitCount, 0); + if (threadIdx.x < UNIT_PER_ITER) + { + if (threadIdx.x < packetUnitCount) + { + int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize; + int groupId = (unitIdBase + threadIdx.x) % groupSize; + int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr)); + ExpertType* dstExpertsPtr + = (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + *dstExpertsPtr = *((ExpertType*) (experts)); + + if (recvScales != nullptr) + { + float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET); + float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr)); + ScaleType* dstScalesPtr + = (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + *dstScalesPtr = *((ScaleType*) (scales)); + } + } + } + else if (localExpertStatics != nullptr) + { + int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; + if (staticCopyBase + staticCopyIdx * 4 < expertCount) + { + int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET); + int4 staticData = *(staticBasePtr + staticCopyIdx); + *(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4) + = staticData; + } + } + + unitIdBase += packetUnitCount; + staticCopyBase += STATIC_COPY_PER_ITER * 4; + } + + pipeline.reset(); + } +} + +__global__ void memsetExpertIdsDevice( + int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, int rankCount) +{ + int maxTokenCount = maxTokenCountPerRank * rankCount; + int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK; + i += gridDim.x * blockDim.x) + { + *(expertIds + i + totalRecvTokenCount * topK) = slotCount; + } +} + +void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, + int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream) +{ + // first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive + int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA; + int block_size = 1024; + dim3 block(block_size); + dim3 grid(grid_x); + FLASHINFER_CHECK(topK >= 1 && topK <= 32, "Only 1 <= topK <= 32 is supported now."); + auto* kernelFn = computeCountAndIndiceDevice<1>; + if (topK > 16) + { + kernelFn = computeCountAndIndiceDevice<32>; + } + else if (topK > 8) + { + kernelFn = computeCountAndIndiceDevice<16>; + } + else if (topK > 4) + { + kernelFn = computeCountAndIndiceDevice<8>; + } + else if (topK > 2) + { + kernelFn = computeCountAndIndiceDevice<4>; + } + else if (topK > 1) + { + kernelFn = computeCountAndIndiceDevice<2>; + } + kernelFn<<>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace, + recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount); +} + +void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream) +{ + int block_size = CUMSUM_THREADS_PER_BLOCK; + dim3 block(block_size); + dim3 grid(2); + computeCumsumDevice<<>>(sendCountsCumsum, recvCountsCumsum, rankId, rankCount); +} + +void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, + int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, + int maxTokenCountPerRank, cudaStream_t stream) +{ + dim3 block(512); + dim3 grid(rankCount, 2); + moveIndiceDevice<<>>(sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice, + backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank); +} + +void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, + int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, + int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, + int slotCount, int rankId, int rankCount, cudaStream_t stream) +{ + int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER; + dim3 block(block_size); + dim3 grid(rankCount, 2); + + if (topK % 4 == 0) + { + using PipelineConfig = PipelineConfig<4, 16>; + static_assert( + PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, + "FIFO size is too small"); + allToAllMetadataDevice<<>>(sendExperts, recvExperts, + sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, + localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, + slotCount, rankId, rankCount); + } + else + { + using PipelineConfig = PipelineConfig<1, 64>; + static_assert( + PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, + "FIFO size is too small"); + allToAllMetadataDevice<<>>(sendExperts, recvExperts, + sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, + localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, + slotCount, rankId, rankCount); + } + + int smCount = getMultiProcessorCount(); + memsetExpertIdsDevice<<>>( + recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount); +} + +size_t getMoePrepareWorkspaceSize(int epSize) +{ + return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize; +} + +} // namespace moe_prepare + +} // namespace flashinfer::trtllm_alltoall diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index 595a84990..f374be806 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -34,6 +34,7 @@ def gen_comm_alltoall_module() -> JitSpec: "comm", [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall_prepare.cu", ], ) @@ -184,12 +185,52 @@ def get_moe_commworkspace_size_per_rank( ) -> int: return module.get_moe_commworkspace_size_per_rank(ep_size) + @register_custom_op( + "flashinfer::get_moe_prepare_workspace_size_per_rank", + mutates_args=[], + ) + def get_moe_prepare_workspace_size_per_rank( + ep_size: int, + ) -> int: + return module.get_moe_prepare_workspace_size_per_rank(ep_size) + + @register_custom_op( + "flashinfer::moe_prepare", + mutates_args=[], + ) + def moe_prepare( + experts_ids: torch.Tensor, + scales: Optional[torch.Tensor], + experts_statics: Optional[torch.Tensor], + workspace: torch.Tensor, + max_token_count_per_rank: int, + ep_rank: int, + ep_size: int, + expert_count: int, + slot_count: int, + top_k: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return module.moe_prepare( + experts_ids, + scales, + experts_statics, + workspace, + max_token_count_per_rank, + ep_rank, + ep_size, + expert_count, + slot_count, + top_k, + ) + return SimpleNamespace( moe_comm_prepare_indices=moe_comm_prepare_indices, moe_local_gather=moe_local_gather, moe_comm=moe_comm, set_moe_max_usable_sm_count=set_moe_max_usable_sm_count, get_moe_commworkspace_size_per_rank=get_moe_commworkspace_size_per_rank, + get_moe_prepare_workspace_size_per_rank=get_moe_prepare_workspace_size_per_rank, + moe_prepare=moe_prepare, ) @@ -278,6 +319,35 @@ def get_moe_commworkspace_size_per_rank( ) -> int: return get_comm_alltoall_module().get_moe_commworkspace_size_per_rank(ep_size) +def get_moe_prepare_workspace_size_per_rank( + ep_size: int, +) -> int: + return get_comm_alltoall_module().get_moe_prepare_workspace_size_per_rank(ep_size) + +def moe_prepare( + experts_ids: torch.Tensor, + scales: Optional[torch.Tensor], + experts_statics: Optional[torch.Tensor], + workspace: torch.Tensor, + max_token_count_per_rank: int, + ep_rank: int, + ep_size: int, + expert_count: int, + slot_count: int, + top_k: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return get_comm_alltoall_module().moe_prepare( + experts_ids, + scales, + experts_statics, + workspace, + max_token_count_per_rank, + ep_rank, + ep_size, + expert_count, + slot_count, + top_k, + ) @dataclass class MoEAlltoallInfo: @@ -292,7 +362,9 @@ class MoEAlltoallInfo: class MnnvlMoe: moe_workspace: MnnvlMemory = None + moe_prepare_workspace: MnnvlMemory = None moe_workspace_tensor: torch.Tensor = None + moe_prepare_workspace_tensor: torch.Tensor = None moe_mapping: Mapping = None @staticmethod @@ -309,6 +381,20 @@ def get_moe_workspaces(mapping: Mapping): ) return MnnvlMoe.moe_workspace_tensor + @staticmethod + def get_moe_prepare_workspace(mapping: Mapping): + if MnnvlMoe.moe_prepare_workspace_tensor is not None: + assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now" + return MnnvlMoe.moe_prepare_workspace_tensor + workspace_size_per_rank = get_moe_prepare_workspace_size_per_rank( + mapping.tp_size + ) + MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank) + MnnvlMoe.moe_prepare_workspace_tensor = ( + MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64) + ) + return MnnvlMoe.moe_prepare_workspace_tensor + @staticmethod def compute_target_rank_id( token_selected_experts: torch.Tensor, expert_count: int, ep_size: int @@ -320,6 +406,70 @@ def compute_target_rank_id( token_target_rank_ids = token_selected_experts // expert_per_rank return token_target_rank_ids + @staticmethod + def mnnvl_moe_alltoallv_prepare_without_allgather( + expert_ids: torch.Tensor, + scales: torch.Tensor, + expert_statics: Optional[torch.Tensor], + workspace: torch.Tensor, + max_token_count_per_rank: int, + ep_rank: int, + ep_size: int, + expert_count: int, + slot_count: int, + top_k: int, + ): + ( + prepared_local_experts, + prepared_local_scales, + local_send_rank_count_cumsum, + local_send_rank_indices, + local_recv_rank_count_cumsum, + local_recv_rank_indices, + backward_local_recv_rank_indices, + gathered_expert_statics, + ) = moe_prepare( + expert_ids, + scales, + expert_statics, + workspace, + max_token_count_per_rank, + ep_rank, + ep_size, + expert_count, + slot_count, + top_k, + ) + + local_token_allocation_count = max_token_count_per_rank * ep_size + # Looks like we don't need this. + local_gather_indices = None + + alltoall_info = MoEAlltoallInfo( + local_gather_indices, + local_send_rank_count_cumsum, + local_send_rank_indices, + local_recv_rank_count_cumsum, + local_recv_rank_indices, + backward_local_recv_rank_indices, + local_token_allocation_count, + ) + + return alltoall_info, prepared_local_experts, prepared_local_scales, gathered_expert_statics + + @staticmethod + def mnnvl_moe_expert_static_allgather( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + expert_count: int, + ): + gathered_expert_ids = torch.ops.trtllm.mnnvl_moe_expert_static_allgather( + expert_ids, workspace, ep_rank, ep_size, expert_count + ) + return gathered_expert_ids + @staticmethod def mnnvl_moe_alltoallv_prepare( gathered_target_rank_ids: torch.Tensor, diff --git a/include/flashinfer/comm/trtllm_alltoall_prepare.cuh b/include/flashinfer/comm/trtllm_alltoall_prepare.cuh new file mode 100644 index 000000000..40b1cd51f --- /dev/null +++ b/include/flashinfer/comm/trtllm_alltoall_prepare.cuh @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define DEBUG_PIPELINE 0 + +namespace flashinfer::trtllm_alltoall +{ + +namespace moe_prepare +{ + +#define STEP_DEPTH 2 +#define THREADS_PER_UNIT 1 +#define UNIT_PER_PIPELINE 128 +#define PIPELINE_PER_CTA 4 +#define EXPERT_BYTES_PER_UNIT 32 +#define SCALE_BYTES_PER_UNIT 32 +#define UNIT_COUNT_PER_PACKET 1024 +#define BYTES_COUNTER 8 +#define CUMSUM_THREADS_PER_BLOCK 128 + +#define UNIT_PER_ITER 256 +#define STATIC_COPY_PER_ITER 128 + +static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE; +static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA; + +template +struct PipelineConfig +{ + static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT; + static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT; + static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); + static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int); + static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); + static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int); + static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8); +}; + +// 1MB FIFO size +static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8; + +#ifdef __CUDACC__ +#define ALIGN_256 __align__(256) +#else +#define ALIGN_256 alignas(256) +#endif + +struct ALIGN_256 MoeCommFifoConnInfo +{ + volatile uint64_t head; // write position + volatile uint64_t tail; // read position + volatile uint64_t count; // for counter +}; + +struct MoeCommWorkspace +{ + uint64_t* workspacePtr; + size_t rankStrideInU64; +#ifdef __CUDACC__ + __inline__ __device__ uint64_t* getFifoBasePtr( + bool isSender, int epRank, int peerRank, int channel, int channelCount) const + { + // fifo itself is in receiver's side. + if (isSender) + { + return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64; + } + else + { + return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64; + } + } + + __inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo( + bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const + { + // fifoInfo is in sender's side. + uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize; + int strideIndice = isSender ? epRank : peerRank; + int fifoInfoIndice = isSender ? peerRank : epRank; + fifoInfoPtrU64 += strideIndice * rankStrideInU64; + MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64; + MoeCommFifoConnInfo* result = fifoInfoPtr + fifoInfoIndice * channelCount + channel; + return result; + } + +#endif +}; + +void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, + int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream); + +void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream); + +void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, + int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, + int maxTokenCountPerRank, cudaStream_t stream); + +void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, + int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, + int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, + int slotCount, int rankId, int rankCount, cudaStream_t stream); + +size_t getMoePrepareWorkspaceSize(int epSize); + +} // namespace moe_prepare + +} // namespace flashinfer::trtllm_alltoall diff --git a/tests/test_trtllm_alltoall.py b/tests/test_trtllm_alltoall.py index a219a7ffa..855c4d8e6 100644 --- a/tests/test_trtllm_alltoall.py +++ b/tests/test_trtllm_alltoall.py @@ -509,6 +509,296 @@ def test_moe_local_gather( assert torch.equal(local_expert_ids, ref_local_expert_ids) assert torch.equal(local_scales, ref_local_scales) +@pytest.mark.parametrize( + "ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank", [ + (0, 2, 16, 20, 8, 512), + (0, 2, 16, 16, 3, 300), + (0, 4, 20, 24, 8, 4000), + (0, 8, 96, 96, 8, 1000), + (3, 8, 128, 128, 8, 1000), + (3, 8, 128, 144, 8, 1), + (0, 4, 72, 80, 4, 2256), + (0, 4, 72, 80, 6, 3333), + # Hang with stream count > 8 + #(0, 9, 90, 8, 100), +]) +def test_moe_alltoall_prepare(ep_rank: int, ep_size: int, + expert_count: int, slot_count: int, + top_k: int, max_token_count_per_rank: int): + torch.cuda.set_device(0) + + cpu_expert_ids_all_ranks_lists = [] + cpu_token_count_lists = [] + cpu_scales_all_ranks_lists = [] + for _ in range(ep_size): + token_count = torch.randint(max_token_count_per_rank // 2, + max_token_count_per_rank + 1, (1, ), + dtype=torch.int32, + device=torch.device('cpu')) + token_count = 1 if token_count == 0 else token_count + + token_count = max_token_count_per_rank + + cpu_expert_ids_all_ranks_lists.append( + torch.randint(0, + slot_count, (token_count, top_k), + dtype=torch.int32, + device=torch.device('cpu'))) + + cpu_scales_all_ranks_lists.append( + torch.zeros(token_count, + top_k, + dtype=torch.float32, + device=torch.device('cpu')) + 0.5) + + cpu_token_count_lists.append(token_count) + + def compute_target_rank(expert_id): + ep_per_rank = slot_count // ep_size + return expert_id // ep_per_rank + + def generate_references(): + ref_prepared_local_expert_ids = [] + ref_prepared_local_scales = [] + ref_local_send_rank_count_cumsum = [0] * ep_size + ref_local_recv_rank_count_cumsum = [0] * ep_size + ref_local_recv_rank_indices = [] + + local_token_count = cpu_token_count_lists[ep_rank] + send_token_count_to_ranks = [0] * ep_size + + # send part + for token_id in range(local_token_count): + target_set = set() + for pos in range(top_k): + expert_id = int( + cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) + target_rank_id = compute_target_rank(expert_id) + target_set.add(target_rank_id) + + for target_rank_id in target_set: + send_token_count_to_ranks[target_rank_id] += 1 + + total_send_token_count = 0 + for rank in range(ep_size): + #print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}') + base = ref_local_send_rank_count_cumsum[rank - + 1] if rank > 0 else 0 + ref_local_send_rank_count_cumsum[ + rank] = send_token_count_to_ranks[rank] + base + total_send_token_count += send_token_count_to_ranks[rank] + + ref_local_backward_send_rank_indices = [0 + ] * (total_send_token_count) + ref_local_send_rank_indices = [0] * (total_send_token_count) + + current_send_token_ids = [0] * ep_size + for token_id in range(local_token_count): + target_set = set() + for pos in range(top_k): + expert_id = int( + cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) + target_rank_id = compute_target_rank(expert_id) + if target_rank_id not in target_set: + cumsum_before = 0 if target_rank_id == 0 else ref_local_send_rank_count_cumsum[ + target_rank_id - 1] + send_index = cumsum_before + current_send_token_ids[ + target_rank_id] + ref_local_send_rank_indices[send_index] = token_id + ref_local_backward_send_rank_indices[ + send_index] = token_id * top_k + pos + current_send_token_ids[target_rank_id] += 1 + target_set.add(target_rank_id) + + # receive part + total_recv_token_count = 0 + for rank in range(ep_size): + token_count = cpu_token_count_lists[rank] + current_recv_token_count = 0 + for token_id in range(token_count): + token_is_received = False + for pos in range(top_k): + expert_id = int( + cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) + sf = cpu_scales_all_ranks_lists[rank][token_id][pos] + target_rank_id = compute_target_rank(expert_id) + if target_rank_id == ep_rank: + if not token_is_received: + token_is_received = True + ref_prepared_local_expert_ids.append( + [slot_count] * top_k) + ref_prepared_local_scales.append([0.0] * top_k) + ref_prepared_local_expert_ids[-1][pos] = expert_id + ref_prepared_local_scales[-1][pos] = sf + if token_is_received: + ref_local_recv_rank_indices.append( + total_recv_token_count) + total_recv_token_count += 1 + current_recv_token_count += 1 + ref_local_recv_rank_count_cumsum[ + rank] = current_recv_token_count if rank == 0 else ref_local_recv_rank_count_cumsum[ + rank - 1] + current_recv_token_count + + return ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count + + ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references( + ) + + cpu_experter_count_lists = [] + for rank in range(ep_size): + local_expert_count = [] + for i in range(expert_count): + local_expert_count.append(rank * expert_count + i) + cpu_experter_count_lists.append(torch.IntTensor(local_expert_count)) + + #expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda() + expert_ids_all_ranks = [ + cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size) + ] + #scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() + scales_all_ranks = [ + cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size) + ] + + experter_count_lists = [ + cpu_experter_count_lists[i].cuda() for i in range(ep_size) + ] + + cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)] + + workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank( + ep_size) + + all_workspaces = torch.zeros(ep_size, + workspace_size, + dtype=torch.uint64, + device=torch.device('cuda')) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + tllm_alltoall.moe_prepare( + expert_ids_all_ranks[0], scales_all_ranks[0], + experter_count_lists[0], all_workspaces, + max_token_count_per_rank, 0, 1, expert_count, slot_count, top_k) + stream.wait_stream(torch.cuda.current_stream()) + + # Make torch alloc tensor to avoid cuda sync + prepared_local_experts = [] + prepared_local_scales = [] + local_send_rank_count_cumsum = [] + local_send_rank_indices = [] + local_recv_rank_count_cumsum = [] + local_recv_rank_indices = [] + backward_local_recv_rank_indices = [] + for _ in range(ep_size): + prepared_local_experts.append( + torch.empty(max_token_count_per_rank * ep_size, + top_k, + dtype=torch.int32, + device=torch.device('cuda'))) + prepared_local_scales.append( + torch.empty(max_token_count_per_rank * ep_size, + top_k, + dtype=torch.float32, + device=torch.device('cuda'))) + local_send_rank_count_cumsum.append( + torch.empty(ep_size, + dtype=torch.int32, + device=torch.device('cuda'))) + local_send_rank_indices.append( + torch.empty(max_token_count_per_rank * ep_size, + dtype=torch.int32, + device=torch.device('cuda'))) + local_recv_rank_count_cumsum.append( + torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + local_recv_rank_indices.append( + torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + backward_local_recv_rank_indices.append( + torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + + prepared_local_experts = [] + prepared_local_scales = [] + local_send_rank_count_cumsum = [] + local_send_rank_indices = [] + local_recv_rank_count_cumsum = [] + local_recv_rank_indices = [] + backward_local_recv_rank_indices = [] + + # reset the workspace + all_workspaces = torch.zeros(ep_size, + workspace_size, + dtype=torch.uint64, + device=torch.device('cuda')) + + # do prepare in parallel + for rank in range(ep_size): + with torch.cuda.stream(cuda_streams_all_ranks[rank]): + if rank == ep_rank: + prepared_local_experts, prepared_local_scales, local_send_rank_count_cumsum, \ + local_send_rank_indices, local_recv_rank_count_cumsum, local_recv_rank_indices, \ + backward_local_recv_rank_indices, gathered_expert_statics\ + = tllm_alltoall.moe_prepare(expert_ids_all_ranks[rank], scales_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank, + rank, ep_size, expert_count, slot_count, top_k) + else: + tllm_alltoall.moe_prepare( + expert_ids_all_ranks[rank], scales_all_ranks[rank], + experter_count_lists[rank], all_workspaces, + max_token_count_per_rank, rank, ep_size, expert_count, + slot_count, top_k) + for rank in range(ep_size): + cuda_streams_all_ranks[rank].synchronize() + + prepared_local_experts_cpu = prepared_local_experts[: + total_recv_token_count].cpu( + ) + prepared_local_scales_cpu = prepared_local_scales[: + total_recv_token_count].cpu( + ) + for i in range(total_recv_token_count): + for j in range(top_k): + expert_id = int(prepared_local_experts_cpu[i][j]) + assert 0 <= expert_id and expert_id <= slot_count + if expert_id < slot_count: + assert compute_target_rank(expert_id) == ep_rank + scale = float(prepared_local_scales_cpu[i][j]) + assert scale > 1e-6 + + gathered_expert_statics_cpu = gathered_expert_statics.cpu() + for rank in range(ep_size): + for i in range(expert_count): + assert int(gathered_expert_statics_cpu[rank] + [i]) == rank * expert_count + i + + ref_local_send_rank_count_cumsum = torch.IntTensor( + ref_local_send_rank_count_cumsum) + assert torch.equal(local_send_rank_count_cumsum.cpu(), + ref_local_send_rank_count_cumsum) + + local_send_rank_indices = local_send_rank_indices.cpu() + backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu( + ) + for i in range(ep_size): + base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1] + for j in range(base, ref_local_send_rank_count_cumsum[i]): + token_id = local_send_rank_indices[j] + lane_id = backward_local_recv_rank_indices[j] - token_id * top_k + expert_id = int( + cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id]) + assert compute_target_rank(expert_id) == i + + ref_local_recv_rank_count_cumsum = torch.IntTensor( + ref_local_recv_rank_count_cumsum) + assert torch.equal( + local_recv_rank_count_cumsum[:ref_local_recv_rank_count_cumsum. + size(0)].cpu(), + ref_local_recv_rank_count_cumsum) + + ref_local_recv_rank_indices = torch.IntTensor( + ref_local_recv_rank_indices) + assert torch.equal( + local_recv_rank_indices[:ref_local_recv_rank_indices.size(0)].cpu(), + ref_local_recv_rank_indices) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 111e3d49e0d75cb530f830e302bcdc0ecab3e061 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 22 Aug 2025 22:20:25 +0000 Subject: [PATCH 2/3] Run precommit --- csrc/trtllm_alltoall.cu | 186 +-- csrc/trtllm_alltoall_prepare.cu | 1100 ++++++++--------- flashinfer/comm/trtllm_alltoall.py | 45 +- .../comm/trtllm_alltoall_prepare.cuh | 126 +- tests/test_trtllm_alltoall.py | 332 +++-- 5 files changed, 900 insertions(+), 889 deletions(-) diff --git a/csrc/trtllm_alltoall.cu b/csrc/trtllm_alltoall.cu index 4baaa9f37..6b1a73f61 100644 --- a/csrc/trtllm_alltoall.cu +++ b/csrc/trtllm_alltoall.cu @@ -15,13 +15,13 @@ * limitations under the License. */ -#include - #include -#include "pytorch_extension_utils.h" + +#include #include "flashinfer/comm/trtllm_alltoall.cuh" #include "flashinfer/comm/trtllm_alltoall_prepare.cuh" +#include "pytorch_extension_utils.h" using namespace flashinfer::trtllm_alltoall; @@ -220,96 +220,102 @@ void setMaxUsableSmCount(int64_t maxSmCount) { flashinfer::trtllm_alltoall::setMaxUsableSmCount(static_cast(maxSmCount)); } -int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) -{ - int epSize32 = static_cast(epSize); - return flashinfer::trtllm_alltoall::moe_prepare::getMoePrepareWorkspaceSize(epSize32); +int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) { + int epSize32 = static_cast(epSize); + return flashinfer::trtllm_alltoall::moe_prepare::getMoePrepareWorkspaceSize(epSize32); } std::tuple, at::Tensor, at::Tensor, at::Tensor, at::Tensor, - at::Tensor, c10::optional> -moePrepareOp(at::Tensor expertsIds, c10::optional scales, c10::optional expertsStatics, - at::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, - int64_t slotCount, int64_t topK) -{ - CHECK_INPUT_TYPE(expertsIds, at::ScalarType::Int); - TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4"); - TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4"); - - int64_t maxSendRanksPerToken = std::max(epSize, topK); - int64_t tokenCount = expertsIds.size(0); - - at::Tensor preparedLocalExpertIds - = at::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(at::ScalarType::Int)); - - at::Tensor sendRankCountCumSum = at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); - at::Tensor RecvRankCountCumSum = at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); - - at::Tensor gatherRecvRankIndices - = at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); - at::Tensor recvRankIndices - = at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); - - at::Tensor gatherBackwardRecvRankIndices - = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); - at::Tensor backwardRecvRankIndices - = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); - - at::Tensor gatherSendRankIndices - = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); - at::Tensor sendRankIndices - = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(at::ScalarType::Int)); - - c10::optional preparedLocalScales; - float* scalesPtr = nullptr; - float* preparedLocalScalesPtr = nullptr; - if (scales.has_value()) - { - CHECK_INPUT_TYPE(scales.value(), at::ScalarType::Float); - scalesPtr = scales->data_ptr(); - preparedLocalScales - = at::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(at::ScalarType::Float)); - preparedLocalScalesPtr = preparedLocalScales->data_ptr(); - } - - int* localExpertStaticsPtr = nullptr; - int* gatheredExpertStaticsPtr = nullptr; - c10::optional gatheredExpertStatics; - if (expertsStatics.has_value()) - { - localExpertStaticsPtr = expertsStatics.value().data_ptr(); - gatheredExpertStatics = at::empty({epSize, expertCount}, expertsIds.options().dtype(at::ScalarType::Int)); - gatheredExpertStaticsPtr = gatheredExpertStatics.value().data_ptr(); - } - - flashinfer::trtllm_alltoall::moe_prepare::MoeCommWorkspace workspace; - workspace.workspacePtr = allWorkspaces.data_ptr(); - workspace.rankStrideInU64 = allWorkspaces.stride(0); - - auto stream = at::cuda::getCurrentCUDAStream(); - - flashinfer::trtllm_alltoall::moe_prepare::computeCountAndIndice(expertsIds.data_ptr(), - sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), - backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), workspace, tokenCount, - maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); - - flashinfer::trtllm_alltoall::moe_prepare::computeCumsum( - sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), epRank, epSize, stream); - - flashinfer::trtllm_alltoall::moe_prepare::moveIndice(sendRankCountCumSum.data_ptr(), - RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), gatherSendRankIndices.data_ptr(), - backwardRecvRankIndices.data_ptr(), gatherBackwardRecvRankIndices.data_ptr(), - recvRankIndices.data_ptr(), gatherRecvRankIndices.data_ptr(), epRank, epSize, maxTokenCountPerRank, - stream); - - flashinfer::trtllm_alltoall::moe_prepare::allToAllMetadata(expertsIds.data_ptr(), - preparedLocalExpertIds.data_ptr(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr, - gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), - RecvRankCountCumSum.data_ptr(), recvRankIndices.data_ptr(), tokenCount, maxTokenCountPerRank, topK, - expertCount, slotCount, epRank, epSize, stream); - - return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices, - RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics); + at::Tensor, c10::optional> +moePrepareOp(at::Tensor expertsIds, c10::optional scales, + c10::optional expertsStatics, at::Tensor allWorkspaces, + int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, + int64_t slotCount, int64_t topK) { + CHECK_INPUT_TYPE(expertsIds, at::ScalarType::Int); + TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4"); + TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4"); + + int64_t maxSendRanksPerToken = std::max(epSize, topK); + int64_t tokenCount = expertsIds.size(0); + + at::Tensor preparedLocalExpertIds = at::empty({maxTokenCountPerRank * epSize, topK}, + expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor sendRankCountCumSum = + at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor RecvRankCountCumSum = + at::empty({epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherRecvRankIndices = + at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor recvRankIndices = + at::empty({maxTokenCountPerRank * epSize}, expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherBackwardRecvRankIndices = + at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, + expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor backwardRecvRankIndices = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, + expertsIds.options().dtype(at::ScalarType::Int)); + + at::Tensor gatherSendRankIndices = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, + expertsIds.options().dtype(at::ScalarType::Int)); + at::Tensor sendRankIndices = at::empty({maxTokenCountPerRank * maxSendRanksPerToken}, + expertsIds.options().dtype(at::ScalarType::Int)); + + c10::optional preparedLocalScales; + float* scalesPtr = nullptr; + float* preparedLocalScalesPtr = nullptr; + if (scales.has_value()) { + CHECK_INPUT_TYPE(scales.value(), at::ScalarType::Float); + scalesPtr = scales->data_ptr(); + preparedLocalScales = at::empty({maxTokenCountPerRank * epSize, topK}, + expertsIds.options().dtype(at::ScalarType::Float)); + preparedLocalScalesPtr = preparedLocalScales->data_ptr(); + } + + int* localExpertStaticsPtr = nullptr; + int* gatheredExpertStaticsPtr = nullptr; + c10::optional gatheredExpertStatics; + if (expertsStatics.has_value()) { + localExpertStaticsPtr = expertsStatics.value().data_ptr(); + gatheredExpertStatics = + at::empty({epSize, expertCount}, expertsIds.options().dtype(at::ScalarType::Int)); + gatheredExpertStaticsPtr = gatheredExpertStatics.value().data_ptr(); + } + + flashinfer::trtllm_alltoall::moe_prepare::MoeCommWorkspace workspace; + workspace.workspacePtr = allWorkspaces.data_ptr(); + workspace.rankStrideInU64 = allWorkspaces.stride(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + flashinfer::trtllm_alltoall::moe_prepare::computeCountAndIndice( + expertsIds.data_ptr(), sendRankCountCumSum.data_ptr(), + RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), + backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), workspace, + tokenCount, maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); + + flashinfer::trtllm_alltoall::moe_prepare::computeCumsum(sendRankCountCumSum.data_ptr(), + RecvRankCountCumSum.data_ptr(), + epRank, epSize, stream); + + flashinfer::trtllm_alltoall::moe_prepare::moveIndice( + sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), + sendRankIndices.data_ptr(), gatherSendRankIndices.data_ptr(), + backwardRecvRankIndices.data_ptr(), gatherBackwardRecvRankIndices.data_ptr(), + recvRankIndices.data_ptr(), gatherRecvRankIndices.data_ptr(), epRank, epSize, + maxTokenCountPerRank, stream); + + flashinfer::trtllm_alltoall::moe_prepare::allToAllMetadata( + expertsIds.data_ptr(), preparedLocalExpertIds.data_ptr(), scalesPtr, + preparedLocalScalesPtr, localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, + sendRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), + RecvRankCountCumSum.data_ptr(), recvRankIndices.data_ptr(), tokenCount, + maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream); + + return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, + gatherSendRankIndices, RecvRankCountCumSum, gatherRecvRankIndices, + gatherBackwardRecvRankIndices, gatheredExpertStatics); } TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/trtllm_alltoall_prepare.cu b/csrc/trtllm_alltoall_prepare.cu index 8c115fcc4..1b3cf1131 100644 --- a/csrc/trtllm_alltoall_prepare.cu +++ b/csrc/trtllm_alltoall_prepare.cu @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "flashinfer/comm/trtllm_alltoall_prepare.cuh" - -#include #include #include +#include + #include +#include "flashinfer/comm/trtllm_alltoall_prepare.cuh" #include "flashinfer/exception.h" #include "flashinfer/utils.cuh" @@ -29,658 +29,590 @@ static int getMultiProcessorCount() { int device_id; int multi_processor_count; FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); return multi_processor_count; } - + namespace cg = cooperative_groups; -namespace flashinfer::trtllm_alltoall -{ +namespace flashinfer::trtllm_alltoall { -namespace moe_prepare -{ +namespace moe_prepare { -__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val) -{ - asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory"); +__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val) { + asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory"); } -__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t volatile* ptr) -{ - uint64_t ret; - asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t volatile* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; } -__device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr) -{ - int ret; - asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; +__device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; } -class StepCommunicatorBase -{ -public: - static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo); - - __device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo) - : fifoConnInfo(fifoConnInfo) - , localCachedHead(0) - , localCachedTail(0) - { - } - - __forceinline__ __device__ void reset() - { - fifoConnInfo->head = 0; - fifoConnInfo->tail = 0; - } - - __forceinline__ __device__ void releaseSendStep() - { - localCachedHead += 1; - st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead)); - } - - __forceinline__ __device__ void releaseRecvStep() - { - localCachedTail += 1; - st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail)); - } - - __forceinline__ __device__ uint64_t acquireTail() - { - uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail)); - localCachedTail = tail; - return tail; - } - - __forceinline__ __device__ uint64_t acquireHead() - { - uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head)); - localCachedHead = head; - return head; - } - - __forceinline__ __device__ int acquireNewSendStep() - { - - int64_t tail; - do - { - tail = acquireTail(); - } while (localCachedHead >= tail + STEP_DEPTH); - // depth = 2, head = 1, tail = 0 , ok - // depth = 2, head = 2, tail = 0, should wait - - return localCachedHead % STEP_DEPTH; - } - - __forceinline__ __device__ int acquireNewRecvStep() - { - int64_t head = 0; - do - { - head = acquireHead(); - } while (localCachedTail >= head); - - return localCachedTail % STEP_DEPTH; - } - -public: - MoeCommFifoConnInfo* fifoConnInfo; - uint64_t localCachedHead; - uint64_t localCachedTail; - int rank; - int targetRank; +class StepCommunicatorBase { + public: + static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo); + + __device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo) + : fifoConnInfo(fifoConnInfo), localCachedHead(0), localCachedTail(0) {} + + __forceinline__ __device__ void reset() { + fifoConnInfo->head = 0; + fifoConnInfo->tail = 0; + } + + __forceinline__ __device__ void releaseSendStep() { + localCachedHead += 1; + st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead)); + } + + __forceinline__ __device__ void releaseRecvStep() { + localCachedTail += 1; + st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail)); + } + + __forceinline__ __device__ uint64_t acquireTail() { + uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail)); + localCachedTail = tail; + return tail; + } + + __forceinline__ __device__ uint64_t acquireHead() { + uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head)); + localCachedHead = head; + return head; + } + + __forceinline__ __device__ int acquireNewSendStep() { + int64_t tail; + do { + tail = acquireTail(); + } while (localCachedHead >= tail + STEP_DEPTH); + // depth = 2, head = 1, tail = 0 , ok + // depth = 2, head = 2, tail = 0, should wait + + return localCachedHead % STEP_DEPTH; + } + + __forceinline__ __device__ int acquireNewRecvStep() { + int64_t head = 0; + do { + head = acquireHead(); + } while (localCachedTail >= head); + + return localCachedTail % STEP_DEPTH; + } + + public: + MoeCommFifoConnInfo* fifoConnInfo; + uint64_t localCachedHead; + uint64_t localCachedTail; + int rank; + int targetRank; }; // Use MoeCommFifoConnInfo as media to transfer a counter number. // Use the "head" field as flag. // Use the "tail" field to transfer the counter number. -class CounterCommunicator -{ -public: - __device__ __inline__ CounterCommunicator(MoeCommFifoConnInfo* fifoConnInfo) - : fifoConnInfo(fifoConnInfo) - { - } +class CounterCommunicator { + public: + __device__ __inline__ CounterCommunicator(MoeCommFifoConnInfo* fifoConnInfo) + : fifoConnInfo(fifoConnInfo) {} - __forceinline__ __device__ void releaseValue(uint64_t value) - { - // Avoid block on 0 - st_release_sys_global(&(fifoConnInfo->count), value + 1); - } + __forceinline__ __device__ void releaseValue(uint64_t value) { + // Avoid block on 0 + st_release_sys_global(&(fifoConnInfo->count), value + 1); + } - __forceinline__ __device__ uint64_t acquireValue() - { - uint64_t localCount = 0; - do - { - localCount = ld_acquire_sys_global(&(fifoConnInfo->count)); - } while (localCount == 0); + __forceinline__ __device__ uint64_t acquireValue() { + uint64_t localCount = 0; + do { + localCount = ld_acquire_sys_global(&(fifoConnInfo->count)); + } while (localCount == 0); - fifoConnInfo->count = 0; // reset the count + fifoConnInfo->count = 0; // reset the count - return localCount - 1; - } + return localCount - 1; + } -protected: - MoeCommFifoConnInfo* fifoConnInfo; + protected: + MoeCommFifoConnInfo* fifoConnInfo; }; template -__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount, - int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace, - int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize) -{ - cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); - int laneInTile = tile.thread_rank(); - int tileId = threadIdx.x / kThreadsGroupSize; - int tileCountPerBlock = blockDim.x / kThreadsGroupSize; - int expertCountPerRank = expertCount / epSize; - if (threadIdx.x == 0) - { - *sharedSendRecvRankCount = 0; - } - __syncthreads(); - int targetRankId = blockIdx.x; - int readRankTokenCount = tokenCount; - if (targetRankId >= epSize) - { - return; - } - - int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank; - int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank; - - for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock) - { - int expertRankId = laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize; - bool rankMatched = (expertRankId == targetRankId); - bool hasRankMatched = tile.any(rankMatched); - int mask = tile.ballot(rankMatched); - int firstMatchLane = __ffs(mask) - 1; // only valid if hasRankMatched is true - if (hasRankMatched && laneInTile == 0) - { - int index = atomicAdd_block(sharedSendRecvRankCount, 1); - localSendIndice[index] = i; - localBackwardIndice[index] = i * topK + firstMatchLane; - } - tile.sync(); - } - __syncthreads(); - if (threadIdx.x == 0) - { - CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1)); - int count = *(sharedSendRecvRankCount); - // printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId); - counter.releaseValue(uint64_t(count)); - *(sendCounts + targetRankId) = count; - } +__device__ __forceinline__ void computeCountAndSend( + int* experts, int tokenCount, int* sharedSendRecvRankCount, int* sendCounts, + int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace, + int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize) { + cg::thread_block_tile tile = + cg::tiled_partition(cg::this_thread_block()); + int laneInTile = tile.thread_rank(); + int tileId = threadIdx.x / kThreadsGroupSize; + int tileCountPerBlock = blockDim.x / kThreadsGroupSize; + int expertCountPerRank = expertCount / epSize; + if (threadIdx.x == 0) { + *sharedSendRecvRankCount = 0; + } + __syncthreads(); + int targetRankId = blockIdx.x; + int readRankTokenCount = tokenCount; + if (targetRankId >= epSize) { + return; + } + + int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank; + int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank; + + for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock) { + int expertRankId = + laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize; + bool rankMatched = (expertRankId == targetRankId); + bool hasRankMatched = tile.any(rankMatched); + int mask = tile.ballot(rankMatched); + int firstMatchLane = __ffs(mask) - 1; // only valid if hasRankMatched is true + if (hasRankMatched && laneInTile == 0) { + int index = atomicAdd_block(sharedSendRecvRankCount, 1); + localSendIndice[index] = i; + localBackwardIndice[index] = i * topK + firstMatchLane; + } + tile.sync(); + } + __syncthreads(); + if (threadIdx.x == 0) { + CounterCommunicator counter( + workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1)); + int count = *(sharedSendRecvRankCount); + // printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId); + counter.releaseValue(uint64_t(count)); + *(sendCounts + targetRankId) = count; + } } -__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase, - MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount) -{ - int rankOffset = threadIdx.x / THREADS_PER_PIPELINE; - if (rankOffset >= PIPELINE_PER_CTA) - { - return; - } - int* sharedCountsThisRank = sharedCountsBase + rankOffset; - int targetRankId = (blockIdx.x - rankCount) * PIPELINE_PER_CTA + rankOffset; - if (targetRankId >= rankCount) - { - return; - } - int unitId = threadIdx.x % UNIT_PER_PIPELINE; - cg::thread_block_tile rankTile - = cg::tiled_partition(cg::this_thread_block()); - int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank; - int rankRecvCount; - if (rankTile.thread_rank() == 0) - { - CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1)); - rankRecvCount = int(counter.acquireValue()); - // printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId); - *(recvCounts + targetRankId) = rankRecvCount; - *(sharedCountsThisRank) = rankRecvCount; - } - rankTile.sync(); - - rankRecvCount = *(sharedCountsThisRank); - for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE) - { - *(localRecvIndice + tokenId) = tokenId; - } +__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, + int* sharedCountsBase, MoeCommWorkspace workspace, + int maxTokenCountPerRank, int rankId, int rankCount) { + int rankOffset = threadIdx.x / THREADS_PER_PIPELINE; + if (rankOffset >= PIPELINE_PER_CTA) { + return; + } + int* sharedCountsThisRank = sharedCountsBase + rankOffset; + int targetRankId = (blockIdx.x - rankCount) * PIPELINE_PER_CTA + rankOffset; + if (targetRankId >= rankCount) { + return; + } + int unitId = threadIdx.x % UNIT_PER_PIPELINE; + cg::thread_block_tile rankTile = + cg::tiled_partition(cg::this_thread_block()); + int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank; + int rankRecvCount; + if (rankTile.thread_rank() == 0) { + CounterCommunicator counter( + workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1)); + rankRecvCount = int(counter.acquireValue()); + // printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, + // targetRankId); + *(recvCounts + targetRankId) = rankRecvCount; + *(sharedCountsThisRank) = rankRecvCount; + } + rankTile.sync(); + + rankRecvCount = *(sharedCountsThisRank); + for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE) { + *(localRecvIndice + tokenId) = tokenId; + } } template -__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount) -{ - __shared__ int sharedCounts[PIPELINE_PER_CTA]; - bool isSender = blockIdx.x < rankCount; - if (isSender) - { - computeCountAndSend(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace, - backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount); - } - else - { - recvCount( - recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount); - } +__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, + int* sendIndiceWorkspace, int* backwardIndiceWorkspace, + int* recvIndiceWorkspace, MoeCommWorkspace workspace, + int tokenCount, int maxTokenCountPerRank, int topK, + int expertCount, int rankId, int rankCount) { + __shared__ int sharedCounts[PIPELINE_PER_CTA]; + bool isSender = blockIdx.x < rankCount; + if (isSender) { + computeCountAndSend(experts, tokenCount, &sharedCounts[0], sendCounts, + sendIndiceWorkspace, backwardIndiceWorkspace, workspace, + maxTokenCountPerRank, expertCount, topK, rankId, + rankCount); + } else { + recvCount(recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, + rankId, rankCount); + } } -__global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, - int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int maxTokenCountPerRank) -{ - int targetRankId = blockIdx.x; - if (blockIdx.y == 0) - { - // sendIndice and backwardIndice CTA - int startIndex = targetRankId == 0 ? 0 : sendCountsCumsum[targetRankId - 1]; - int endIndex = sendCountsCumsum[targetRankId]; - int count = endIndex - startIndex; - int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank; - int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank; - for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) - { - gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx]; - gatherBackwardIndice[startIndex + localIdx] = localBackwardIndice[localIdx]; - } - } - else - { - // recvIndice CTA - int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1]; - int endIndex = recvCountsCumsum[targetRankId]; - int count = endIndex - startIndex; - for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) - { - gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx; - } - } +__global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, + int* gatherSendIndice, int* backwardIndice, + int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, + int maxTokenCountPerRank) { + int targetRankId = blockIdx.x; + if (blockIdx.y == 0) { + // sendIndice and backwardIndice CTA + int startIndex = targetRankId == 0 ? 0 : sendCountsCumsum[targetRankId - 1]; + int endIndex = sendCountsCumsum[targetRankId]; + int count = endIndex - startIndex; + int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank; + int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank; + for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) { + gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx]; + gatherBackwardIndice[startIndex + localIdx] = localBackwardIndice[localIdx]; + } + } else { + // recvIndice CTA + int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1]; + int endIndex = recvCountsCumsum[targetRankId]; + int count = endIndex - startIndex; + for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) { + gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx; + } + } } -__global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount) -{ - int* inputOutputPtr = blockIdx.x == 0 ? sendCountsCumsum : recvCountsCumsum; - - // Use 2 block to comuteCumsum - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - int tid = threadIdx.x; - int threadData = tid < rankCount ? inputOutputPtr[tid] : 0; - int count = threadData; - __syncthreads(); - - BlockScan(temp_storage).InclusiveSum(threadData, threadData); - if (tid < rankCount) - { - inputOutputPtr[tid] = threadData; - // printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid, - // threadData, count); - } +__global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, + int rankCount) { + int* inputOutputPtr = blockIdx.x == 0 ? sendCountsCumsum : recvCountsCumsum; + + // Use 2 block to comuteCumsum + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int tid = threadIdx.x; + int threadData = tid < rankCount ? inputOutputPtr[tid] : 0; + int count = threadData; + __syncthreads(); + + BlockScan(temp_storage).InclusiveSum(threadData, threadData); + if (tid < rankCount) { + inputOutputPtr[tid] = threadData; + // printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, + // rankId, tid, threadData, count); + } } template -class PacketPipeline -{ -public: - __device__ __inline__ PacketPipeline( - void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender) - : bufferBase(bufferBase) - , stepCommunicator(stepCommunicator) - , shared_new_step(sharedNewStepPtr) - { - step = 0; - needRelease = false; - packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1; - } - - __device__ __forceinline__ void* getFirstSendPacket() - { - return bufferBase; +class PacketPipeline { + public: + __device__ __inline__ PacketPipeline(void* bufferBase, StepCommunicatorBase* stepCommunicator, + int* sharedNewStepPtr, bool isSender) + : bufferBase(bufferBase), + stepCommunicator(stepCommunicator), + shared_new_step(sharedNewStepPtr) { + step = 0; + needRelease = false; + packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1; + } + + __device__ __forceinline__ void* getFirstSendPacket() { return bufferBase; } + + __device__ __inline__ void* finishSendPacket(bool acquireNewStep) { + packetId++; + if (packetId < PipelineConfig::PACKET_PER_STEP) { + return acquireNewStep + ? bufferBase + + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE + + packetId * PipelineConfig::PACKET_SIZE + : nullptr; } - __device__ __inline__ void* finishSendPacket(bool acquireNewStep) - { - - packetId++; - if (packetId < PipelineConfig::PACKET_PER_STEP) - { - return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE - + packetId * PipelineConfig::PACKET_SIZE - : nullptr; - } - - __syncthreads(); - if (threadIdx.x == 0) - { - stepCommunicator->releaseSendStep(); - if (acquireNewStep) - { - step = stepCommunicator->acquireNewSendStep(); - *(shared_new_step) = step; - } - } - __syncthreads(); - - if (acquireNewStep) - { - step = *(shared_new_step); - packetId = 0; - return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; - } - - return nullptr; + __syncthreads(); + if (threadIdx.x == 0) { + stepCommunicator->releaseSendStep(); + if (acquireNewStep) { + step = stepCommunicator->acquireNewSendStep(); + *(shared_new_step) = step; + } } + __syncthreads(); - __device__ __forceinline__ void* sendFinalize() - { - if (packetId > 0 && threadIdx.x == 0) - { - stepCommunicator->releaseSendStep(); - } + if (acquireNewStep) { + step = *(shared_new_step); + packetId = 0; + return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; } - __device__ __inline__ void* getNewRecvPacket() - { - packetId++; - if (packetId < PipelineConfig::PACKET_PER_STEP) - { - return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE - + packetId * PipelineConfig::PACKET_SIZE; - } - - __syncthreads(); - if (threadIdx.x == 0) - { - if (needRelease) - { - stepCommunicator->releaseRecvStep(); - } - step = stepCommunicator->acquireNewRecvStep(); - needRelease = true; - *(shared_new_step) = step; - } - __syncthreads(); - packetId = 0; - step = *(shared_new_step); - void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; + return nullptr; + } - return packetPtr; + __device__ __forceinline__ void* sendFinalize() { + if (packetId > 0 && threadIdx.x == 0) { + stepCommunicator->releaseSendStep(); } + } - __device__ __forceinline__ void reset() - { - if (threadIdx.x == 0) - { - stepCommunicator->reset(); - } + __device__ __inline__ void* getNewRecvPacket() { + packetId++; + if (packetId < PipelineConfig::PACKET_PER_STEP) { + return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE + + packetId * PipelineConfig::PACKET_SIZE; } - void* bufferBase; - StepCommunicatorBase* stepCommunicator; - int step; - int packetId; - bool needRelease; - int* shared_new_step; + __syncthreads(); + if (threadIdx.x == 0) { + if (needRelease) { + stepCommunicator->releaseRecvStep(); + } + step = stepCommunicator->acquireNewRecvStep(); + needRelease = true; + *(shared_new_step) = step; + } + __syncthreads(); + packetId = 0; + step = *(shared_new_step); + void* packetPtr = + bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; + + return packetPtr; + } + + __device__ __forceinline__ void reset() { + if (threadIdx.x == 0) { + stepCommunicator->reset(); + } + } + + void* bufferBase; + StepCommunicatorBase* stepCommunicator; + int step; + int packetId; + bool needRelease; + int* shared_new_step; }; template -__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, - int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, - int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, - int topK, int expertCount, int slotCount, int rankId, int rankCount) -{ - bool isSender = (blockIdx.y == 0); - int targetRankId = blockIdx.x; - int slotCountPerRank = slotCount / rankCount; - int groupSize = topK / PipelineConfig::UNIT_SIZE; - - __shared__ int sharedNewStep; - __align__(16) int experts[PipelineConfig::UNIT_SIZE]; - __align__(16) float scales[PipelineConfig::UNIT_SIZE]; - - uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1)); - StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1)); - PacketPipeline pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender); - - if (isSender) - { - int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1); - int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum; - int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE; - - void* packPtr = pipeline.getFirstSendPacket(); - int indexBase = 0; - int staticCopyBase = 0; - bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0); - while (acquireNewStep) - { - if (threadIdx.x < UNIT_PER_ITER) - { - int index = indexBase + threadIdx.x; - int groupId = index % groupSize; - if (index < unitCount) - { - int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize)); - *((ExpertType*) (experts)) - = *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); +__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, + float* recvScales, int* localExpertStatics, + int* gatheredExpertStatics, MoeCommWorkspace workspace, + int* sendCountsCumsum, int* localSendIndice, + int* recvCountsCumsum, int* localRecvIndice, int tokenCount, + int maxTokenCountPerRank, int topK, int expertCount, + int slotCount, int rankId, int rankCount) { + bool isSender = (blockIdx.y == 0); + int targetRankId = blockIdx.x; + int slotCountPerRank = slotCount / rankCount; + int groupSize = topK / PipelineConfig::UNIT_SIZE; + + __shared__ int sharedNewStep; + __align__(16) int experts[PipelineConfig::UNIT_SIZE]; + __align__(16) float scales[PipelineConfig::UNIT_SIZE]; + + uint8_t* bufferBase = (uint8_t*)(workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1)); + StepCommunicatorBase stepCommunicator( + workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1)); + PacketPipeline pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender); + + if (isSender) { + int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1); + int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum; + int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE; + + void* packPtr = pipeline.getFirstSendPacket(); + int indexBase = 0; + int staticCopyBase = 0; + bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0); + while (acquireNewStep) { + if (threadIdx.x < UNIT_PER_ITER) { + int index = indexBase + threadIdx.x; + int groupId = index % groupSize; + if (index < unitCount) { + int tokenId = + *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize)); + *((ExpertType*)(experts)) = + *(ExpertType*)(sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); #pragma unroll - for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++) - { - int expertId = experts[j]; - if (expertId / slotCountPerRank != targetRankId) - { - experts[j] = slotCount; - } - } - - int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts)); - if (sendScales != nullptr) - { - *((ScaleType*) (scales)) - = *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET); - float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales)); - } - } - } - else if (localExpertStatics != nullptr) - { - int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; - if (staticCopyBase + staticCopyIdx * 4 < expertCount) - { - int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET); - int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4); - *(staticBasePtr + staticCopyIdx) = staticData; - } + for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++) { + int expertId = experts[j]; + if (expertId / slotCountPerRank != targetRankId) { + experts[j] = slotCount; } - - indexBase += UNIT_PER_ITER; - staticCopyBase += STATIC_COPY_PER_ITER * 4; - acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount; - packPtr = pipeline.finishSendPacket(acquireNewStep); + } + + int* expertsPtr = (int*)(packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ExpertType*)(expertsPtr)) = *((ExpertType*)(experts)); + if (sendScales != nullptr) { + *((ScaleType*)(scales)) = + *(ScaleType*)(sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + float* scaleBasePtr = (float*)(packPtr + PipelineConfig::SCALE_OFFSET); + float* scalesPtr = (float*)(scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ScaleType*)(scalesPtr)) = *((ScaleType*)(scales)); + } } - - pipeline.sendFinalize(); - } - else - { - int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1); - int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum; - int recvUnitCount = recvTokenCount * groupSize; - - int unitIdBase = 0; - int staticCopyBase = 0; - while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount)) - { - void* packetPtr = pipeline.getNewRecvPacket(); - int packetUnitCount - = unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase; - packetUnitCount = max(packetUnitCount, 0); - if (threadIdx.x < UNIT_PER_ITER) - { - if (threadIdx.x < packetUnitCount) - { - int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize; - int groupId = (unitIdBase + threadIdx.x) % groupSize; - int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr)); - ExpertType* dstExpertsPtr - = (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - *dstExpertsPtr = *((ExpertType*) (experts)); - - if (recvScales != nullptr) - { - float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET); - float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr)); - ScaleType* dstScalesPtr - = (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - *dstScalesPtr = *((ScaleType*) (scales)); - } - } - } - else if (localExpertStatics != nullptr) - { - int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; - if (staticCopyBase + staticCopyIdx * 4 < expertCount) - { - int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET); - int4 staticData = *(staticBasePtr + staticCopyIdx); - *(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4) - = staticData; - } - } - - unitIdBase += packetUnitCount; - staticCopyBase += STATIC_COPY_PER_ITER * 4; + } else if (localExpertStatics != nullptr) { + int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; + if (staticCopyBase + staticCopyIdx * 4 < expertCount) { + int4* staticBasePtr = (int4*)(packPtr + PipelineConfig::STATIC_COPY_OFFSET); + int4 staticData = *(int4*)(localExpertStatics + staticCopyBase + staticCopyIdx * 4); + *(staticBasePtr + staticCopyIdx) = staticData; + } + } + + indexBase += UNIT_PER_ITER; + staticCopyBase += STATIC_COPY_PER_ITER * 4; + acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount; + packPtr = pipeline.finishSendPacket(acquireNewStep); + } + + pipeline.sendFinalize(); + } else { + int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1); + int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum; + int recvUnitCount = recvTokenCount * groupSize; + + int unitIdBase = 0; + int staticCopyBase = 0; + while (unitIdBase < recvUnitCount || + (localExpertStatics != nullptr && staticCopyBase < expertCount)) { + void* packetPtr = pipeline.getNewRecvPacket(); + int packetUnitCount = + unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase; + packetUnitCount = max(packetUnitCount, 0); + if (threadIdx.x < UNIT_PER_ITER) { + if (threadIdx.x < packetUnitCount) { + int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize; + int groupId = (unitIdBase + threadIdx.x) % groupSize; + int* expertsPtr = (int*)(packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ExpertType*)(experts)) = *((ExpertType*)(expertsPtr)); + ExpertType* dstExpertsPtr = + (ExpertType*)(recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + *dstExpertsPtr = *((ExpertType*)(experts)); + + if (recvScales != nullptr) { + float* scaleBasePtr = (float*)(packetPtr + PipelineConfig::SCALE_OFFSET); + float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE; + *((ScaleType*)(scales)) = *((ScaleType*)(scalesPtr)); + ScaleType* dstScalesPtr = + (ScaleType*)(recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); + *dstScalesPtr = *((ScaleType*)(scales)); + } } + } else if (localExpertStatics != nullptr) { + int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; + if (staticCopyBase + staticCopyIdx * 4 < expertCount) { + int4* staticBasePtr = (int4*)(packetPtr + PipelineConfig::STATIC_COPY_OFFSET); + int4 staticData = *(staticBasePtr + staticCopyIdx); + *(int4*)(gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + + staticCopyIdx * 4) = staticData; + } + } - pipeline.reset(); + unitIdBase += packetUnitCount; + staticCopyBase += STATIC_COPY_PER_ITER * 4; } + + pipeline.reset(); + } } -__global__ void memsetExpertIdsDevice( - int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, int rankCount) -{ - int maxTokenCount = maxTokenCountPerRank * rankCount; - int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1); - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK; - i += gridDim.x * blockDim.x) - { - *(expertIds + i + totalRecvTokenCount * topK) = slotCount; - } +__global__ void memsetExpertIdsDevice(int* expertIds, int* recvCountsCumsum, + int maxTokenCountPerRank, int topK, int slotCount, + int rankCount) { + int maxTokenCount = maxTokenCountPerRank * rankCount; + int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; + i + totalRecvTokenCount * topK < maxTokenCount * topK; i += gridDim.x * blockDim.x) { + *(expertIds + i + totalRecvTokenCount * topK) = slotCount; + } } void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream) -{ - // first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive - int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA; - int block_size = 1024; - dim3 block(block_size); - dim3 grid(grid_x); - FLASHINFER_CHECK(topK >= 1 && topK <= 32, "Only 1 <= topK <= 32 is supported now."); - auto* kernelFn = computeCountAndIndiceDevice<1>; - if (topK > 16) - { - kernelFn = computeCountAndIndiceDevice<32>; - } - else if (topK > 8) - { - kernelFn = computeCountAndIndiceDevice<16>; - } - else if (topK > 4) - { - kernelFn = computeCountAndIndiceDevice<8>; - } - else if (topK > 2) - { - kernelFn = computeCountAndIndiceDevice<4>; - } - else if (topK > 1) - { - kernelFn = computeCountAndIndiceDevice<2>; - } - kernelFn<<>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace, - recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount); + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, + MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, + int topK, int expert_count, int rankId, int rankCount, + cudaStream_t stream) { + // first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for + // receive + int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA; + int block_size = 1024; + dim3 block(block_size); + dim3 grid(grid_x); + FLASHINFER_CHECK(topK >= 1 && topK <= 32, "Only 1 <= topK <= 32 is supported now."); + auto* kernelFn = computeCountAndIndiceDevice<1>; + if (topK > 16) { + kernelFn = computeCountAndIndiceDevice<32>; + } else if (topK > 8) { + kernelFn = computeCountAndIndiceDevice<16>; + } else if (topK > 4) { + kernelFn = computeCountAndIndiceDevice<8>; + } else if (topK > 2) { + kernelFn = computeCountAndIndiceDevice<4>; + } else if (topK > 1) { + kernelFn = computeCountAndIndiceDevice<2>; + } + kernelFn<<>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, + backwardIndiceWorkspace, recvIndiceWorkspace, workspace, + tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, + rankCount); } -void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream) -{ - int block_size = CUMSUM_THREADS_PER_BLOCK; - dim3 block(block_size); - dim3 grid(2); - computeCumsumDevice<<>>(sendCountsCumsum, recvCountsCumsum, rankId, rankCount); +void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, + cudaStream_t stream) { + int block_size = CUMSUM_THREADS_PER_BLOCK; + dim3 block(block_size); + dim3 grid(2); + computeCumsumDevice<<>>(sendCountsCumsum, recvCountsCumsum, rankId, + rankCount); } -void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, - int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, - int maxTokenCountPerRank, cudaStream_t stream) -{ - dim3 block(512); - dim3 grid(rankCount, 2); - moveIndiceDevice<<>>(sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice, - backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank); +void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, + int* gatherSendIndice, int* backwardIndice, int* gatherBackwardIndice, + int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, + int maxTokenCountPerRank, cudaStream_t stream) { + dim3 block(512); + dim3 grid(rankCount, 2); + moveIndiceDevice<<>>( + sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice, backwardIndice, + gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank); } -void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, - int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, - int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, - int slotCount, int rankId, int rankCount, cudaStream_t stream) -{ - int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER; - dim3 block(block_size); - dim3 grid(rankCount, 2); - - if (topK % 4 == 0) - { - using PipelineConfig = PipelineConfig<4, 16>; - static_assert( - PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, - "FIFO size is too small"); - allToAllMetadataDevice<<>>(sendExperts, recvExperts, - sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, - localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, - slotCount, rankId, rankCount); - } - else - { - using PipelineConfig = PipelineConfig<1, 64>; - static_assert( - PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, - "FIFO size is too small"); - allToAllMetadataDevice<<>>(sendExperts, recvExperts, - sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, - localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, - slotCount, rankId, rankCount); - } - - int smCount = getMultiProcessorCount(); - memsetExpertIdsDevice<<>>( - recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount); +void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, + int* localExpertStatics, int* gatheredExpertStatics, + MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, + int* recvCountsCumsum, int* localRecvIndice, int tokenCount, + int maxTokenCountPerRank, int topK, int expertCount, int slotCount, + int rankId, int rankCount, cudaStream_t stream) { + int block_size = + localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER; + dim3 block(block_size); + dim3 grid(rankCount, 2); + + if (topK % 4 == 0) { + using PipelineConfig = PipelineConfig<4, 16>; + static_assert( + PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= + FIFO_SIZE_IN_U64, + "FIFO size is too small"); + allToAllMetadataDevice<<>>( + sendExperts, recvExperts, sendScales, recvScales, localExpertStatics, gatheredExpertStatics, + workspace, sendCountsCumsum, localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, + maxTokenCountPerRank, topK, expertCount, slotCount, rankId, rankCount); + } else { + using PipelineConfig = PipelineConfig<1, 64>; + static_assert( + PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= + FIFO_SIZE_IN_U64, + "FIFO size is too small"); + allToAllMetadataDevice<<>>( + sendExperts, recvExperts, sendScales, recvScales, localExpertStatics, gatheredExpertStatics, + workspace, sendCountsCumsum, localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, + maxTokenCountPerRank, topK, expertCount, slotCount, rankId, rankCount); + } + + int smCount = getMultiProcessorCount(); + memsetExpertIdsDevice<<>>( + recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount); } -size_t getMoePrepareWorkspaceSize(int epSize) -{ - return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize; +size_t getMoePrepareWorkspaceSize(int epSize) { + return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize; } -} // namespace moe_prepare +} // namespace moe_prepare -} // namespace flashinfer::trtllm_alltoall +} // namespace flashinfer::trtllm_alltoall diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index f374be806..be965fdc2 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -193,7 +193,7 @@ def get_moe_prepare_workspace_size_per_rank( ep_size: int, ) -> int: return module.get_moe_prepare_workspace_size_per_rank(ep_size) - + @register_custom_op( "flashinfer::moe_prepare", mutates_args=[], @@ -209,7 +209,16 @@ def moe_prepare( expert_count: int, slot_count: int, top_k: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: return module.moe_prepare( experts_ids, scales, @@ -319,11 +328,13 @@ def get_moe_commworkspace_size_per_rank( ) -> int: return get_comm_alltoall_module().get_moe_commworkspace_size_per_rank(ep_size) + def get_moe_prepare_workspace_size_per_rank( ep_size: int, ) -> int: return get_comm_alltoall_module().get_moe_prepare_workspace_size_per_rank(ep_size) + def moe_prepare( experts_ids: torch.Tensor, scales: Optional[torch.Tensor], @@ -335,7 +346,16 @@ def moe_prepare( expert_count: int, slot_count: int, top_k: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: return get_comm_alltoall_module().moe_prepare( experts_ids, scales, @@ -349,6 +369,7 @@ def moe_prepare( top_k, ) + @dataclass class MoEAlltoallInfo: local_gather_indices: torch.Tensor @@ -455,20 +476,12 @@ def mnnvl_moe_alltoallv_prepare_without_allgather( local_token_allocation_count, ) - return alltoall_info, prepared_local_experts, prepared_local_scales, gathered_expert_statics - - @staticmethod - def mnnvl_moe_expert_static_allgather( - expert_ids: torch.Tensor, - workspace: torch.Tensor, - ep_rank: int, - ep_size: int, - expert_count: int, - ): - gathered_expert_ids = torch.ops.trtllm.mnnvl_moe_expert_static_allgather( - expert_ids, workspace, ep_rank, ep_size, expert_count + return ( + alltoall_info, + prepared_local_experts, + prepared_local_scales, + gathered_expert_statics, ) - return gathered_expert_ids @staticmethod def mnnvl_moe_alltoallv_prepare( diff --git a/include/flashinfer/comm/trtllm_alltoall_prepare.cuh b/include/flashinfer/comm/trtllm_alltoall_prepare.cuh index 40b1cd51f..2945eee97 100644 --- a/include/flashinfer/comm/trtllm_alltoall_prepare.cuh +++ b/include/flashinfer/comm/trtllm_alltoall_prepare.cuh @@ -16,16 +16,14 @@ #pragma once -#include #include +#include #define DEBUG_PIPELINE 0 -namespace flashinfer::trtllm_alltoall -{ +namespace flashinfer::trtllm_alltoall { -namespace moe_prepare -{ +namespace moe_prepare { #define STEP_DEPTH 2 #define THREADS_PER_UNIT 1 @@ -44,15 +42,15 @@ static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA; template -struct PipelineConfig -{ - static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT; - static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT; - static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); - static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int); - static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); - static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int); - static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8); +struct PipelineConfig { + static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT; + static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT; + static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); + static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int); + static constexpr int STATIC_COPY_OFFSET = + UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); + static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int); + static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8); }; // 1MB FIFO size @@ -64,65 +62,67 @@ static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8; #define ALIGN_256 alignas(256) #endif -struct ALIGN_256 MoeCommFifoConnInfo -{ - volatile uint64_t head; // write position - volatile uint64_t tail; // read position - volatile uint64_t count; // for counter +struct ALIGN_256 MoeCommFifoConnInfo { + volatile uint64_t head; // write position + volatile uint64_t tail; // read position + volatile uint64_t count; // for counter }; -struct MoeCommWorkspace -{ - uint64_t* workspacePtr; - size_t rankStrideInU64; +struct MoeCommWorkspace { + uint64_t* workspacePtr; + size_t rankStrideInU64; #ifdef __CUDACC__ - __inline__ __device__ uint64_t* getFifoBasePtr( - bool isSender, int epRank, int peerRank, int channel, int channelCount) const - { - // fifo itself is in receiver's side. - if (isSender) - { - return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64; - } - else - { - return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64; - } - } - - __inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo( - bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const - { - // fifoInfo is in sender's side. - uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize; - int strideIndice = isSender ? epRank : peerRank; - int fifoInfoIndice = isSender ? peerRank : epRank; - fifoInfoPtrU64 += strideIndice * rankStrideInU64; - MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64; - MoeCommFifoConnInfo* result = fifoInfoPtr + fifoInfoIndice * channelCount + channel; - return result; + __inline__ __device__ uint64_t* getFifoBasePtr(bool isSender, int epRank, int peerRank, + int channel, int channelCount) const { + // fifo itself is in receiver's side. + if (isSender) { + return workspacePtr + peerRank * rankStrideInU64 + + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64; + } else { + return workspacePtr + epRank * rankStrideInU64 + + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64; } + } + + __inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(bool isSender, int epRank, + int peerRank, int channel, int epSize, + int channelCount) const { + // fifoInfo is in sender's side. + uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize; + int strideIndice = isSender ? epRank : peerRank; + int fifoInfoIndice = isSender ? peerRank : epRank; + fifoInfoPtrU64 += strideIndice * rankStrideInU64; + MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*)fifoInfoPtrU64; + MoeCommFifoConnInfo* result = fifoInfoPtr + fifoInfoIndice * channelCount + channel; + return result; + } #endif }; void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream); - -void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream); - -void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, int* gatherSendIndice, - int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, - int maxTokenCountPerRank, cudaStream_t stream); - -void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, - int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, - int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, - int slotCount, int rankId, int rankCount, cudaStream_t stream); + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, + MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, + int topK, int expert_count, int rankId, int rankCount, + cudaStream_t stream); + +void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, + cudaStream_t stream); + +void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, + int* gatherSendIndice, int* backwardIndice, int* gatherBackwardIndice, + int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, + int maxTokenCountPerRank, cudaStream_t stream); + +void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, + int* localExpertStatics, int* gatheredExpertStatics, + MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, + int* recvCountsCumsum, int* localRecvIndice, int tokenCount, + int maxTokenCountPerRank, int topK, int expertCount, int slotCount, + int rankId, int rankCount, cudaStream_t stream); size_t getMoePrepareWorkspaceSize(int epSize); -} // namespace moe_prepare +} // namespace moe_prepare -} // namespace flashinfer::trtllm_alltoall +} // namespace flashinfer::trtllm_alltoall diff --git a/tests/test_trtllm_alltoall.py b/tests/test_trtllm_alltoall.py index 855c4d8e6..699556eaf 100644 --- a/tests/test_trtllm_alltoall.py +++ b/tests/test_trtllm_alltoall.py @@ -509,47 +509,63 @@ def test_moe_local_gather( assert torch.equal(local_expert_ids, ref_local_expert_ids) assert torch.equal(local_scales, ref_local_scales) + @pytest.mark.parametrize( - "ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank", [ - (0, 2, 16, 20, 8, 512), - (0, 2, 16, 16, 3, 300), - (0, 4, 20, 24, 8, 4000), - (0, 8, 96, 96, 8, 1000), - (3, 8, 128, 128, 8, 1000), - (3, 8, 128, 144, 8, 1), - (0, 4, 72, 80, 4, 2256), - (0, 4, 72, 80, 6, 3333), - # Hang with stream count > 8 - #(0, 9, 90, 8, 100), -]) -def test_moe_alltoall_prepare(ep_rank: int, ep_size: int, - expert_count: int, slot_count: int, - top_k: int, max_token_count_per_rank: int): + "ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank", + [ + (0, 2, 16, 20, 8, 512), + (0, 2, 16, 16, 3, 300), + (0, 4, 20, 24, 8, 4000), + (0, 8, 96, 96, 8, 1000), + (3, 8, 128, 128, 8, 1000), + (3, 8, 128, 144, 8, 1), + (0, 4, 72, 80, 4, 2256), + (0, 4, 72, 80, 6, 3333), + # Hang with stream count > 8 + # (0, 9, 90, 8, 100), + ], +) +def test_moe_alltoall_prepare( + ep_rank: int, + ep_size: int, + expert_count: int, + slot_count: int, + top_k: int, + max_token_count_per_rank: int, +): torch.cuda.set_device(0) cpu_expert_ids_all_ranks_lists = [] cpu_token_count_lists = [] cpu_scales_all_ranks_lists = [] for _ in range(ep_size): - token_count = torch.randint(max_token_count_per_rank // 2, - max_token_count_per_rank + 1, (1, ), - dtype=torch.int32, - device=torch.device('cpu')) + token_count = torch.randint( + max_token_count_per_rank // 2, + max_token_count_per_rank + 1, + (1,), + dtype=torch.int32, + device=torch.device("cpu"), + ) token_count = 1 if token_count == 0 else token_count token_count = max_token_count_per_rank cpu_expert_ids_all_ranks_lists.append( - torch.randint(0, - slot_count, (token_count, top_k), - dtype=torch.int32, - device=torch.device('cpu'))) + torch.randint( + 0, + slot_count, + (token_count, top_k), + dtype=torch.int32, + device=torch.device("cpu"), + ) + ) cpu_scales_all_ranks_lists.append( - torch.zeros(token_count, - top_k, - dtype=torch.float32, - device=torch.device('cpu')) + 0.5) + torch.zeros( + token_count, top_k, dtype=torch.float32, device=torch.device("cpu") + ) + + 0.5 + ) cpu_token_count_lists.append(token_count) @@ -571,8 +587,7 @@ def generate_references(): for token_id in range(local_token_count): target_set = set() for pos in range(top_k): - expert_id = int( - cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) + expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) target_rank_id = compute_target_rank(expert_id) target_set.add(target_rank_id) @@ -581,32 +596,33 @@ def generate_references(): total_send_token_count = 0 for rank in range(ep_size): - #print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}') - base = ref_local_send_rank_count_cumsum[rank - - 1] if rank > 0 else 0 - ref_local_send_rank_count_cumsum[ - rank] = send_token_count_to_ranks[rank] + base + # print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}') + base = ref_local_send_rank_count_cumsum[rank - 1] if rank > 0 else 0 + ref_local_send_rank_count_cumsum[rank] = ( + send_token_count_to_ranks[rank] + base + ) total_send_token_count += send_token_count_to_ranks[rank] - ref_local_backward_send_rank_indices = [0 - ] * (total_send_token_count) + ref_local_backward_send_rank_indices = [0] * (total_send_token_count) ref_local_send_rank_indices = [0] * (total_send_token_count) current_send_token_ids = [0] * ep_size for token_id in range(local_token_count): target_set = set() for pos in range(top_k): - expert_id = int( - cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) + expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) target_rank_id = compute_target_rank(expert_id) if target_rank_id not in target_set: - cumsum_before = 0 if target_rank_id == 0 else ref_local_send_rank_count_cumsum[ - target_rank_id - 1] - send_index = cumsum_before + current_send_token_ids[ - target_rank_id] + cumsum_before = ( + 0 + if target_rank_id == 0 + else ref_local_send_rank_count_cumsum[target_rank_id - 1] + ) + send_index = cumsum_before + current_send_token_ids[target_rank_id] ref_local_send_rank_indices[send_index] = token_id - ref_local_backward_send_rank_indices[ - send_index] = token_id * top_k + pos + ref_local_backward_send_rank_indices[send_index] = ( + token_id * top_k + pos + ) current_send_token_ids[target_rank_id] += 1 target_set.add(target_rank_id) @@ -618,31 +634,48 @@ def generate_references(): for token_id in range(token_count): token_is_received = False for pos in range(top_k): - expert_id = int( - cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) + expert_id = int(cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) sf = cpu_scales_all_ranks_lists[rank][token_id][pos] target_rank_id = compute_target_rank(expert_id) if target_rank_id == ep_rank: if not token_is_received: token_is_received = True - ref_prepared_local_expert_ids.append( - [slot_count] * top_k) + ref_prepared_local_expert_ids.append([slot_count] * top_k) ref_prepared_local_scales.append([0.0] * top_k) ref_prepared_local_expert_ids[-1][pos] = expert_id ref_prepared_local_scales[-1][pos] = sf if token_is_received: - ref_local_recv_rank_indices.append( - total_recv_token_count) + ref_local_recv_rank_indices.append(total_recv_token_count) total_recv_token_count += 1 current_recv_token_count += 1 - ref_local_recv_rank_count_cumsum[ - rank] = current_recv_token_count if rank == 0 else ref_local_recv_rank_count_cumsum[ - rank - 1] + current_recv_token_count + ref_local_recv_rank_count_cumsum[rank] = ( + current_recv_token_count + if rank == 0 + else ref_local_recv_rank_count_cumsum[rank - 1] + + current_recv_token_count + ) - return ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count + return ( + ref_prepared_local_expert_ids, + ref_prepared_local_scales, + ref_local_send_rank_count_cumsum, + ref_local_send_rank_indices, + ref_local_recv_rank_count_cumsum, + ref_local_recv_rank_indices, + ref_local_backward_send_rank_indices, + total_recv_token_count, + ) - ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references( - ) + ( + ref_prepared_local_expert_ids, + ref_prepared_local_scales, + ref_local_send_rank_count_cumsum, + ref_local_send_rank_indices, + ref_local_recv_rank_count_cumsum, + ref_local_recv_rank_indices, + ref_local_backward_send_rank_indices, + total_recv_token_count, + ) = generate_references() cpu_experter_count_lists = [] for rank in range(ep_size): @@ -651,35 +684,37 @@ def generate_references(): local_expert_count.append(rank * expert_count + i) cpu_experter_count_lists.append(torch.IntTensor(local_expert_count)) - #expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda() + # expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda() expert_ids_all_ranks = [ cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size) ] - #scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() - scales_all_ranks = [ - cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size) - ] + # scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() + scales_all_ranks = [cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)] - experter_count_lists = [ - cpu_experter_count_lists[i].cuda() for i in range(ep_size) - ] + experter_count_lists = [cpu_experter_count_lists[i].cuda() for i in range(ep_size)] cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)] - workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank( - ep_size) + workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank(ep_size) - all_workspaces = torch.zeros(ep_size, - workspace_size, - dtype=torch.uint64, - device=torch.device('cuda')) + all_workspaces = torch.zeros( + ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") + ) stream = torch.cuda.Stream() with torch.cuda.stream(stream): tllm_alltoall.moe_prepare( - expert_ids_all_ranks[0], scales_all_ranks[0], - experter_count_lists[0], all_workspaces, - max_token_count_per_rank, 0, 1, expert_count, slot_count, top_k) + expert_ids_all_ranks[0], + scales_all_ranks[0], + experter_count_lists[0], + all_workspaces, + max_token_count_per_rank, + 0, + 1, + expert_count, + slot_count, + top_k, + ) stream.wait_stream(torch.cuda.current_stream()) # Make torch alloc tensor to avoid cuda sync @@ -692,72 +727,101 @@ def generate_references(): backward_local_recv_rank_indices = [] for _ in range(ep_size): prepared_local_experts.append( - torch.empty(max_token_count_per_rank * ep_size, - top_k, - dtype=torch.int32, - device=torch.device('cuda'))) + torch.empty( + max_token_count_per_rank * ep_size, + top_k, + dtype=torch.int32, + device=torch.device("cuda"), + ) + ) prepared_local_scales.append( - torch.empty(max_token_count_per_rank * ep_size, - top_k, - dtype=torch.float32, - device=torch.device('cuda'))) + torch.empty( + max_token_count_per_rank * ep_size, + top_k, + dtype=torch.float32, + device=torch.device("cuda"), + ) + ) local_send_rank_count_cumsum.append( - torch.empty(ep_size, - dtype=torch.int32, - device=torch.device('cuda'))) + torch.empty(ep_size, dtype=torch.int32, device=torch.device("cuda")) + ) local_send_rank_indices.append( - torch.empty(max_token_count_per_rank * ep_size, - dtype=torch.int32, - device=torch.device('cuda'))) + torch.empty( + max_token_count_per_rank * ep_size, + dtype=torch.int32, + device=torch.device("cuda"), + ) + ) local_recv_rank_count_cumsum.append( - torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) + ) local_recv_rank_indices.append( - torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) + ) backward_local_recv_rank_indices.append( - torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) + torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) + ) - prepared_local_experts = [] - prepared_local_scales = [] - local_send_rank_count_cumsum = [] - local_send_rank_indices = [] - local_recv_rank_count_cumsum = [] - local_recv_rank_indices = [] - backward_local_recv_rank_indices = [] + prepared_local_experts = None + prepared_local_scales = None + local_send_rank_count_cumsum = None + local_send_rank_indices = None + local_recv_rank_count_cumsum = None + local_recv_rank_indices = None + backward_local_recv_rank_indices = None # reset the workspace - all_workspaces = torch.zeros(ep_size, - workspace_size, - dtype=torch.uint64, - device=torch.device('cuda')) + all_workspaces = torch.zeros( + ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") + ) # do prepare in parallel for rank in range(ep_size): with torch.cuda.stream(cuda_streams_all_ranks[rank]): if rank == ep_rank: - prepared_local_experts, prepared_local_scales, local_send_rank_count_cumsum, \ - local_send_rank_indices, local_recv_rank_count_cumsum, local_recv_rank_indices, \ - backward_local_recv_rank_indices, gathered_expert_statics\ - = tllm_alltoall.moe_prepare(expert_ids_all_ranks[rank], scales_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank, - rank, ep_size, expert_count, slot_count, top_k) + ( + prepared_local_experts, + prepared_local_scales, + local_send_rank_count_cumsum, + local_send_rank_indices, + local_recv_rank_count_cumsum, + local_recv_rank_indices, + backward_local_recv_rank_indices, + gathered_expert_statics, + ) = tllm_alltoall.moe_prepare( + expert_ids_all_ranks[rank], + scales_all_ranks[rank], + experter_count_lists[rank], + all_workspaces, + max_token_count_per_rank, + rank, + ep_size, + expert_count, + slot_count, + top_k, + ) else: tllm_alltoall.moe_prepare( - expert_ids_all_ranks[rank], scales_all_ranks[rank], - experter_count_lists[rank], all_workspaces, - max_token_count_per_rank, rank, ep_size, expert_count, - slot_count, top_k) + expert_ids_all_ranks[rank], + scales_all_ranks[rank], + experter_count_lists[rank], + all_workspaces, + max_token_count_per_rank, + rank, + ep_size, + expert_count, + slot_count, + top_k, + ) for rank in range(ep_size): cuda_streams_all_ranks[rank].synchronize() - prepared_local_experts_cpu = prepared_local_experts[: - total_recv_token_count].cpu( - ) - prepared_local_scales_cpu = prepared_local_scales[: - total_recv_token_count].cpu( - ) + prepared_local_experts_cpu = prepared_local_experts[:total_recv_token_count].cpu() + prepared_local_scales_cpu = prepared_local_scales[:total_recv_token_count].cpu() for i in range(total_recv_token_count): for j in range(top_k): expert_id = int(prepared_local_experts_cpu[i][j]) - assert 0 <= expert_id and expert_id <= slot_count + assert expert_id >= 0 and expert_id <= slot_count if expert_id < slot_count: assert compute_target_rank(expert_id) == ep_rank scale = float(prepared_local_scales_cpu[i][j]) @@ -766,38 +830,34 @@ def generate_references(): gathered_expert_statics_cpu = gathered_expert_statics.cpu() for rank in range(ep_size): for i in range(expert_count): - assert int(gathered_expert_statics_cpu[rank] - [i]) == rank * expert_count + i + assert int(gathered_expert_statics_cpu[rank][i]) == rank * expert_count + i - ref_local_send_rank_count_cumsum = torch.IntTensor( - ref_local_send_rank_count_cumsum) - assert torch.equal(local_send_rank_count_cumsum.cpu(), - ref_local_send_rank_count_cumsum) + ref_local_send_rank_count_cumsum = torch.IntTensor(ref_local_send_rank_count_cumsum) + assert torch.equal( + local_send_rank_count_cumsum.cpu(), ref_local_send_rank_count_cumsum + ) local_send_rank_indices = local_send_rank_indices.cpu() - backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu( - ) + backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu() for i in range(ep_size): base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1] for j in range(base, ref_local_send_rank_count_cumsum[i]): token_id = local_send_rank_indices[j] lane_id = backward_local_recv_rank_indices[j] - token_id * top_k - expert_id = int( - cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id]) + expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id]) assert compute_target_rank(expert_id) == i - ref_local_recv_rank_count_cumsum = torch.IntTensor( - ref_local_recv_rank_count_cumsum) + ref_local_recv_rank_count_cumsum = torch.IntTensor(ref_local_recv_rank_count_cumsum) assert torch.equal( - local_recv_rank_count_cumsum[:ref_local_recv_rank_count_cumsum. - size(0)].cpu(), - ref_local_recv_rank_count_cumsum) + local_recv_rank_count_cumsum[: ref_local_recv_rank_count_cumsum.size(0)].cpu(), + ref_local_recv_rank_count_cumsum, + ) - ref_local_recv_rank_indices = torch.IntTensor( - ref_local_recv_rank_indices) + ref_local_recv_rank_indices = torch.IntTensor(ref_local_recv_rank_indices) assert torch.equal( - local_recv_rank_indices[:ref_local_recv_rank_indices.size(0)].cpu(), - ref_local_recv_rank_indices) + local_recv_rank_indices[: ref_local_recv_rank_indices.size(0)].cpu(), + ref_local_recv_rank_indices, + ) if __name__ == "__main__": From 0189214e7d5e7917fa442bf9d9fbc293b4251e54 Mon Sep 17 00:00:00 2001 From: "Shu Wang." Date: Mon, 25 Aug 2025 21:30:05 +0000 Subject: [PATCH 3/3] Add alternative comm backend for mnnvl --- flashinfer/comm/mnnvl.py | 58 ++++++++++++++++++++++++++++++ flashinfer/comm/trtllm_alltoall.py | 12 +++++-- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index e495825de..0aadbc52c 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,6 +16,8 @@ import ctypes import logging import os +from abc import ABC, abstractmethod +from dataclasses import dataclass import platform import sys from typing import Any, Dict, List, Optional @@ -220,6 +222,51 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm): def __getattr__(self, name): return getattr(self._comm, name) + class CommBackend(ABC): + """Abstract communication backend interface""" + @abstractmethod + def Get_rank(self) -> int: ... + + @abstractmethod + def Get_size(self) -> int: ... + + @abstractmethod + def allgather(self, data: int) -> List[int]: ... + + @abstractmethod + def allgather_bytes(self, data): ... + + @abstractmethod + def Split(self, color: int, key: int) -> 'CommBackend': ... + class LegacyMPIBackend(CommBackend): + """Adapter for the original MpiComm singleton pattern""" + def __init__(self): + self._mpicomm = MpiComm() + + def Get_rank(self) -> int: + return self._mpicomm.Get_rank() + + def Get_size(self) -> int: + return self._mpicomm.Get_size() + + def allgather(self, data: int) -> List[int]: + return self._mpicomm.allgather(data) + + def allgather_bytes(self, data): + return self._mpicomm.allgather(data) + + def Split(self, color: int, key: int) -> CommBackend: + # Original split logic + new_comm = self._mpicomm.Split(color, key) + return LegacyMPIBackend() # Returns new adapter + @dataclass + class MnnvlConfig: + """Configuration for MNNVL memory management""" + comm_backend: Optional[CommBackend] = None + allocation_granularity: int = 0 + fabric_page_size: int = 1 << 29 # 512MB + + class MnnvlMemory: # type: ignore[no-redef] initialized: bool = False @@ -275,6 +322,17 @@ def initialize(): pynvml.nvmlInit() MnnvlMemory.initialized = True + @staticmethod + def set_comm(mapping: Mapping, config: MnnvlConfig = None): + # print("set_comm"*10) + # print(f"config:{config}, tp_rank:{mapping.tp_rank}") + MnnvlMemory._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend()) + comm0 = config.comm_backend + comm = comm0.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) + MnnvlMemory.comm = comm + @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index be965fdc2..7d84a9281 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -26,7 +26,7 @@ from ..jit import gen_jit_spec from ..utils import register_custom_op from .mapping import Mapping -from .mnnvl import MnnvlMemory +from .mnnvl import (MnnvlMemory, MnnvlConfig) def gen_comm_alltoall_module() -> JitSpec: @@ -389,13 +389,16 @@ class MnnvlMoe: moe_mapping: Mapping = None @staticmethod - def get_moe_workspaces(mapping: Mapping): + def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None): if MnnvlMoe.moe_workspace is not None: assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now" return MnnvlMoe.moe_workspace_tensor MnnvlMoe.moe_mapping = mapping workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size) + if config: + MnnvlMemory.set_comm(mapping, config) + MnnvlMemory.initialize() MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank) MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor( torch.uint64 @@ -403,13 +406,16 @@ def get_moe_workspaces(mapping: Mapping): return MnnvlMoe.moe_workspace_tensor @staticmethod - def get_moe_prepare_workspace(mapping: Mapping): + def get_moe_prepare_workspace(mapping: Mapping, config: Optional[MnnvlConfig] = None): if MnnvlMoe.moe_prepare_workspace_tensor is not None: assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now" return MnnvlMoe.moe_prepare_workspace_tensor workspace_size_per_rank = get_moe_prepare_workspace_size_per_rank( mapping.tp_size ) + if config: + MnnvlMemory.set_comm(mapping, config) + MnnvlMemory.initialize() MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank) MnnvlMoe.moe_prepare_workspace_tensor = ( MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64)