From 77b7b8af42c425591c20fd9cfee412a895c2a0ba Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 20 May 2025 11:58:47 +0000 Subject: [PATCH 01/71] init --- .../multiary_infer_sym.cc | 5 + .../infer_symbolic_shape/multiary_infer_sym.h | 1 + paddle/phi/infermeta/multiary.cc | 23 +++ paddle/phi/infermeta/multiary.h | 7 + paddle/phi/infermeta/ternary.cc | 19 +++ paddle/phi/infermeta/ternary.h | 5 + .../kernels/gpu/moe_combine_grad_kernel.cu | 151 ++++++++++++++++++ paddle/phi/kernels/gpu/moe_combine_kernel.cu | 112 +++++++++++++ paddle/phi/kernels/moe_combine_grad_kernel.h | 27 ++++ paddle/phi/kernels/moe_combine_kernel.h | 25 +++ paddle/phi/ops/yaml/backward.yaml | 10 ++ paddle/phi/ops/yaml/ops.yaml | 10 ++ 12 files changed, 395 insertions(+) create mode 100644 paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/moe_combine_kernel.cu create mode 100644 paddle/phi/kernels/moe_combine_grad_kernel.h create mode 100644 paddle/phi/kernels/moe_combine_kernel.h diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index c4f01bc296df97..29bcb9ab4105fe 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -20,6 +20,11 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" namespace paddle::dialect { +bool MoeCombineInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return true; +} + bool AccuracyOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index f19f1926469b3a..7eda3c3ce70907 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -141,5 +141,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeChannelWiseDequantizeMaxAbs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiDot) OP_DECLARE_INFER_SYMBOLIC_SHAPE(UpdateLossScaling_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBoxPost) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MoeCombine) } // namespace paddle::dialect diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ef5cd2679006a..99bdd6ab8d14bb 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -35,6 +35,29 @@ limitations under the License. */ namespace phi { + void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& grad_y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper){ + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + errors::InvalidArgument("Input X should have 2 dimensions")); + PADDLE_ENFORCE_EQ( + (scatter_index.dtype() == phi::DataType::INT32), + true, + errors::InvalidArgument( + "The input scatter_index type should be int32")); + grad_x->set_dims(phi::make_ddim(x_dim)); + grad_x->set_dtype(x.dtype()); + grad_combine_weights_helper->set_dims(phi::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); + grad_combine_weights_helper->set_dtype(x.dtype()); +} + std::vector GetMetaTensorsDim( const std::vector& tensors) { std::vector dims; diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index dfe1af6754aa9d..1f6a8dfa7a5651 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -40,6 +40,13 @@ namespace phi { std::vector GetMetaTensorsDim( const std::vector& tensors); +void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& grad_y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper); + void AdadeltaInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& avg_squared_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a66797a4d22437..a30d30c49d85ff 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -27,6 +27,25 @@ limitations under the License. */ namespace phi { +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y){ + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + errors::InvalidArgument("Input X should have 2 dimensions")); + PADDLE_ENFORCE_EQ( + combine_weights_shape.size(), + 2, + errors::InvalidArgument("Input combine_weights should have 2 dimensions")); // maybe + y->set_dims(phi::make_ddim({combine_weights_shape[0], x_dim[1]})); + y->set_dtype(x.dtype()); +} + + void AccuracyInferMeta(const MetaTensor& out, const MetaTensor& indice, const MetaTensor& label, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 14dd2685949573..e23f14e3c7db33 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -33,6 +33,11 @@ namespace phi { // // The InferMeta Functions in this file are arranged in alphabetic order. +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y); + void AccuracyInferMeta(const MetaTensor& out, const MetaTensor& indice, const MetaTensor& label, diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu new file mode 100644 index 00000000000000..78577bc4ec029b --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -0,0 +1,151 @@ +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_combine_grad_kernel.h" +namespace phi { + +template +__global__ void combine_moe_bwd_kernel(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + const int64_t n) { + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + int64_t row_i = i / hidden_size; + int64_t slice_i = i - row_i * hidden_size; + const int* scatter_index_start = scatter_index + row_i * k; + const T grad_y_i = *(grad_y + i); + // y [ row_i, slice_i] + // combine [row_i, k, slice_i] + int64_t weight_base = row_i * k * hidden_size + slice_i; + + T* grad_cw_ptr = + grad_combine_weights_helper + weight_base; // stride hidden_size + for (int64_t ki = 0; ki < k; ki++) { + // get combine_weights i + int64_t ele_index = + static_cast(*(scatter_index_start + ki)) * hidden_size + + slice_i; + const T* w_ptr = combine_weights + row_i * k + ki; + const T* x_ptr = x + ele_index; + if ((*w_ptr) != T(0)) { + *(grad_x + ele_index) = grad_y_i * (*w_ptr); + } + *(grad_cw_ptr + ki * hidden_size) = grad_y_i * (*x_ptr); + } + } +} + +template +void combine_moe_bwd_kernelLauncher(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + // y is [seqlen, hidden_size] + // for kk in k: + // y[i][j] += x[scatter_index[i][kk]][j] * combine_weights[i][kk] + + const int64_t n = hidden_size * seqlen; + + const int64_t threads = 1024; + const int64_t blocks = (n + threads - 1) / threads; + + combine_moe_bwd_kernel + <<>>(x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + k, + seqlen, + hidden_size, + n); +} + +template +void apply_moe_combine_bwd(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + combine_moe_bwd_kernelLauncher(x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + k, + seqlen, + hidden_size, + stream); +} + +template +void moe_combine_bwd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + const DenseTensor* grad_x, + const DenseTensor* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size) { + apply_moe_combine_bwd(x.data(), + combine_weights.data(), + scatter_index.data(), + grad_y.data(), + const_cast(grad_x.data()), + const_cast(grad_combine_weights_helper.data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); +} +template +void MoeCombineGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + DenseTensor* grad_x, + DenseTensor* grad_combine_weights_helper) { + dev_ctx.template Alloc(grad_x); + dev_ctx.template Alloc(grad_combine_weights_helper); + moe_combine_bwd(x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_combine_grad, + GPU, + ALL_LAYOUT, + phi::MoeCombineGradKernel, + float, + phi::dtype::bfloat16, + phi::dtype::float16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu new file mode 100644 index 00000000000000..c6cfb0ac9bb2eb --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -0,0 +1,112 @@ +#include "paddle/phi/kernels/moe_combine_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +__global__ void combine_moe_kernel(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + const int64_t n) { + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + int64_t row_i = i / hidden_size; + int64_t slice_i = i - row_i * hidden_size; + const int* scatter_index_start = scatter_index + row_i * k; + T* dest_ptr = y + i; + for (int ki = 0; ki < k; ki++) { + // get combine_weights i + const T* w_ptr = combine_weights + row_i * k + ki; + const T* x_ptr = + x + static_cast(*(scatter_index_start + ki)) * hidden_size + + slice_i; + *(dest_ptr) += (*w_ptr) * (*x_ptr); + } + } +} + +template +void combine_moe_kernelLauncher(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + // y is [seqlen, hidden_size] + // for kk in k: + // y[i][j] += x[scatter_index[i][kk]][j] * combine_weights[i][kk] + const int64_t n = hidden_size * seqlen; + + const int64_t threads = 1024; + const int64_t blocks = (n + threads - 1) / threads; + + combine_moe_kernel<<>>( + x, combine_weights, scatter_index, y, k, seqlen, hidden_size, n); +} + +template +void apply_moe_combine_fwd(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + combine_moe_kernelLauncher( + x, combine_weights, scatter_index, y, k, seqlen, hidden_size, stream); +} + +template +void moe_combine_fwd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size) { + apply_moe_combine_fwd(x.data(), + combine_weights.data(), + scatter_index.data(), + const_cast(y.data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); + } + + template + void MoeCombineKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + DenseTensor* y) { + dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very + // well, maybe replaced with x.dtype(); + auto combine_weights_shape = combine_weights.dims(); + moe_combine_fwd(x, + combine_weights, + scatter_index, + *y, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size + } +} + +PD_REGISTER_KERNEL(moe_combine, + GPU, + ALL_LAYOUT, + phi::MoeCombineKernel, + float, + phi::dtype::bfloat16, + phi::dtype::float16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/moe_combine_grad_kernel.h b/paddle/phi/kernels/moe_combine_grad_kernel.h new file mode 100644 index 00000000000000..dd8e6c2651f8db --- /dev/null +++ b/paddle/phi/kernels/moe_combine_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/common/scalar.h" +namespace phi { +template +void MoeCombineGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + DenseTensor* grad_x, + DenseTensor* grad_combine_weights_helper); +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/moe_combine_kernel.h b/paddle/phi/kernels/moe_combine_kernel.h new file mode 100644 index 00000000000000..8225241c018759 --- /dev/null +++ b/paddle/phi/kernels/moe_combine_kernel.h @@ -0,0 +1,25 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeCombineKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + DenseTensor* out); +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 3a48214c8e1579..d0195f069f0f98 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1,6 +1,16 @@ # This file is designed for backward C++ operators associated with # the operator in ops.yaml. +- backward_op : moe_combine_grad + forward : moe_combine (Tensor x, Tensor combine_weights, Tensor scatter_index) -> Tensor(y) + args : (Tensor x, Tensor combine_weights, Tensor scatter_index, Tensor y_grad) + output : Tensor(x_grad), Tensor(combine_weights_grad) + infer_meta : + func : MoeCombineGradInferMeta + kernel : + func : moe_combine_grad + data_type : x + - backward_op : abs_double_grad forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 5bd4eca9a850ac..c8e89064a1c337 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -7,6 +7,16 @@ # interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : moe_combine + args : (Tensor x, Tensor combine_weights, Tensor scatter_index) + output : Tensor(y) + infer_meta : + func : MoeCombineInferMeta + kernel : + func : moe_combine + data_type : x + backward : moe_combine_grad + - op : abs args : (Tensor x) output : Tensor(out) From 492ac04d381dfd2522a2cbc91d1538fc8a866273 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 03:30:18 +0000 Subject: [PATCH 02/71] insert moe_combine --- paddle/phi/infermeta/multiary.cc | 2 +- .../kernels/gpu/moe_combine_grad_kernel.cu | 11 ++-- paddle/phi/kernels/gpu/moe_combine_kernel.cu | 18 +++--- .../incubate/nn/functional/moe_combine.py | 63 +++++++++++++++++++ 4 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 python/paddle/incubate/nn/functional/moe_combine.py diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 99bdd6ab8d14bb..577e6b6f027a8d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -52,7 +52,7 @@ namespace phi { true, errors::InvalidArgument( "The input scatter_index type should be int32")); - grad_x->set_dims(phi::make_ddim(x_dim)); + grad_x->set_dims(phi::make_ddim({x_dim[0],x_dim[1]})); grad_x->set_dtype(x.dtype()); grad_combine_weights_helper->set_dims(phi::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); grad_combine_weights_helper->set_dtype(x.dtype()); diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu index 78577bc4ec029b..4192d58a30bd1f 100644 --- a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -109,12 +109,12 @@ void moe_combine_bwd(const Context& dev_ctx, const int64_t k, const int64_t seqlen, const int64_t hidden_size) { - apply_moe_combine_bwd(x.data(), + apply_moe_combine_bwd(x.data(), combine_weights.data(), scatter_index.data(), grad_y.data(), - const_cast(grad_x.data()), - const_cast(grad_combine_weights_helper.data()), + const_cast(grad_x->data()), + const_cast(grad_combine_weights_helper->data()), k, seqlen, hidden_size, @@ -130,7 +130,10 @@ void MoeCombineGradKernel(const Context& dev_ctx, DenseTensor* grad_combine_weights_helper) { dev_ctx.template Alloc(grad_x); dev_ctx.template Alloc(grad_combine_weights_helper); - moe_combine_bwd(x, + auto x_shape = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + moe_combine_bwd(dev_ctx, + x, combine_weights, scatter_index, grad_y, diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu index c6cfb0ac9bb2eb..47ae75dd1db829 100644 --- a/paddle/phi/kernels/gpu/moe_combine_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -74,7 +74,7 @@ void moe_combine_fwd(const Context& dev_ctx, const int64_t k, const int64_t seqlen, const int64_t hidden_size) { - apply_moe_combine_fwd(x.data(), + apply_moe_combine_fwd(x.data(), combine_weights.data(), scatter_index.data(), const_cast(y.data()), @@ -93,13 +93,15 @@ void moe_combine_fwd(const Context& dev_ctx, dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very // well, maybe replaced with x.dtype(); auto combine_weights_shape = combine_weights.dims(); - moe_combine_fwd(x, - combine_weights, - scatter_index, - *y, - combine_weights_shape[1], // k - combine_weights_shape[0], // seqlen - x_shape[1]); // hidden_size + auto x_shape = x.dims(); + moe_combine_fwd(dev_ctx, + x, + combine_weights, + scatter_index, + *y, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size } } diff --git a/python/paddle/incubate/nn/functional/moe_combine.py b/python/paddle/incubate/nn/functional/moe_combine.py new file mode 100644 index 00000000000000..1e7a26e657a3de --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_combine.py @@ -0,0 +1,63 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from paddle import _C_ops +import paddle +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + +def moe_combine( + x: Tensor, combine_weights: Tensor, scatter_index: Tensor, name: str | None = None +) -> Tensor: + """ + Args: + x: Input tensor [seq, dim] + combine_weights: Combination weights [s, k] + scatter_index: Scatter indices [k, s] + + Returns: + Output Combined output [s, dim] + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_combine(x, combine_weights, scatter_index) + helper = LayerHelper('moe_combine', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + inputs = { + 'x': x, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index + } + helper.append_op(type='moe_combine', inputs=inputs, outputs={'y': y}) + return y + +if __name__ == "__main__": + print("This module is not for direct use.") + x = paddle.arange(1, 16).view((5, 3)).astype('float32') + combine_weights = paddle.to_tensor([ + [0, 0], + [0, 0], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5] + ]) + + # 分散索引 + scatter_index = paddle.to_tensor([ + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0] + ]).astype('int32') + + # 输出计算 + output = paddle.zeros((5, 3)) + for s in range(5): + expert0_idx = scatter_index[0, s] + expert1_idx = scatter_index[1, s] + output[s] = ( + x[expert0_idx] * combine_weights[s, 0] + + x[expert1_idx] * combine_weights[s, 1] + ) + print(output) + print(moe_combine(x, combine_weights, scatter_index)) \ No newline at end of file From 22e36430bd8517906a15360aa44975ab3f1f7193 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 04:37:27 +0000 Subject: [PATCH 03/71] init --- paddle/phi/infermeta/unary.cc | 23 +++++++++++++++++++++++ paddle/phi/infermeta/unary.h | 7 +++++++ paddle/phi/ops/yaml/ops.yaml | 10 ++++++++++ 3 files changed, 40 insertions(+) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index abf1823d67c86e..3a74bc7bfe9f45 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -46,6 +46,29 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) { } } // namespace detail +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + MetaTensor* expert_id_out){ + auto expert_id_dims = expert_id.dims(); + PADDLE_ENFORCE_EQ( + expert_id_dims.size(), + 2, + common::errors::InvalidArgument( + "The input expert_id's dimensions size should be 2. But received " + "expert_id's dimensions size=[%d], expert_id's dimensions=[%s].", + expert_id_dims.size(), + expert_id_dims)); + + int64_t seqlen = expert_id_dims[0]; + int64_t k = expert_id_dims[1]; + expert_id_out->set_dims(common::make_ddim({seqlen, k})); + expert_id_out->set_dtype(expert_id.dtype()); +} + + void AddPositionEncodingInferMeta(const MetaTensor& x, float alpha, float beta, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6e9454e9fdac9d..f9d26f715536ce 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -33,6 +33,13 @@ struct MetaConfig; // // The InferMeta Functions in this file are arranged in alphabetic order. +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert + MetaTensor* expert_id_out); + void AddPositionEncodingInferMeta(const MetaTensor& x, float alpha, float beta, diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index c8e89064a1c337..9f43cc2c8d0146 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -17,6 +17,16 @@ data_type : x backward : moe_combine_grad +- op : expand_modality_expert_id + args : (Tensor expert_id, int64_t num_expert_per_modality, int64_t group_size, int64_t modality_offset, bool is_group_expert) + output : Tensor(expert_id_out) + infer_meta : + func : ExpandModalityExpertIdInferMeta + kernel : + func : expand_modality_expert_id + data_type : expert_id + + - op : abs args : (Tensor x) output : Tensor(out) From bebaf443fa3039ed6280f01b7f68203224aa96ad Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 06:01:10 +0000 Subject: [PATCH 04/71] update yaml --- .../multiary_infer_sym.cc | 5 --- .../infer_symbolic_shape/multiary_infer_sym.h | 1 - paddle/phi/infermeta/ternary.cc | 37 +++++++++---------- paddle/phi/infermeta/ternary.h | 10 ++--- paddle/phi/ops/yaml/backward.yaml | 19 +++++----- paddle/phi/ops/yaml/ops.yaml | 20 +++++----- 6 files changed, 42 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 29bcb9ab4105fe..c4f01bc296df97 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -20,11 +20,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" namespace paddle::dialect { -bool MoeCombineInferSymbolicShape( - pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return true; -} - bool AccuracyOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 7eda3c3ce70907..f19f1926469b3a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -141,6 +141,5 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeChannelWiseDequantizeMaxAbs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiDot) OP_DECLARE_INFER_SYMBOLIC_SHAPE(UpdateLossScaling_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBoxPost) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(MoeCombine) } // namespace paddle::dialect diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a30d30c49d85ff..f215a7b68c6206 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -27,25 +27,6 @@ limitations under the License. */ namespace phi { -void MoeCombineInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& scatter_index, - MetaTensor* y){ - auto x_dim = x.dims(); - auto combine_weights_shape = combine_weights.dims(); - PADDLE_ENFORCE_EQ( - x_dim.size(), - 2, - errors::InvalidArgument("Input X should have 2 dimensions")); - PADDLE_ENFORCE_EQ( - combine_weights_shape.size(), - 2, - errors::InvalidArgument("Input combine_weights should have 2 dimensions")); // maybe - y->set_dims(phi::make_ddim({combine_weights_shape[0], x_dim[1]})); - y->set_dtype(x.dtype()); -} - - void AccuracyInferMeta(const MetaTensor& out, const MetaTensor& indice, const MetaTensor& label, @@ -1631,6 +1612,24 @@ void MultiClassNMSInferMeta(const MetaTensor& bboxes, nms_rois_num->set_dtype(DataType::INT32); } +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y){ + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + common::errors::InvalidArgument("The dimensions of Input(x) must be 1, but " + "received dimensions of" + "Input(x) is [%d]", + x_dim.size())); + // maybe there is more conditions here.... + y->set_dims(phi::make_ddim({combine_weights_shape[0], x_dim[1]})); + y->set_dtype(x.dtype()); +} + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index e23f14e3c7db33..ae7ca25ae245f3 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -33,11 +33,6 @@ namespace phi { // // The InferMeta Functions in this file are arranged in alphabetic order. -void MoeCombineInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& scatter_index, - MetaTensor* y); - void AccuracyInferMeta(const MetaTensor& out, const MetaTensor& indice, const MetaTensor& label, @@ -274,6 +269,11 @@ void MatrixRankAtolRtolInferMeta(const MetaTensor& x, bool hermitian, MetaTensor* out); +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y); + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index d0195f069f0f98..c36a3c1694ae9b 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1,16 +1,6 @@ # This file is designed for backward C++ operators associated with # the operator in ops.yaml. -- backward_op : moe_combine_grad - forward : moe_combine (Tensor x, Tensor combine_weights, Tensor scatter_index) -> Tensor(y) - args : (Tensor x, Tensor combine_weights, Tensor scatter_index, Tensor y_grad) - output : Tensor(x_grad), Tensor(combine_weights_grad) - infer_meta : - func : MoeCombineGradInferMeta - kernel : - func : moe_combine_grad - data_type : x - - backward_op : abs_double_grad forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad) @@ -2257,6 +2247,15 @@ kernel : func : mode_grad +- backward_op : moe_combine_grad + forward : moe_combine (Tensor x, Tensor combine_weights, Tensor scatter_index) -> Tensor(y) + args : (Tensor x, Tensor combine_weights, Tensor scatter_index, Tensor y_grad) + output : Tensor(x_grad), Tensor(combine_weights_grad) + infer_meta : + func : MoeCombineGradInferMeta + kernel : + func : moe_combine_grad + - backward_op : mp_allreduce_sum_grad forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out) args : (Tensor out_grad, int ring_id = 0) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index c8e89064a1c337..0b786dd315ac09 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -7,16 +7,6 @@ # interfaces : paddle::dialect::InferSymbolicShapeInterface -- op : moe_combine - args : (Tensor x, Tensor combine_weights, Tensor scatter_index) - output : Tensor(y) - infer_meta : - func : MoeCombineInferMeta - kernel : - func : moe_combine - data_type : x - backward : moe_combine_grad - - op : abs args : (Tensor x) output : Tensor(out) @@ -3593,6 +3583,16 @@ backward : mode_grad interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface +- op : moe_combine + args : (Tensor x, Tensor combine_weights, Tensor scatter_index) + output : Tensor(y) + infer_meta : + func : MoeCombineInferMeta + kernel : + func : moe_combine + data_type : x + backward : moe_combine_grad + - op : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) From df96e812e1a1d0746f54fbc3a84a0de54cf3e487 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 06:02:22 +0000 Subject: [PATCH 05/71] update python API --- .../incubate/nn/functional/moe_combine.py | 31 +------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/python/paddle/incubate/nn/functional/moe_combine.py b/python/paddle/incubate/nn/functional/moe_combine.py index 1e7a26e657a3de..13dd712e987392 100644 --- a/python/paddle/incubate/nn/functional/moe_combine.py +++ b/python/paddle/incubate/nn/functional/moe_combine.py @@ -31,33 +31,4 @@ def moe_combine( 'scatter_index': scatter_index } helper.append_op(type='moe_combine', inputs=inputs, outputs={'y': y}) - return y - -if __name__ == "__main__": - print("This module is not for direct use.") - x = paddle.arange(1, 16).view((5, 3)).astype('float32') - combine_weights = paddle.to_tensor([ - [0, 0], - [0, 0], - [0.5, 0.5], - [0.5, 0.5], - [0.5, 0.5] - ]) - - # 分散索引 - scatter_index = paddle.to_tensor([ - [0, 1, 0, 0, 0], - [0, 1, 0, 0, 0] - ]).astype('int32') - - # 输出计算 - output = paddle.zeros((5, 3)) - for s in range(5): - expert0_idx = scatter_index[0, s] - expert1_idx = scatter_index[1, s] - output[s] = ( - x[expert0_idx] * combine_weights[s, 0] + - x[expert1_idx] * combine_weights[s, 1] - ) - print(output) - print(moe_combine(x, combine_weights, scatter_index)) \ No newline at end of file + return y \ No newline at end of file From 9c7cf2536ab1f85de00224e8783ec74947281673 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 07:08:53 +0000 Subject: [PATCH 06/71] delete useless header file --- paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu | 1 - paddle/phi/kernels/gpu/moe_combine_kernel.cu | 1 - paddle/phi/kernels/moe_combine_grad_kernel.h | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu index 4192d58a30bd1f..6e091cf4bafd78 100644 --- a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -1,6 +1,5 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/moe_combine_grad_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu index 47ae75dd1db829..feef6c536b2d0d 100644 --- a/paddle/phi/kernels/gpu/moe_combine_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -1,7 +1,6 @@ #include "paddle/phi/kernels/moe_combine_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/full_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/moe_combine_grad_kernel.h b/paddle/phi/kernels/moe_combine_grad_kernel.h index dd8e6c2651f8db..43682c941f87fe 100644 --- a/paddle/phi/kernels/moe_combine_grad_kernel.h +++ b/paddle/phi/kernels/moe_combine_grad_kernel.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/common/scalar.h" + namespace phi { template void MoeCombineGradKernel(const Context& dev_ctx, From 81efa8986efa62bd178a130acbc56cc8b4b4e5de Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 07:30:55 +0000 Subject: [PATCH 07/71] remove supported by DCU --- paddle/phi/kernels/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 2f45770291bd58..f9f8a91da8ce45 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -228,6 +228,8 @@ if(WITH_ROCM) list( REMOVE_ITEM kernel_gpu + "gpu/moe_combine_kernel.cu" + "gpu/moe_combine_grad_kernel.cu" "gpu/affine_grid_grad_kernel.cu" "gpu/apply_per_channel_scale_kernel.cu" "gpu/calc_reduced_attn_kernel.cu" From 56ce7d28b4b7b86636c58c2b4c9163f1bc8e1502 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 09:32:01 +0000 Subject: [PATCH 08/71] add expand_modality_expert_id kernel --- paddle/phi/infermeta/unary.cc | 52 +++++++------- paddle/phi/infermeta/unary.h | 14 ++-- paddle/phi/kernels/CMakeLists.txt | 1 + .../expand_modality_expert_id_kernel.h | 15 ++++ .../gpu/expand_modality_expert_id_kernel.cu | 68 +++++++++++++++++++ .../functional/expand_modality_expert_id.py | 44 ++++++++++++ 6 files changed, 164 insertions(+), 30 deletions(-) create mode 100644 paddle/phi/kernels/expand_modality_expert_id_kernel.h create mode 100644 paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu create mode 100644 python/paddle/incubate/nn/functional/expand_modality_expert_id.py diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3a74bc7bfe9f45..954f8b5d35ec7e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -46,29 +46,6 @@ static DDim CheckAndGetOutputDim(const DDim& dim_x) { } } // namespace detail -void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, - int64_t num_expert_per_modality, - int64_t group_size, - int64_t modality_offset, - bool is_group_expert, - MetaTensor* expert_id_out){ - auto expert_id_dims = expert_id.dims(); - PADDLE_ENFORCE_EQ( - expert_id_dims.size(), - 2, - common::errors::InvalidArgument( - "The input expert_id's dimensions size should be 2. But received " - "expert_id's dimensions size=[%d], expert_id's dimensions=[%s].", - expert_id_dims.size(), - expert_id_dims)); - - int64_t seqlen = expert_id_dims[0]; - int64_t k = expert_id_dims[1]; - expert_id_out->set_dims(common::make_ddim({seqlen, k})); - expert_id_out->set_dtype(expert_id.dtype()); -} - - void AddPositionEncodingInferMeta(const MetaTensor& x, float alpha, float beta, @@ -1388,6 +1365,35 @@ void ExpandInferMeta(const MetaTensor& x, #undef EXPAND_MAX_RANK_SUPPORTED } +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + MetaTensor* expert_id_out){ + auto expert_id_dims = expert_id.dims(); + PADDLE_ENFORCE_EQ( + expert_id_dims.size(), + 2, + common::errors::InvalidArgument( + "The input expert_id's dimensions size should be 2. But received " + "expert_id's dimensions size=[%d], expert_id's dimensions=[%s].", + expert_id_dims.size(), + expert_id_dims)); + PADDLE_ENFORCE_EQ( + expert_id.dtype() == DataType::INT32 || expert_id.dtype() == DataType::INT64, + true, + common::errors::InvalidArgument( + "The dtype of expert_id should be INT32 or INT64. But received" + "dtype=%s.", + DataTypeToString(expert_id.dtype()))); + + int64_t seqlen = expert_id_dims[0]; + int64_t k = expert_id_dims[1]; + expert_id_out->set_dims(common::make_ddim({seqlen, k})); + expert_id_out->set_dtype(expert_id.dtype()); +} + void FakeChannelWiseQuantizeAbsMaxInferMeta(const MetaTensor& x, int bit_length, int round_type, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index f9d26f715536ce..966cbe6466f8a1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -33,13 +33,6 @@ struct MetaConfig; // // The InferMeta Functions in this file are arranged in alphabetic order. -void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, - int64_t num_expert_per_modality, - int64_t group_size, - int64_t modality_offset, - bool is_group_expert - MetaTensor* expert_id_out); - void AddPositionEncodingInferMeta(const MetaTensor& x, float alpha, float beta, @@ -255,6 +248,13 @@ void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out); +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + MetaTensor* expert_id_out); + void FakeChannelWiseQuantizeAbsMaxInferMeta(const MetaTensor& x, int bit_length, int round_type, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index f9f8a91da8ce45..629e59d682b651 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -228,6 +228,7 @@ if(WITH_ROCM) list( REMOVE_ITEM kernel_gpu + "gpu/expand_modality_expert_id_kernel.cu" "gpu/moe_combine_kernel.cu" "gpu/moe_combine_grad_kernel.cu" "gpu/affine_grid_grad_kernel.cu" diff --git a/paddle/phi/kernels/expand_modality_expert_id_kernel.h b/paddle/phi/kernels/expand_modality_expert_id_kernel.h new file mode 100644 index 00000000000000..0f6ce161dce05f --- /dev/null +++ b/paddle/phi/kernels/expand_modality_expert_id_kernel.h @@ -0,0 +1,15 @@ +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ExpandModalityExpertIDKernel(const Context& dev_ctx, + const DenseTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + DenseTensor* expert_id_out); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu new file mode 100644 index 00000000000000..4e06d7274325a4 --- /dev/null +++ b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu @@ -0,0 +1,68 @@ +#include "paddle/phi/kernels/expand_modality_expert_id_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include + +namespace phi { + +template +void expand_modality_expert_id(const T* expert_id, + T* expert_id_out, + int64_t seqlen, + int64_t k, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + cudaStream_t stream){ + thrust::transform( + thrust::cuda::par.on(stream), + thrust::device_pointer_cast(expert_id), + thrust::device_pointer_cast(expert_id) + seqlen * k, + thrust::counting_iterator(0), + thrust::device_pointer_cast(expert_id_out), + [k, num_expert_per_modality, group_size, modality_offset, is_group_expert] __device__(T e, T idx) { + if (is_group_expert){ + e += idx % k * group_size; + } + if (num_expert_per_modality <= 0) + return static_cast(e); + T rank = e / num_expert_per_modality; + T expert_id_in_rank = e % num_expert_per_modality; + return static_cast(rank * (num_expert_per_modality * 2) // HRAD code: only support 2 modality + + expert_id_in_rank + + modality_offset * num_expert_per_modality); + } + ); +} + +template +void ExpandModalityExpertIDKernel(const Context& dev_ctx, + const DenseTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + DenseTensor* expert_id_out){ + dev_ctx.template Alloc(expert_id_out); + auto expert_id_shape = expert_id.dims(); + int64_t seqlen = expert_id_shape[0]; + int64_t k = expert_id_shape[1]; + expand_modality_expert_id(expert_id.data(), + expert_id_out->data(), + seqlen, + k, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + dev_ctx.stream()); +} +} // namespace phi + +PD_REGISTER_KERNEL(expand_modality_expert_id, + GPU, + ALL_LAYOUT, + phi::ExpandModalityExpertIDKernel, + int, + int64_t) {} \ No newline at end of file diff --git a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py new file mode 100644 index 00000000000000..086b42c83b9a4a --- /dev/null +++ b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py @@ -0,0 +1,44 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from paddle import _C_ops +import paddle +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + +def expand_modality_expert_id( + expert_id: Tensor, + num_expert_per_modality: int, + group_size: int, + modality_offset: int, + is_group_expert: bool, + name: str | None = None +) -> Tensor: + """ + Args: + expert_id: + num_expert_per_modality: + group_size: + modality_offset: + is_group_expert: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.expand_modality_expert_id(expert_id, num_expert_per_modality, group_size, modality_offset, is_group_expert) + helper = LayerHelper('expand_modality_expert_id', **locals()) + expert_id_out = helper.create_variable_for_type_inference(dtype=expert_id.dtype) + inputs = { + 'expert_id': expert_id + } + attrs = { + 'num_expert_per_modality': num_expert_per_modality, + 'group_size': group_size, + 'modality_offset': modality_offset, + 'is_group_expert': is_group_expert + } + helper.append_op(type='expand_modality_expert_id', inputs=inputs, attrs=attrs, outputs={'expert_id_out': expert_id_out}) + return y From e2e6ac76a1000e8849b5b11660506baa323abd3f Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 21 May 2025 09:48:52 +0000 Subject: [PATCH 09/71] reorder the new code and refine OP type --- paddle/phi/infermeta/backward.cc | 23 +++++++++++++++++++ paddle/phi/infermeta/backward.h | 7 ++++++ paddle/phi/infermeta/multiary.cc | 23 ------------------- paddle/phi/infermeta/multiary.h | 7 ------ .../kernels/gpu/moe_combine_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/moe_combine_kernel.cu | 1 + paddle/phi/ops/yaml/ops.yaml | 18 +++++++-------- 7 files changed, 41 insertions(+), 39 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a7d368ea869b22..805e4a0ac868a5 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1217,6 +1217,29 @@ void MeshgridGradInferMeta(const std::vector& inputs, } } +void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper){ + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + errors::InvalidArgument("Input X should have 2 dimensions")); + PADDLE_ENFORCE_EQ( + (scatter_index.dtype() == phi::DataType::INT32), + true, + errors::InvalidArgument( + "The input scatter_index type should be int32")); + grad_x->set_dims(phi::make_ddim({x_dim[0],x_dim[1]})); + grad_x->set_dtype(x.dtype()); + grad_combine_weights_helper->set_dims(phi::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); + grad_combine_weights_helper->set_dtype(x.dtype()); +} + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bca0c6f53906f9..50cd2500b26d72 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -462,6 +462,13 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query, MetaTensor* value_grad, MetaTensor* bias_grad); +void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& grad_y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper); + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 577e6b6f027a8d..1ef5cd2679006a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -35,29 +35,6 @@ limitations under the License. */ namespace phi { - void MoeCombineGradInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& scatter_index, - const MetaTensor& grad_y, - MetaTensor* grad_x, - MetaTensor* grad_combine_weights_helper){ - auto x_dim = x.dims(); - auto combine_weights_shape = combine_weights.dims(); - PADDLE_ENFORCE_EQ( - x_dim.size(), - 2, - errors::InvalidArgument("Input X should have 2 dimensions")); - PADDLE_ENFORCE_EQ( - (scatter_index.dtype() == phi::DataType::INT32), - true, - errors::InvalidArgument( - "The input scatter_index type should be int32")); - grad_x->set_dims(phi::make_ddim({x_dim[0],x_dim[1]})); - grad_x->set_dtype(x.dtype()); - grad_combine_weights_helper->set_dims(phi::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); - grad_combine_weights_helper->set_dtype(x.dtype()); -} - std::vector GetMetaTensorsDim( const std::vector& tensors) { std::vector dims; diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 1f6a8dfa7a5651..dfe1af6754aa9d 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -40,13 +40,6 @@ namespace phi { std::vector GetMetaTensorsDim( const std::vector& tensors); -void MoeCombineGradInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& scatter_index, - const MetaTensor& grad_y, - MetaTensor* grad_x, - MetaTensor* grad_combine_weights_helper); - void AdadeltaInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& avg_squared_grad, diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu index 6e091cf4bafd78..6d1247cd856120 100644 --- a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -149,5 +149,6 @@ PD_REGISTER_KERNEL(moe_combine_grad, ALL_LAYOUT, phi::MoeCombineGradKernel, float, + double, phi::dtype::bfloat16, phi::dtype::float16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu index feef6c536b2d0d..ae57420fc985f5 100644 --- a/paddle/phi/kernels/gpu/moe_combine_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -109,5 +109,6 @@ PD_REGISTER_KERNEL(moe_combine, ALL_LAYOUT, phi::MoeCombineKernel, float, + double, phi::dtype::bfloat16, phi::dtype::float16) {} \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index d131e6bda2cba5..6532a4f3812f01 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -7,15 +7,6 @@ # interfaces : paddle::dialect::InferSymbolicShapeInterface -- op : expand_modality_expert_id - args : (Tensor expert_id, int64_t num_expert_per_modality, int64_t group_size, int64_t modality_offset, bool is_group_expert) - output : Tensor(expert_id_out) - infer_meta : - func : ExpandModalityExpertIdInferMeta - kernel : - func : expand_modality_expert_id - data_type : expert_id - - op : abs args : (Tensor x) output : Tensor(out) @@ -1800,6 +1791,15 @@ backward : expand_as_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : expand_modality_expert_id + args : (Tensor expert_id, int64_t num_expert_per_modality, int64_t group_size, int64_t modality_offset, bool is_group_expert) + output : Tensor(expert_id_out) + infer_meta : + func : ExpandModalityExpertIdInferMeta + kernel : + func : expand_modality_expert_id + data_type : expert_id + - op : expm1 args : (Tensor x) output : Tensor(out) From ddf247cf655467e1cfcbb78e4391c1abe20e581b Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Thu, 22 May 2025 07:36:49 +0000 Subject: [PATCH 10/71] add unit test --- .../paddle/incubate/nn/functional/__init__.py | 4 +++ .../incubate/nn/functional/moe_combine.py | 2 +- .../test_expand_modality_expert_id.py | 18 ++++++++++++ test/legacy_test/test_moe_combine.py | 28 +++++++++++++++++++ 4 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 test/legacy_test/test_expand_modality_expert_id.py create mode 100644 test/legacy_test/test_moe_combine.py diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index aec7625145d348..ecba7a94d43517 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -43,6 +43,8 @@ from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, ) +from .moe_combine import moe_combine +from .expand_modality_expert_id import expand_modality_expert_id __all__ = [ 'fused_multi_head_attention', @@ -62,4 +64,6 @@ "blha_get_max_len", "block_multihead_attention", "swiglu", + "moe_combine", + "expand_modality_expert_id", ] diff --git a/python/paddle/incubate/nn/functional/moe_combine.py b/python/paddle/incubate/nn/functional/moe_combine.py index 13dd712e987392..78964c2cf9a1d6 100644 --- a/python/paddle/incubate/nn/functional/moe_combine.py +++ b/python/paddle/incubate/nn/functional/moe_combine.py @@ -16,7 +16,7 @@ def moe_combine( Args: x: Input tensor [seq, dim] combine_weights: Combination weights [s, k] - scatter_index: Scatter indices [k, s] + scatter_index: Scatter indices [k, s] dtype=int32 Returns: Output Combined output [s, dim] diff --git a/test/legacy_test/test_expand_modality_expert_id.py b/test/legacy_test/test_expand_modality_expert_id.py new file mode 100644 index 00000000000000..d197f6d1f4f597 --- /dev/null +++ b/test/legacy_test/test_expand_modality_expert_id.py @@ -0,0 +1,18 @@ +from paddle.incubate.nn.functional import expand_modality_expert_id +import paddle + +num_expert_per_modality = 4 +group_size = 10 +modality_offset = 3 +is_group_expert = True + +expert_id = paddle.to_tensor([[0, 1, 2,], [3, 4, 5]], dtype='int32') + +expert_id_out = expand_modality_expert_id(expert_id, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert) + +print(expert_id_out) + diff --git a/test/legacy_test/test_moe_combine.py b/test/legacy_test/test_moe_combine.py new file mode 100644 index 00000000000000..eea62e1f0a6728 --- /dev/null +++ b/test/legacy_test/test_moe_combine.py @@ -0,0 +1,28 @@ +import paddle +from paddle.incubate.nn.functional import moe_combine + +x = paddle.arange(1, 16).view((5, 3)).astype("float32") # [[1,2,3], [4,5,6], ..., [13,14,15]] +x.stop_gradient = False + +# 组合权重(手动构造), 数据类型需要与x相同 +combine_weights = paddle.to_tensor([ +[0.7, 0.3], +[0.6, 0.4], +[0.5, 0.5], +[0.4, 0.6], +[0.2, 0.8] +], stop_gradient=False) + +# 分散索引 仅支持int32 +scatter_index = paddle.to_tensor([ +[0, 1, 2, 3, 4], +[0, 1, 2, 3, 4] +], dtype="int32", stop_gradient=False) + +y = moe_combine(x, combine_weights, scatter_index) +print("\n##########forward output##########\n") +print(y) +print(f"x.grad: {x.grad,}, combine_weights.grad: {combine_weights.grad}, scatter_index.grad: {scatter_index.grad}") +y.backward() +print("\n##########backward output##########\n") +print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n scatter_index.grad: {scatter_index.grad}") \ No newline at end of file From ad82cc8dd1dc977136c339b96cac5476e4d50b86 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Thu, 22 May 2025 07:48:25 +0000 Subject: [PATCH 11/71] add cal_aux_loss_op and build_src_rank_and_local_expert_id_op --- paddle/phi/infermeta/backward.cc | 20 + paddle/phi/infermeta/backward.h | 9 + paddle/phi/infermeta/multiary.cc | 89 +++++ paddle/phi/infermeta/multiary.h | 12 + paddle/phi/infermeta/unary.cc | 22 ++ paddle/phi/infermeta/unary.h | 7 + paddle/phi/kernels/CMakeLists.txt | 3 + ...uild_src_rank_and_local_expert_id_kernel.h | 30 ++ paddle/phi/kernels/cal_aux_loss_kernel.h | 45 +++ ...ild_src_rank_and_local_expert_id_kernel.cu | 110 ++++++ paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 361 ++++++++++++++++++ paddle/phi/ops/yaml/backward.yaml | 9 + paddle/phi/ops/yaml/ops.yaml | 19 + .../build_src_rank_and_local_expert_id.py | 67 ++++ .../incubate/nn/functional/cal_aux_loss.py | 90 +++++ 15 files changed, 893 insertions(+) create mode 100644 paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h create mode 100644 paddle/phi/kernels/cal_aux_loss_kernel.h create mode 100644 paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu mode change 100755 => 100644 paddle/phi/ops/yaml/ops.yaml create mode 100644 python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py create mode 100644 python/paddle/incubate/nn/functional/cal_aux_loss.py diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a7d368ea869b22..bfe718bf867551 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1887,4 +1887,24 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, value_grad->share_lod(values); } } + +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, + const MetaTensor& seqlen_float, + const MetaTensor& ce, + const MetaTensor& out_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + MetaTensor* gate_prob_grad) { + auto gate_prob_dims = gate_prob.dims(); + + PADDLE_ENFORCE_EQ( + gate_prob.dtype(), + out_grad.dtype(), + errors::InvalidArgument( + "The input out_grad type should be equal to gate_prob type")); + + gate_prob_grad->set_dims({gate_prob_dims}); + gate_prob_grad->set_dtype(gate_prob.dtype()); +} } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bca0c6f53906f9..7990568f09b022 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -680,4 +680,13 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* value_grad); +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, + const MetaTensor& seqlen_float, + const MetaTensor& ce, + const MetaTensor& out_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + MetaTensor* gate_prob_grad); + } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ef5cd2679006a..5ddc5fd5a3080a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6273,5 +6273,94 @@ void TopPSamplingInferMeta(const MetaTensor& x, } } +void CalAuxLossInferMeta(const MetaTensor& gate_prob, + const MetaTensor& dispatch_mask, + const MetaTensor& tokens_mask, + const MetaTensor& dispatch_tokens_mask, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + MetaTensor* l_aux_loss, + MetaTensor* seqlen_floats, + MetaTensor* ce) { + auto gate_prob_dims = gate_prob.dims(); + auto dispatch_mask_dims = dispatch_mask.dims(); + auto tokens_mask_dims = tokens_mask.dims(); + auto dispatch_tokens_mask_dims = dispatch_tokens_mask.dims(); + + PADDLE_ENFORCE_EQ( + gate_prob_dims.size(), + 2, + errors::InvalidArgument("Input gate_prob_dims should have 2 dimensions")); + + PADDLE_ENFORCE_EQ(gate_prob_dims[0] >= gate_prob_dims[1], + true, + errors::InvalidArgument( + "The value of gate_prob_dims[0] should be greater than " + "or equal to that of gate_prob_dims[1].")); + + PADDLE_ENFORCE_EQ( + gate_prob_dims[1] <= 1024, + true, + errors::InvalidArgument( + "The value of gate_prob_dims[1] should be less than 1024.")); + + PADDLE_ENFORCE_EQ( + (dispatch_mask_dims.size() == 1) || (dispatch_mask_dims.size() == 2), + true, + errors::InvalidArgument( + "Input dispatch_mask_dims should have 1 or 2 dimensions")); + + if (dispatch_mask_dims.size() == 1) { + PADDLE_ENFORCE_EQ( + dispatch_mask_dims[0], + gate_prob_dims[1], + errors::InvalidArgument("The value of dispatch_mask_shape.back() " + "should be equal to gate_prob_shape.back().")); + } else { + PADDLE_ENFORCE_EQ( + dispatch_mask_dims[1], + gate_prob_dims[1], + errors::InvalidArgument("The value of dispatch_mask_shape.back() " + "should be equal to gate_prob_shape.back().")); + } + + PADDLE_ENFORCE_EQ( + dispatch_mask.dtype(), + phi::DataType::INT64, + errors::InvalidArgument("The input dispatch_mask type should be INT64")); + + PADDLE_ENFORCE_EQ( + tokens_mask_dims.size(), + 1, + errors::InvalidArgument("Input tokens_mask should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + tokens_mask.dtype(), + gate_prob.dtype(), + errors::InvalidArgument( + "The input tokens_mask type should be equal to gate_prob type")); + + PADDLE_ENFORCE_EQ(dispatch_tokens_mask_dims.size(), + 1, + errors::InvalidArgument( + "Input dispatch_tokens_mask should have 1 dimensions")); + + PADDLE_ENFORCE_EQ(dispatch_tokens_mask.dtype(), + phi::DataType::BOOL, + errors::InvalidArgument( + "The input dispatch_tokens_mask type should be BOOL")); + + l_aux_loss->set_dims({1}); + l_aux_loss->set_dtype(gate_prob.dtype()); + + seqlen_floats->set_dims({1}); + seqlen_floats->set_dtype(gate_prob.dtype()); + + ce->set_dims({gate_prob_dims[1]}); + ce->set_dtype(gate_prob.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index dfe1af6754aa9d..f511ce4d9d6f28 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -1284,4 +1284,16 @@ void TopPSamplingInferMeta(const MetaTensor& x, MetaTensor* topk_scores, MetaTensor* topk_ids); +void CalAuxLossInferMeta(const MetaTensor& gate_prob, + const MetaTensor& dispatch_mask, + const MetaTensor& tokens_mask, + const MetaTensor& dispatch_tokens_mask, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + MetaTensor* l_aux_loss, + MetaTensor* seqlen_floats, + MetaTensor* ce); + } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index abf1823d67c86e..a050d43d2a7d81 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -6164,6 +6164,28 @@ void ArrayPopInferMeta(const MetaTensor& array, out->set_dtype(array.dtype()); } +void BuildSrcRankAndLocalExpertIdInferMeta( + const MetaTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + MetaTensor* src_rank, + MetaTensor* local_expert_id) { + int64_t token_num = + std::accumulate(expert_num_global.begin(), expert_num_global.end(), 0); + + PADDLE_ENFORCE_EQ( + expert_num_global_tensor.dtype(), + phi::DataType::INT64, + errors::InvalidArgument( + "The input expert_num_global_tensor type should be INT64")); + + src_rank->set_dims({token_num}); + src_rank->set_dtype(DataType::INT32); + + local_expert_id->set_dims({token_num}); + local_expert_id->set_dtype(DataType::INT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6e9454e9fdac9d..9be0e2af1f6667 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -998,4 +998,11 @@ void ArrayPopInferMeta(const MetaTensor& array, MetaTensor* out, MetaConfig config = MetaConfig()); +void BuildSrcRankAndLocalExpertIdInferMeta( + const MetaTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + MetaTensor* src_rank, + MetaTensor* local_expert_id); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 2f45770291bd58..2f0c1eb96d856d 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -229,6 +229,9 @@ if(WITH_ROCM) REMOVE_ITEM kernel_gpu "gpu/affine_grid_grad_kernel.cu" + "gpu/cal_aux_loss_kernel.cu" + "gpu/cal_aux_loss_grad_kernel.cu" + "build_src_rank_and_local_expert_id_kernel.cu" "gpu/apply_per_channel_scale_kernel.cu" "gpu/calc_reduced_attn_kernel.cu" "gpu/eigvalsh_kernel.cu" diff --git a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h new file mode 100644 index 00000000000000..34f99e2280a257 --- /dev/null +++ b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BuildSrcRankAndLocalExpertIdInferMeta( + const Context& dev_ctx, + const DenseTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + DenseTensor* src_rank, + DenseTensor* local_expert_id); + +} // namespace phi diff --git a/paddle/phi/kernels/cal_aux_loss_kernel.h b/paddle/phi/kernels/cal_aux_loss_kernel.h new file mode 100644 index 00000000000000..e6e682d7db43fc --- /dev/null +++ b/paddle/phi/kernels/cal_aux_loss_kernel.h @@ -0,0 +1,45 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CalAuxLossKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& dispatch_mask, + const DenseTensor& tokens_mask, + const DenseTensor& dispatch_tokens_mask, + int64_t num_experts, + bool use_group, + int64_t moe_k, + float clip_min, + DenseTensor* l_aux_loss, + DenseTensor* seqlen_float, + DenseTensor* ce); + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const DenseTensor& out_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu new file mode 100644 index 00000000000000..8caab2d7e4badc --- /dev/null +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -0,0 +1,110 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include + +#include "paddle/extension.h" +#include "paddle/phi/api/all.h" +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/api/ext/spmd_infer.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h" + +namespace phi { + +template +__global__ void build_srcrank_and_local_expert_id_kernel( + T* src_rank, + T* local_expert_id, + const U* expert_num, + int64_t total_num, + int64_t num_total_experts, + int64_t num_local_experts) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_total_experts) return; + int64_t start = 0; + int64_t end = 0; + for (int64_t i = 0; i < num_total_experts; ++i) { + end += expert_num[i]; + if (i == tid) { + break; + } + start += expert_num[i]; + } + for (int64_t i = start; i != end; ++i) { + src_rank[i] = static_cast(tid / num_local_experts); + local_expert_id[i] = static_cast(tid % num_local_experts); + } +} + +template +void build_srcrank_and_local_expert_id(T* src_rank, + T* local_expert_id, + const U* expert_num, + int64_t total_num, + int64_t num_total_experts, + int64_t num_local_experts, + cudaStream_t stream) { + int64_t threads_per_block = 32; + int64_t blocks = + (num_total_experts + threads_per_block - 1) / threads_per_block; + build_srcrank_and_local_expert_id_kernel + <<>>(src_rank, + local_expert_id, + expert_num, + total_num, + num_total_experts, + num_local_experts); +} + +template +void BuildSrcRankAndLocalExpertIdKernel( + const Context& dev_ctx, + const DenseTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + DenseTensor* src_rank, + DenseTensor* local_expert_id) { + int64_t token_num = + std::accumulate(expert_num_global.begin(), expert_num_global.end(), 0); + + const int64_t* expert_num_global_tensor_data = + expert_num_global_tensor.data(); + + T* src_rank_data = dev_ctx.template Alloc(src_rank); + T* local_expert_id_data = dev_ctx.template Alloc(local_expert_id); + + build_srcrank_and_local_expert_id(src_rank_data, + local_expert_id_data, + expert_num_global_tensor_data, + token_num, + expert_num_global.size(), + num_local_experts, + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(build_srcrank_and_local_expert_id, + GPU, + ALL_LAYOUT, + phi::BuildSrcRankAndLocalExpertIdKernel, + float) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu new file mode 100644 index 00000000000000..ec599cdd38dc8e --- /dev/null +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -0,0 +1,361 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +// #include "paddle/extension.h" +// #include "paddle/phi/api/all.h" +#include "paddle/phi/core/dense_tensor.h" +// #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" +// #include "paddle/extension.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cal_aux_loss_kernel.h" + +namespace phi { + +template +__global__ void cal_aux_loss_kernel( + const T* gate_prob, /*[s, e]*/ + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const int64_t* dispatch_mask, /*[s, e] or [e]*/ + const int64_t row_dispatch_mask, + const int64_t col_dispatch_mask, + const T* tokens_mask, /*[s]*/ + const bool* dispatch_tokens_mask, + const int64_t dispatch_tokens_mask_len, /*global_seq_len*/ + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + T* l_aux_loss, /*output*/ + T* seqlen_float, + T* ce) { + extern __shared__ int64_t aux_loss_shared[]; + static __shared__ float shared_float[1]; + + // 算seqlen_float + float seqlen_float_f = 0.f; + if (dispatch_tokens_mask) { + float local_seqlen_float_f = 0.f; + int64_t num_k = (dispatch_tokens_mask_len + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= dispatch_tokens_mask_len) continue; + bool mask = dispatch_tokens_mask[k * blockDim.x + threadIdx.x]; + local_seqlen_float_f += static_cast(mask); + } + seqlen_float_f = + phi::funcs::BlockReduceSum(local_seqlen_float_f, 0xFFFFFFFF); + } else if (tokens_mask) { + float local_seqlen_float_f = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; + T mask = tokens_mask[k * blockDim.x + threadIdx.x]; + local_seqlen_float_f += static_cast(mask); + } + seqlen_float_f = + phi::funcs::BlockReduceSum(local_seqlen_float_f, 0xFFFFFFFF); + } else { + seqlen_float_f = static_cast(row_gate_prob) / + static_cast(num_experts) * + static_cast(col_gate_prob); + } + + if (threadIdx.x == 0) { + shared_float[0] = max(seqlen_float_f, clip_min); + } + __syncthreads(); + seqlen_float_f = shared_float[0]; + + __syncthreads(); + // 处理dispatch_mask + if (col_dispatch_mask > 1) { + int64_t num_k = (row_dispatch_mask + blockDim.x - 1) / blockDim.x; + + for (int64_t e = 0; e < col_dispatch_mask; e++) { + int64_t local_sum_val = 0.f; + for (int64_t k = 0; k < num_k; ++k) { + int64_t mask_val = 0; + if (k * blockDim.x + threadIdx.x < row_dispatch_mask) { + mask_val = static_cast( + dispatch_mask[(k * blockDim.x + threadIdx.x) * col_dispatch_mask + + e]); + } + local_sum_val += mask_val; + } + int64_t sum_val = + phi::funcs::BlockReduceSum(local_sum_val, 0xFFFFFFFF); + if (threadIdx.x == 0) { + aux_loss_shared[e] = sum_val; + } + } + } else { + if (threadIdx.x < row_dispatch_mask) { + aux_loss_shared[threadIdx.x] = + static_cast(dispatch_mask[threadIdx.x]); + } + } + + // 算scale_val + float scale_val = 1.f; + if (tokens_mask) { + float sum_tokens_mask = 0.f; + float local_sum_tokens_mask = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; + T mask = tokens_mask[k * blockDim.x + threadIdx.x]; + local_sum_tokens_mask += static_cast(mask); + } + sum_tokens_mask = + phi::funcs::BlockReduceSum(local_sum_tokens_mask, 0xFFFFFFFF); + if (threadIdx.x == 0) { + shared_float[0] = seqlen_float_f / max(sum_tokens_mask, clip_min); + } + __syncthreads(); + scale_val = shared_float[0]; + } + + // 算me和l_aux + float l_aux = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t e = 0; e < col_gate_prob; e++) { + float local_sum_val = 0.f; + for (int64_t k = 0; k < num_k; ++k) { + float gate_prob_val = 0.f; + if (k * blockDim.x + threadIdx.x < row_gate_prob) { + gate_prob_val = static_cast( + gate_prob[(k * blockDim.x + threadIdx.x) * col_gate_prob + e]); + } + local_sum_val += gate_prob_val; + } + float sum_val = + phi::funcs::BlockReduceSum(local_sum_val, 0xFFFFFFFF); + if (threadIdx.x == 0) { + float ce_val = static_cast(aux_loss_shared[e]) / seqlen_float_f; + float me_val = sum_val / seqlen_float_f; + l_aux += ce_val * me_val * static_cast(num_experts); + ce[e] = static_cast(ce_val); + } + } + + if (threadIdx.x == 0) { + if (use_group) { + l_aux /= static_cast(moe_k); + } + l_aux = l_aux * scale_val; + *l_aux_loss = static_cast(l_aux); + *seqlen_float = static_cast(seqlen_float_f); + } +} + +template +void cal_aux_loss(const T* gate_prob, + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const int64_t* dispatch_mask, + const int64_t row_dispatch_mask, + const int64_t col_dispatch_mask, + const T* tokens_mask, + const bool* dispatch_tokens_mask, + const int64_t dispatch_tokens_mask_len, /*global_seq_len*/ + const int64_t num_experts, /*global_num_experts*/ + const bool use_group, + const int64_t moe_k, + const float clip_min, + T* l_aux_loss, /*output*/ + T* seqlen_float, + T* ce, + cudaStream_t stream) { + int64_t threads = 1024; + threads = std::min(row_gate_prob, threads); + cal_aux_loss_kernel + <<<1, threads, col_gate_prob * sizeof(int64_t), stream>>>( + gate_prob, + row_gate_prob, + col_gate_prob, + dispatch_mask, + row_dispatch_mask, + col_dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + dispatch_tokens_mask_len, + num_experts, + use_group, + moe_k, + clip_min, + l_aux_loss, + seqlen_float, + ce); +} + +template +void CalAuxLossKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& dispatch_mask, + const DenseTensor& tokens_mask, + const DenseTensor& dispatch_tokens_mask, + int64_t num_experts, + bool use_group, + int64_t moe_k, + float clip_min, + DenseTensor* l_aux_loss, + DenseTensor* seqlen_float, + DenseTensor* ce) { + auto gate_prob_dims = gate_prob.dims(); + auto dispatch_mask_dims = dispatch_mask.dims(); + auto dispatch_tokens_mask_dim = dispatch_tokens_mask.dims(); + + const T* gate_prob_data = gate_prob.data(); + const int64_t* dispatch_mask_data = dispatch_mask.data(); + const T* tokens_mask_data = tokens_mask.data(); + const bool* dispatch_tokens_mask_data = dispatch_tokens_mask.data(); + + T* l_aux_loss_data = dev_ctx.template Alloc(l_aux_loss); + T* seqlen_float_data = dev_ctx.template Alloc(seqlen_float); + T* ce_data = dev_ctx.template Alloc(ce); + + int64_t row_gate_prob = gate_prob_dims[0]; + int64_t col_gate_prob = gate_prob_dims[1]; + + int64_t col_dispatch_mask = 0; + int64_t row_dispatch_mask = dispatch_mask_dims[0]; + if (dispatch_mask_dims.size() > 1) { + col_dispatch_mask = dispatch_mask_dims[1]; + } else { + col_dispatch_mask = 1; + } + + int dispatch_tokens_mask_len = dispatch_tokens_mask_dim[0]; + + cal_aux_loss(gate_prob_data, + row_gate_prob, + col_gate_prob, + dispatch_mask_data, + row_dispatch_mask, + col_dispatch_mask, + tokens_mask_data, + dispatch_tokens_mask_data, + dispatch_tokens_mask_len, + num_experts, + use_group, + moe_k, + clip_min, + l_aux_loss_data, + seqlen_float_data, + ce_data, + dev_ctx.stream()); +} + +template +__global__ void cal_aux_loss_grad_kernel(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, + const int64_t col_gate_prob, + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad) { + T ce_val = ce[threadIdx.x]; + T l_aux_grad = *out_grad; + if (use_group) { + l_aux_grad = l_aux_grad / static_cast(moe_k); + } + l_aux_grad *= static_cast(num_experts); + + gate_prob_grad[blockIdx.x * col_gate_prob + threadIdx.x] = + (ce_val * l_aux_grad) / (*seqlen_float); +} + +template +void cal_aux_loss_grad(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad, + cudaStream_t stream) { + cal_aux_loss_grad_kernel + <<>>(out_grad, + gate_prob, + row_gate_prob, + col_gate_prob, + seqlen_float, + ce, + num_experts, + use_group, + moe_k, + gate_prob_grad); +} + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const DenseTensor& out_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad) { + auto gate_prob_dims = gate_prob.dims(); + + const T* out_grad_data = out_grad.data(); + const T* gate_prob_data = gate_prob.data(); + const T* seqlen_float_data = seqlen_float.data(); + const T* ce_data = ce.data(); + + int64_t row_gate_prob = gate_prob_dims[0]; + int64_t col_gate_prob = gate_prob_dims[1]; + + T* gate_prob_grad_data = dev_ctx.template Alloc(gate_prob_grad); + + cal_aux_loss_grad(out_grad_data, + gate_prob_data, + row_gate_prob, + col_gate_prob, + seqlen_float_data, + ce_data, + num_experts, + use_group, + moe_k, + gate_prob_grad_data, + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cal_aux_loss, + GPU, + ALL_LAYOUT, + phi::CalAuxLossKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + cal_aux_loss_grad, GPU, ALL_LAYOUT, phi::CalAuxLossGradKernel, float) {} diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 46307972aa48d6..4e56ba0e89959e 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -361,6 +361,15 @@ param: [softmax, label, loss_grad, ignore_index, rank, nranks] inplace : (softmax -> logits_grad) +- backward_op : cal_aux_loss_grad + forward : cal_aux_loss (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) -> Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) + args : (Tensor gate_prob, Tensor seqlen_float, Tensor ce, Tensor out_grad, int64_t num_experts, bool use_group, int64_t moe_k) + output : Tensor(gate_prob_grad) + infer_meta : + func : CalAuxLossGradInferMeta + kernel : + func : cal_aux_loss_grad + - backward_op : cast_grad forward : cast (Tensor x, DataType dtype) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml old mode 100755 new mode 100644 index 5bd4eca9a850ac..7d39ba2010b28c --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -820,6 +820,15 @@ backward: broadcast_tensors_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : build_src_rank_and_local_expert_id + args : (Tensor expert_num_global_tensor, int64_t[] expert_num_global, int64_t num_local_experts) + output : Tensor(vector), Tensor(local_expert_id) + infer_meta : + func : BuildSrcRankAndLocalExpertIdInferMeta + kernel : + func : build_src_rank_and_local_expert_id + data_type : expert_num_global_tensor + - op : c_allreduce_sum args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -884,6 +893,16 @@ func : c_split param: [x, rank, nranks, use_model_parallel] +- op : cal_aux_loss + args : (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) + output : Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) + infer_meta : + func : CalAuxLossInferMeta + kernel : + func : cal_aux_loss + data_type : gate_prob + backward : cal_aux_loss_grad + - op : calc_reduced_attn_scores args : (Tensor q, Tensor k, Tensor softmax_lse) output : Tensor(reduced_scores) diff --git a/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py new file mode 100644 index 00000000000000..69f0a1fca12704 --- /dev/null +++ b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def build_src_rank_and_local_expert_id( + expert_num_global_tensor: Tensor, + expert_num_global: list, + num_local_experts: int, + name: str | None = None, +) -> Tensor: + """ + Args: + expert_num_global_tensor: + expert_num_global: + num_local_experts: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.build_src_rank_and_local_expert_id( + expert_num_global_tensor, expert_num_global, num_local_experts + ) + + helper = LayerHelper('expert_num_global_tensor', **locals()) + vector = helper.create_variable_for_type_inference(dtype=paddle.int32) + local_expert_id = helper.create_variable_for_type_inference( + dtype=paddle.int32 + ) + + inputs = {'expert_num_global_tensor': expert_num_global_tensor} + attrs = { + 'expert_num_global': expert_num_global, + 'num_local_experts': num_local_experts, + } + outputs = {'vector': vector, 'local_expert_id': local_expert_id} + helper.append_op( + type='build_src_rank_and_local_expert_id', + inputs=inputs, + attrs=attrs, + outputs=outputs, + ) + return vector diff --git a/python/paddle/incubate/nn/functional/cal_aux_loss.py b/python/paddle/incubate/nn/functional/cal_aux_loss.py new file mode 100644 index 00000000000000..e759a62b77af6f --- /dev/null +++ b/python/paddle/incubate/nn/functional/cal_aux_loss.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def cal_aux_loss( + gate_prob: Tensor, + dispatch_mask: Tensor, + tokens_mask: Tensor, + dispatch_tokens_mask: Tensor, + num_experts: int, + use_group: bool, + moe_k: int, + clip_min: float, + name: str | None = None, +) -> Tensor: + """ + Args: + gate_prob: + dispatch_mask: + tokens_mask: + dispatch_tokens_mask: + num_experts: + use_group: + moe_k: + clip_min: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, + ) + + helper = LayerHelper('cal_aux_loss', **locals()) + l_aux_loss = helper.create_variable_for_type_inference( + dtype=gate_prob.dtype + ) + seqlen_float = helper.create_variable_for_type_inference( + dtype=gate_prob.dtype + ) + ce = helper.create_variable_for_type_inference(dtype=gate_prob.dtype) + + inputs = { + 'gate_prob': gate_prob, + 'dispatch_mask': dispatch_mask, + 'tokens_mask': tokens_mask, + 'dispatch_tokens_mask': dispatch_tokens_mask, + } + attrs = { + 'num_experts': num_experts, + 'use_group': use_group, + 'moe_k': moe_k, + 'clip_min': clip_min, + } + outputs = {'l_aux_loss': l_aux_loss, 'seqlen_float': seqlen_float, 'ce': ce} + helper.append_op( + type='cal_aux_loss', inputs=inputs, attrs=attrs, outputs=outputs + ) + return l_aux_loss From 8ade3ac2c95fa3e0df36fbaa20406120774d3a7a Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 03:49:23 +0000 Subject: [PATCH 12/71] moegatedispatch init --- paddle/phi/infermeta/ternary.cc | 101 ++++++++++++++++++++++++++++++++ paddle/phi/infermeta/ternary.h | 12 ++++ paddle/phi/ops/yaml/ops.yaml | 11 ++++ 3 files changed, 124 insertions(+) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index f215a7b68c6206..bdd4e074bdfe87 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1630,6 +1630,107 @@ void MoeCombineInferMeta(const MetaTensor& x, y->set_dtype(x.dtype()); } +void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id){ + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + common::errors::InvalidArgument("The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); + auto gate_logits_dims= gate_logits.dims(); + PADDLE_ENFORCE_EQ( + gate_logits_dims.size(), + 2, + common::errors::InvalidArgument("The dimensions of Input(gate_logits) must be 2, but " + "received dimensions of" + "Input(gate_logits) is [%d]", + gate_logits_dims.size())); + PADDLE_ENFORCE_EQ( + gate_logits_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(gate_logits) must be equal to the first " + "dimension of Input(x), but received Input(gate_logits) shape is [%d]," + "Input(x) shape is [%d]", + gate_logits_dims[0], + x_dims[0])); + PADDLE_ENFORCE_EQ( + gate_logits_dims[1] % world_size, + 0, + common::errors::InvalidArgument( + "The number of experts (the second dimension of Input(gate_logits)) must be divisible by world_size, but received " + "num_experts = %d, world_size = %d", + gate_logits_dims[1], + world_size)); + + PADDLE_ENFORCE_GE( + gate_logits_dims[1], + k, + common::errors::InvalidArgument( + "The number of experts ((the second dimension of Input(gate_logits))) must be greater than or equal to k, but received " + "num_experts = %d, k = %d", + gate_logits_dims[1], + k)); + + PADDLE_ENFORCE_EQ( + gate_logits.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(gate_logits) must be FLOAT32, but received %s", + gate_logits.dtype())); + + if(corr_bias){ + auto corr_bias_dims = corr_bias.dims(); + PADDLE_ENFORCE_EQ( + corr_bias_dims.size(), + 1, + common::errors::InvalidArgument( + "The dimensions of Input(corr_bias) must be 1, but received " + "dimensions of Input(corr_bias) is [%d]", + corr_bias_dims.size())); + PADDLE_ENFORCE_EQ( + corr_bias_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The dimensions of Input(corr_bias) must be equal to the first " + "dimension of Input(x), but received Input(corr_bias) first dimension is [%d]," + "Input(x) first dimension is [%d]", + corr_bias_dims[0], + x_dims[0])); + PADDLE_ENFORCE_EQ( + corr_bias.dtype(), + paddle::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(corr_bias) must be FLOAT32, but received %s", + corr_bias.dtype())); + } + int64_t num_experts = gate_logits_dims[1]; + int64_t num_local_experts = num_experts / world_size; + int64_t num_rows = x_dims[0]; + y->set_dims({num_local_experts, world_size, capacity, x_dims[1]}); + y->set_dtype(x.dtype()); + combine_weights->set_dims({num_rows, k}); + combine_weights->set_dtype(phi::DataType::FLOAT32); + scatter_index->set_dims({k, num_rows}); + scatter_index->set_dtype(phi::DataType::INT32); + expert_offset->set_dims({num_experts}); + expert_offset->set_dtype(phi::DataType::INT64); + expert_id->set_dims({num_rows, k}); + expert_id->set_dtype(phi::DataType::INT32); +} + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index ae7ca25ae245f3..9b2c34aed451ae 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -274,6 +274,18 @@ void MoeCombineInferMeta(const MetaTensor& x, const MetaTensor& scatter_index, MetaTensor* y); +void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id); + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 6532a4f3812f01..e183db924f844f 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3602,6 +3602,17 @@ data_type : x backward : moe_combine_grad +- op : moe_gate_dispatch_permute + args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) + output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + infer_meta : + func : MoeGateDispatchPermuteInferMeta + kernel : + func : moe_gate_dispatch_permute + data_type : x + optional : corr_bias + # backward : moe_gate_dispatch_permute_grad + - op : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) From b8c9636123e495c67129cd36c2baeebbc64ffbca Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 03:50:12 +0000 Subject: [PATCH 13/71] insert moegatedispatch --- paddle/phi/kernels/gpu/moe_fuse_op.cuh | 448 ++++++++++++ .../gpu/moe_gate_dispatch_permute_kernel.cu | 340 ++++++++++ paddle/phi/kernels/gpu/moe_kernel_impl.h | 642 ++++++++++++++++++ .../moe_gate_dispatch_permute_kernel.h | 32 + .../paddle/incubate/nn/functional/__init__.py | 1 + .../functional/moe_gate_dispatch_permute.py | 65 ++ .../test_moe_gate_dispatch_permute.py | 37 + 7 files changed, 1565 insertions(+) create mode 100644 paddle/phi/kernels/gpu/moe_fuse_op.cuh create mode 100644 paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu create mode 100644 paddle/phi/kernels/gpu/moe_kernel_impl.h create mode 100644 paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h create mode 100644 python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py create mode 100644 test/legacy_test/test_moe_gate_dispatch_permute.py diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.cuh b/paddle/phi/kernels/gpu/moe_fuse_op.cuh new file mode 100644 index 00000000000000..ba2474e6bec250 --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_fuse_op.cuh @@ -0,0 +1,448 @@ +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/common/exception.h" + +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, + const T* bias, //bias could be nullptr if not used + T* output, + int* indices, + int* source_rows, + const int num_experts, + const int k){ + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, + const T* bias, + T* output, + T* softmax, //no use + int* indices, + int* source_row, + const int num_rows, + const int num_experts, + const int k, + cudaStream_t stream){ + static constexpr int WARPS_PER_TB = 4; + static constexpr int TPB = 256; + moe_top_k<<>>( + input, bias, output, indices, source_row, num_experts, k); +} + +template +__global__ void modify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts){ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) + return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id[idx], ik); + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx]< +void modify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream){ + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_expert_id<<>>( + expert_id, + expert_id_out, + k, + num_rows, + num_experts + ); +} + +template +__global__ void +unmodify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts){ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) + return; + int ik = idx % k; + int irow = idx / k; + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx]>>offset); +} + +template +void unmodify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream){ + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + unmodify_expert_id<<>>( + expert_id, + expert_id_out, + k, + num_rows, + num_experts + ); +} + +template +__device__ inline int find_total_elts_leq_target(const T* sorted_indices, const int arr_length, const int target) +{ + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } + else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +template +__global__ void compute_total_rows_before_expert_kernel(const T* sorted_experts, + const int sorted_experts_len, + const int64_t num_experts, + int64_t* total_rows_before_expert) +{ + + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + return; + + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); + // total_rows_before_expert[0] = 0; + // total_rows_before_expert[1] = 1; + // if (sorted_experts_len > 3) { + // for (int i=0; i<35;i++){ + // total_rows_before_expert[i] = i; + // } + // } + + +} + +template +void compute_total_rows_before_expert(const T* sorted_indices, + const int total_indices, + const int64_t num_experts, + int64_t* total_rows_before_expert, + const cudaStream_t& stream) +{ + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +template +__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad + ) +{ + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity){ + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; //unset combine-weight + } + return; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + // printf("going through: capacity=%lld, num_active=%lld, row=[%d->%d], row-in-expert %lld\n", + // capacity, + // num_active, + // expanded_dest_row, expanded_source_row, + // row_in_expert + // ); + if (use_pad) + num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + if (use_pad){ + dest_row_ptr = permuted_output + + iexpert * capacity * cols + + row_in_expert * cols; + }else{ + dest_row_ptr = permuted_output + expanded_dest_row * cols; + } + + + for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x* VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } +} + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad, + cudaStream_t stream) +{ + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel<<>>(unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad + ); + } else { + initialize_moe_routing_kernel<<>>(unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad + ); + } +} + +/** + * 原逻辑的output: + * R0E0 + * R0E1 + * R1E0 + * R1E1 + * + * 我们想对all2all和专家gemm做overlap, 所以需要将all2all拆成流水线, 为了便于后续计算, 此kernel的output: + * R0E0 + * R1E0 + * R0E1 + * R1E1 +*/ +template +__global__ void initialize_moe_routing_permute_kernel(const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts + ) +{ + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. +#pragma unroll + for (int i = 0; i < LoopSize; i++) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x + i * gridDim.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity){ + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; //unset combine-weight + } + continue; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + + const int64_t irank = iexpert / num_local_experts; + const int64_t local_iexpert = iexpert % num_local_experts; + dest_row_ptr = permuted_output + local_iexpert * world_size * capacity * cols + irank * capacity * cols + row_in_expert * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +template +void initialize_moe_routing_permute_kernelLauncher(const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts, + cudaStream_t stream) +{ + const int loop_size = 2; + const int blocks = (num_rows * k) / loop_size; + assert((num_rows * k) % loop_size == 0); + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_permute_kernel<<>>(unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts + ); + } else { + initialize_moe_routing_permute_kernel<<>>(unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts + ); + } +} + diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu new file mode 100644 index 00000000000000..7d110b319d6f5d --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -0,0 +1,340 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" +#include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/gpu/moe_fuse_op.cuh" +namespace phi { + +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + // const int max_seq_len, + phi::CubKeyValueSorter &sorter) +{ + + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * inter_size); + // const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 4 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + // total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt + // total_ws_bytes += num_softmax_outs * sizeof(KeyT); + // const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const int sorter_ws_size_bytes = AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + //sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity // 用所有 bit 做排序,会降低些许性能,但是防止越界 + total_ws_bytes += sorter_ws_size_bytes; // intermediate (fc1) output + cub sorting workspace + // std::cout<<"sorter_ws_size_bytes = "<(gate_logits, + corr_bias, + combine_weights, // output + softmax_out_, // no use + expert_id, // output + source_rows_, // output + num_rows, + num_experts, + k, + stream); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1(combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1(expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1(source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + // modifiy expert-id according to k + if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id + modify_expert_id_launcher(expert_id, expert_id_, k, num_rows, num_experts, stream); + + // calc expert-size +/* + if (!use_pad) + cal_expert_size_and_filter_launcher(expert_id, + k * num_rows, + num_experts, + capacity, + stream); +*/ + #ifdef DEBUG_MOE_OP + print_to_screen1(expert_id, 8, 16, std::string("expert-id after modified")); +#endif + sorter.run(fc1_result_, + sorter_ws_size_bytes, + use_pad ? expert_id_ : expert_id, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + if (use_pad) + unmodify_expert_id_launcher(permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1(permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_total_rows_before_expert( + permuted_experts_, + k * num_rows, + num_experts, + expert_offset, + stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(expert_offset, 8, 16, std::string("expert_offset")); + int64_t num_active_host_v2; + cudaMemcpy(&num_active_host_v2, expert_offset + num_experts - 1, sizeof(int64_t), cudaMemcpyDeviceToHost); + std::cerr << "[DEBUG] num_active v2: " << num_active_host_v2 << std::endl; + print_to_screen1(permuted_experts_, 8, num_active_host_v2+2, std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, std::string("expert-id after permute")); +#endif + + if (!use_all2all_permute) { + initialize_moe_routing_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + use_pad, + stream); + } else { + PD_CHECK(num_experts > 0); + PD_CHECK(world_size > 0); + initialize_moe_routing_permute_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + world_size, + num_local_experts, + stream); + } + + // turn expert_offset_ptr into experts_num + // auto expert_offset_ptr = thrust::device_pointer_cast(expert_offset); + // thrust::adjacent_difference( + // expert_offset_ptr, expert_offset_ptr + num_experts, expert_offset_ptr + // ); +#ifdef DEBUG_MOE_OP + print_to_screen1(scatter_index, 8, 16, std::string("scatter_index after pad")); +#endif + // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, cudaMemcpyDeviceToDevice); + // cudaMemcpy(combine_weights, expert_scales_float, sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); + return; +} + +template +void moe_dispatch_fwd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const paddle::optional& corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + const DenseTensor& y, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_offset, + const DenseTensor& expert_id, + bool use_pad, + int64_t use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1){ + apply_moe_dispatch_fwd(dev_ctx, + x.data(), + gate_logits.data(), + corr_bias? corr_bias.get_ptr()->data() : nullptr, + num_rows, + num_experts, + hidden_size, + capacity, + k, + const_cast(y.data()), + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(expert_offset.data()), + const_cast(expert_id.data()), + use_pad, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoEDispatchPermuteKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const paddle::optional& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* y, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, + DenseTensor* expert_id){ + dev_ctx.template Alloc(expert_id); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(combine_weights); + dev_ctx.template Alloc(y); + const auto &x_shape = x.dims(); + const auto &gate_logits_shape = gate_logits.dims(); + int64_t num_rows = x_shape[0]; + int64_t hidden_size = x_shape[1]; + int64_t num_experts = gate_logits_shape[1]; + int64_t num_local_experts = num_experts / world_size; + moe_dispatch_fwd(dev_ctx, + x, + gate_logits, + corr_bias, + num_rows, + num_experts, + hidden_size, + capacity, + k, + *y, + *combine_weights, + *scatter_index, + *expert_offset, + *expert_id, + true, /*use_pad*/ + true, /*use_all2all_permute*/ + world_size, + num_local_experts); +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_permute, + GPU, + ALL_LAYOUT, + phi::MoEDispatchPermuteKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_kernel_impl.h b/paddle/phi/kernels/gpu/moe_kernel_impl.h new file mode 100644 index 00000000000000..fbf6063b284dbe --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_kernel_impl.h @@ -0,0 +1,642 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "cub/cub.cuh" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(cudaStream_t stream = 0); + + explicit CubKeyValueSorter(const int num_experts); + + void update_num_experts(const int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false); + + template + void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; + cudaStream_t stream_; +}; + +// ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + 3; //额外增加 3 位用于标记 topk的位置 +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32, + stream_); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_, + stream_); + } + return required_storage; +} + +template +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss + << "Error. The allocated workspace is too small to run this problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } +} + +template <> +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const __nv_bfloat16* keys_in, + __nv_bfloat16* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) {} + +// CubKeyValueSorter sorter_(stream); + +// -------- initialize_expert_choice_route_kernel -------- // +template +__global__ void initialize_expert_choice_route_kernel( + int* expert_for_source_row, + int* source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* total_rows_before_expert, + T* attr_mask, + const int cols, + const int k, + const int batch_size) { + int start = cols * blockIdx.x; + + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + expert_for_source_row[start + i] = blockIdx.x; + source_row[start + i] = start + i; + expanded_source_row_to_expanded_dest_row[start + i] = -1; + attr_mask[start + i] = (T)1.0f; + } + if (threadIdx.x == 0) { + total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); + } +} + +// -------- softmax_kernel -------- // +template +__global__ void softmax_kernel_v4( + T* qk_buf_, + const T* qk_buf_src, // shape [batch_size, seq_len] + const T* attr_mask, // shape [batch_size, seq_len] + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + float data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x; + + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + data[i] = qk + mask_val; + local_max = fmax(local_max, data[i]); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + data[i] = __expf(data[i] - s_max); + local_sum += data[i]; + } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + qk_buf_[qk_offset] = (T)(data[i] * s_mean); + } +#endif +} + +template +__global__ void softmax_kernel_v4_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x; + + T2 qk = qk_buf_half2[qk_offset]; + T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); + mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), + __float2half2_rn(-10000.0f)); + + data[i] = __hadd2(qk, mask_val); + + local_max = fmax( + local_max, + fmax(static_cast(data[i].x), static_cast(data[i].y))); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); + local_sum += static_cast(data[i].x + data[i].y); + } + + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); + } +#endif +} + +template +__global__ void softmax_kernel_v5_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[NUM][ITEMS_PER_THREAD]; + + int qk_offset[NUM]; + + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_max[j] = -1e20f; + } + + const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + int mask_offset[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + + T2 mask_val[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); + } + + T2 qk[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), + __float2half2_rn(-10000.0f)); + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = __hadd2(qk[j], mask_val[j]); + local_max[j] = fmax(local_max[j], + fmax(static_cast(data[j][i].x), + static_cast(data[j][i].y))); + } + } + if (blockDim.x <= 32) { + phi::funcs::WarpReduceMaxV2(local_max); + } else { + phi::funcs::BlockReduceMaxV2(local_max); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_max[j] = local_max[j]; + } + } + __syncthreads(); + float local_sum[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; + } + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + local_sum[j] += static_cast(data[j][i].x + data[j][i].y); + } + } + + if (blockDim.x <= 32) { + phi::funcs::WarpReduceSumV2(local_sum); + + } else { + phi::funcs::BlockReduceSumV2(local_sum); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); + } + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_buf_half2[qk_offset[j]] = + __hmul2(data[j][i], __float2half2_rn(s_sum[j])); + } + } +#endif +} + +// -------- transpose_kernel -------- // +template +__global__ void transposeAxis01( + T* out, T* in, const int dim0, const int dim1, const int dim2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < dim0 * dim1 * dim2) { + const int input_dim2_index = index % dim2; + index = (index - input_dim2_index) / dim2; + const int input_dim1_index = index % dim1; + index = (index - input_dim1_index) / dim1; + const int input_dim0_index = index % dim0; + + out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + + input_dim2_index] = in[input_dim0_index * dim1 * dim2 + + input_dim1_index * dim2 + input_dim2_index]; + } +} + +// -------- padding_kernel -------- // +template +__global__ void paddingKernel(T* output1, + int* output2, + const T* input1, + const int* input2, + const int* input_lengths, + const int num_tokens, + const int batch_size, + const int max_seq_len, + const int num_experts) { + const bool IS_FP16 = std::is_same::value; + const T MIN_T_VAL = (IS_FP16) ? (T)HALF_FLT_MIN : (T)FLT_MIN; + int offset1 = blockIdx.x * num_tokens; + int offset2 = blockIdx.x * batch_size * max_seq_len; + for (int i = 0; i < batch_size; i++) { + const T* in1_ptr = input1 + offset1; + const int* in2_ptr = input2 + offset1; + int input_length = input_lengths[i]; + offset1 += input_length; + + T* out1_ptr = output1 + offset2; + int* out2_ptr = output2 + offset2; + offset2 += max_seq_len; + + for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) { + if (j < input_length) { + out1_ptr[j] = in1_ptr[j]; + out2_ptr[j] = in2_ptr[j]; + } else { + out1_ptr[j] = MIN_T_VAL; + out2_ptr[j] = 0; + } + } + } +} + +// -------- general_topk_pair_sort_kernel -------- // +template +__global__ void general_topk_pair_sort(T* out_keys, + int* out_values, + T* in_keys, + int* in_values) { + typedef cub::BlockRadixSort + BlockRadixSort; + typedef cub:: + BlockLoad + BlockLoadKey; + typedef cub:: + BlockLoad + BlockLoadValue; + typedef cub:: + BlockStore + BlockStoreKey; + typedef cub::BlockStore + BlockStoreValue; + + __shared__ union { + typename BlockRadixSort::TempStorage sort; + typename BlockLoadKey::TempStorage loadkey; + typename BlockLoadValue::TempStorage loadvalue; + typename BlockStoreKey::TempStorage storekey; + typename BlockStoreValue::TempStorage storevalue; + } temp_storage; + + int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; + + T thread_keys[ITEMS_PER_THREAD]; + int thread_values[ITEMS_PER_THREAD]; + BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); + BlockLoadValue(temp_storage.loadvalue) + .Load(in_values + block_offset, thread_values); + __syncthreads(); + + BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); + __syncthreads(); + + BlockStoreKey(temp_storage.storekey) + .Store(out_keys + block_offset, thread_keys); + BlockStoreValue(temp_storage.storevalue) + .Store(out_values + block_offset, thread_values); +} + +// -------- finalize_moe_routing_kernel -------- // +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int cols, + const int k, + bool ec_route) { + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + const T* skip_row_ptr = skip + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output = skip_row_ptr[tid]; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + if (ec_route && expanded_permuted_row == -1) continue; + const int64_t k_offset = + ec_route ? expanded_original_row : original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = + thread_output + + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} + +// -------- initialize_moe_routing_kernel -------- // +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_rows, + const int active_rows, + const int cols, + const int k, + const int max_seq_len, + bool ec_route) { + // using LoadT = phi::AlignedVector; + // LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row / + k * max_seq_len + + expanded_dest_row % k] + : expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + dest_row_ptr[tid] = source_row_ptr[tid]; + // phi::Load(&source_row_ptr[tid], &src_vec); + // phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h new file mode 100644 index 00000000000000..2ca266150783cc --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoEDispatchPermuteKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const paddle::optional& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* y, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, + DenseTensor* expert_id); +} // namespace phi \ No newline at end of file diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index ecba7a94d43517..ff567942662bed 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -45,6 +45,7 @@ ) from .moe_combine import moe_combine from .expand_modality_expert_id import expand_modality_expert_id +from .moe_gate_dispatch_permute import moe_gate_dispatch_permute __all__ = [ 'fused_multi_head_attention', diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py new file mode 100644 index 00000000000000..ea4e54cc653eba --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py @@ -0,0 +1,65 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +import paddle +from paddle import _C_ops +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + +def moe_gate_dispatch_permute( + x: Tensor, + gate_logits: Tensor, + corr_bias: Tensor, + k: int, + capacity: int, + world_size: int, + name: str | None = None +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Dispatch and permute for Mixture of Experts (MoE). + + Args: + x: Input tensor [batch_size, seq_len, hidden_dim]. + gate_logits: Gate logits for choosing experts [batch_size, seq_len, num_experts]. + corr_bias: Optional correction bias to adjust gate logits. + k: Top-k experts to be selected. + capacity: The maximum number of tokens an expert can handle. + world_size: Number of distributed processes. + name: Optional name for the operation. + + Returns: + Tuple of Tensors containing: + - y: Output tensor after dispatch and permute. + - combine_weights: Weights for combining experts' outputs. + - scatter_index: Indices for scattering inputs to experts. + - expert_offset: Offset indices for each expert. + - expert_id: IDs of selected experts for each position. + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) + + helper = LayerHelper('moe_gate_dispatch_permute', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights = helper.create_variable_for_type_inference(dtype='float') + scatter_index = helper.create_variable_for_type_inference(dtype='int32') + expert_offset = helper.create_variable_for_type_inference(dtype='int32') + expert_id = helper.create_variable_for_type_inference(dtype='int32') + + inputs = { + 'x': x, + 'gate_logits': gate_logits, + 'corr_bias': corr_bias if corr_bias is not None else None + } + attrs = {'k': k, 'capacity': capacity, 'world_size': world_size} + outputs = { + 'y': y, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index, + 'expert_offset': expert_offset, + 'expert_id': expert_id + } + + helper.append_op(type='moe_gate_dispatch_permute', inputs=inputs, outputs=outputs, attrs=attrs) + return y, combine_weights, scatter_index, expert_offset, expert_id \ No newline at end of file diff --git a/test/legacy_test/test_moe_gate_dispatch_permute.py b/test/legacy_test/test_moe_gate_dispatch_permute.py new file mode 100644 index 00000000000000..e870e109bcb2fd --- /dev/null +++ b/test/legacy_test/test_moe_gate_dispatch_permute.py @@ -0,0 +1,37 @@ +import paddle +from paddle.incubate.nn.functional import moe_gate_dispatch_permute + +# 定义输入参数 +num_rows = 10 # 示例行数 +hidden_size = 128 # 隐藏层维度 +num_experts = 4 # 专家数 +world_size = 2 # 分布式世界大小 +k = 2 # 选择的Top-k专家 +capacity = 5 # 每个专家的处理容量 + +# 确保num_experts可以被world_size整除 +assert num_experts % world_size == 0 + +# 生成输入数据 +x = paddle.randn([num_rows, hidden_size], dtype='float32') +gate_logits = paddle.randn([num_rows, num_experts], dtype='float32') + +# 可选的修正偏差 +corr_bias = paddle.randn([num_rows], dtype='float32') + +# 调用封装的API +y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( + x=x, + gate_logits=gate_logits, + corr_bias=corr_bias, + k=k, + capacity=capacity, + world_size=world_size +) + +# 打印输出结果的形状和类型,验证结果 +print("Output y shape:", y.shape) +print("Combine weights shape:", combine_weights.shape) +print("Scatter index shape:", scatter_index.shape) +print("Expert offset shape:", expert_offset.shape) +print("Expert ID shape:", expert_id.shape) \ No newline at end of file From e5bfdc9a9a5a165092d53360c83e4c891008ce28 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 03:51:30 +0000 Subject: [PATCH 14/71] remove DCU support --- paddle/phi/kernels/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 629e59d682b651..7eee630a8cb34a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -228,6 +228,7 @@ if(WITH_ROCM) list( REMOVE_ITEM kernel_gpu + "gpu/moe_gate_dispatch_permute_kernel.cu" "gpu/expand_modality_expert_id_kernel.cu" "gpu/moe_combine_kernel.cu" "gpu/moe_combine_grad_kernel.cu" From 81d9fbc5e01cada8cfe9ef4b10c4de57524f38bb Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Thu, 22 May 2025 08:29:37 +0000 Subject: [PATCH 15/71] fix-bugs fix-bugs fix-bugs --- .pre-commit-config.yaml | 0 paddle/phi/infermeta/backward.cc | 7 +- paddle/phi/infermeta/backward.h | 4 +- paddle/phi/kernels/CMakeLists.txt | 2 +- ...uild_src_rank_and_local_expert_id_kernel.h | 3 +- paddle/phi/kernels/cal_aux_loss_grad_kernel.h | 31 +++++ paddle/phi/kernels/cal_aux_loss_kernel.h | 13 +- ...ild_src_rank_and_local_expert_id_kernel.cu | 15 +-- .../kernels/gpu/cal_aux_loss_grad_kernel.cu | 111 ++++++++++++++++++ paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 97 +-------------- paddle/phi/ops/yaml/backward.yaml | 2 +- 11 files changed, 158 insertions(+), 127 deletions(-) mode change 100644 => 100755 .pre-commit-config.yaml mode change 100644 => 100755 paddle/phi/kernels/CMakeLists.txt create mode 100644 paddle/phi/kernels/cal_aux_loss_grad_kernel.h create mode 100644 paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index bfe718bf867551..44c4ca41c064ce 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1888,14 +1888,15 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, } } -void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, +void CalAuxLossGradInferMeta(const MetaTensor& l_aux_loss_grad, + const MetaTensor& gate_prob, const MetaTensor& seqlen_float, const MetaTensor& ce, - const MetaTensor& out_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, - MetaTensor* gate_prob_grad) { + MetaTensor* gate_prob_grad); +{ auto gate_prob_dims = gate_prob.dims(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 7990568f09b022..db43a0dde64877 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -680,10 +680,10 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* value_grad); -void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, +void CalAuxLossGradInferMeta(const MetaTensor& l_aux_loss_grad, + const MetaTensor& gate_prob, const MetaTensor& seqlen_float, const MetaTensor& ce, - const MetaTensor& out_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt old mode 100644 new mode 100755 index 2f0c1eb96d856d..8d75f79c1ce431 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -231,7 +231,7 @@ if(WITH_ROCM) "gpu/affine_grid_grad_kernel.cu" "gpu/cal_aux_loss_kernel.cu" "gpu/cal_aux_loss_grad_kernel.cu" - "build_src_rank_and_local_expert_id_kernel.cu" + "gpu/build_src_rank_and_local_expert_id_kernel.cu" "gpu/apply_per_channel_scale_kernel.cu" "gpu/calc_reduced_attn_kernel.cu" "gpu/eigvalsh_kernel.cu" diff --git a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h index 34f99e2280a257..866ce93aac7cdd 100644 --- a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h +++ b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h @@ -13,8 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { diff --git a/paddle/phi/kernels/cal_aux_loss_grad_kernel.h b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h new file mode 100644 index 00000000000000..d9113b194aa969 --- /dev/null +++ b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& l_aux_loss_grad, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cal_aux_loss_kernel.h b/paddle/phi/kernels/cal_aux_loss_kernel.h index e6e682d7db43fc..4dfd4b5b020a4f 100644 --- a/paddle/phi/kernels/cal_aux_loss_kernel.h +++ b/paddle/phi/kernels/cal_aux_loss_kernel.h @@ -13,8 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -32,14 +31,4 @@ void CalAuxLossKernel(const Context& dev_ctx, DenseTensor* seqlen_float, DenseTensor* ce); -template -void CalAuxLossGradKernel(const Context& dev_ctx, - const DenseTensor& gate_prob, - const DenseTensor& seqlen_float, - const DenseTensor& ce, - const DenseTensor& out_grad, - const int64_t num_experts, - const bool use_group, - const int64_t moe_k, - DenseTensor* gate_prob_grad); } // namespace phi diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index 8caab2d7e4badc..d3d8f68e311942 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -11,22 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include "paddle/extension.h" -#include "paddle/phi/api/all.h" -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cal_aux_loss_kernel.h" -#include "paddle/phi/api/ext/spmd_infer.h" -#include "paddle/phi/infermeta/spmd_rules/rules.h" -#include "paddle/phi/infermeta/spmd_rules/utils.h" - -#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu new file mode 100644 index 00000000000000..7da3f8c10f9b31 --- /dev/null +++ b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu @@ -0,0 +1,111 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cal_aux_loss_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { + +template +__global__ void cal_aux_loss_grad_kernel(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, + const int64_t col_gate_prob, + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad) { + T ce_val = ce[threadIdx.x]; + T l_aux_grad = *out_grad; + if (use_group) { + l_aux_grad = l_aux_grad / static_cast(moe_k); + } + l_aux_grad *= static_cast(num_experts); + + gate_prob_grad[blockIdx.x * col_gate_prob + threadIdx.x] = + (ce_val * l_aux_grad) / (*seqlen_float); +} + +template +void cal_aux_loss_grad(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad, + cudaStream_t stream) { + cal_aux_loss_grad_kernel + <<>>(out_grad, + gate_prob, + row_gate_prob, + col_gate_prob, + seqlen_float, + ce, + num_experts, + use_group, + moe_k, + gate_prob_grad); +} + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& l_aux_loss_grad, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad) { + auto gate_prob_dims = gate_prob.dims(); + + const T* l_aux_loss_grad_data = l_aux_loss_grad.data(); + const T* gate_prob_data = gate_prob.data(); + const T* seqlen_float_data = seqlen_float.data(); + const T* ce_data = ce.data(); + + int64_t row_gate_prob = gate_prob_dims[0]; + int64_t col_gate_prob = gate_prob_dims[1]; + + T* gate_prob_grad_data = dev_ctx.template Alloc(gate_prob_grad); + + cal_aux_loss_grad(l_aux_loss_grad_data, + gate_prob_data, + row_gate_prob, + col_gate_prob, + seqlen_float_data, + ce_data, + num_experts, + use_group, + moe_k, + gate_prob_grad_data, + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + cal_aux_loss_grad, GPU, ALL_LAYOUT, phi::CalAuxLossGradKernel, float) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index ec599cdd38dc8e..10d61a97b50932 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -14,19 +14,12 @@ #pragma once -#include -#include - -// #include "paddle/extension.h" -// #include "paddle/phi/api/all.h" -#include "paddle/phi/core/dense_tensor.h" -// #include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/funcs/math_cuda_utils.h" -// #include "paddle/extension.h" +#include "paddle/phi/kernels/cal_aux_loss_kernel.h" -#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/cal_aux_loss_kernel.h" + +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" namespace phi { @@ -265,88 +258,6 @@ void CalAuxLossKernel(const Context& dev_ctx, dev_ctx.stream()); } -template -__global__ void cal_aux_loss_grad_kernel(const T* out_grad, - const T* gate_prob, - const int64_t row_gate_prob, - const int64_t col_gate_prob, - const T* seqlen_float, - const T* ce, - const int64_t num_experts, - const bool use_group, - const int64_t moe_k, - T* gate_prob_grad) { - T ce_val = ce[threadIdx.x]; - T l_aux_grad = *out_grad; - if (use_group) { - l_aux_grad = l_aux_grad / static_cast(moe_k); - } - l_aux_grad *= static_cast(num_experts); - - gate_prob_grad[blockIdx.x * col_gate_prob + threadIdx.x] = - (ce_val * l_aux_grad) / (*seqlen_float); -} - -template -void cal_aux_loss_grad(const T* out_grad, - const T* gate_prob, - const int64_t row_gate_prob, /*seq_len*/ - const int64_t col_gate_prob, /*expert_num*/ - const T* seqlen_float, - const T* ce, - const int64_t num_experts, - const bool use_group, - const int64_t moe_k, - T* gate_prob_grad, - cudaStream_t stream) { - cal_aux_loss_grad_kernel - <<>>(out_grad, - gate_prob, - row_gate_prob, - col_gate_prob, - seqlen_float, - ce, - num_experts, - use_group, - moe_k, - gate_prob_grad); -} - -template -void CalAuxLossGradKernel(const Context& dev_ctx, - const DenseTensor& gate_prob, - const DenseTensor& seqlen_float, - const DenseTensor& ce, - const DenseTensor& out_grad, - const int64_t num_experts, - const bool use_group, - const int64_t moe_k, - DenseTensor* gate_prob_grad) { - auto gate_prob_dims = gate_prob.dims(); - - const T* out_grad_data = out_grad.data(); - const T* gate_prob_data = gate_prob.data(); - const T* seqlen_float_data = seqlen_float.data(); - const T* ce_data = ce.data(); - - int64_t row_gate_prob = gate_prob_dims[0]; - int64_t col_gate_prob = gate_prob_dims[1]; - - T* gate_prob_grad_data = dev_ctx.template Alloc(gate_prob_grad); - - cal_aux_loss_grad(out_grad_data, - gate_prob_data, - row_gate_prob, - col_gate_prob, - seqlen_float_data, - ce_data, - num_experts, - use_group, - moe_k, - gate_prob_grad_data, - dev_ctx.stream()); -} - } // namespace phi PD_REGISTER_KERNEL(cal_aux_loss, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 4e56ba0e89959e..15b5bcab2d6211 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -363,7 +363,7 @@ - backward_op : cal_aux_loss_grad forward : cal_aux_loss (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) -> Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) - args : (Tensor gate_prob, Tensor seqlen_float, Tensor ce, Tensor out_grad, int64_t num_experts, bool use_group, int64_t moe_k) + args : (Tensor l_aux_loss_grad, Tensor gate_prob, Tensor seqlen_float, Tensor ce, int64_t num_experts, bool use_group, int64_t moe_k) output : Tensor(gate_prob_grad) infer_meta : func : CalAuxLossGradInferMeta From 4c0442929beb2432911ed5d1e20d629f63f0f411 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 05:31:10 +0000 Subject: [PATCH 16/71] fix log2 in windows maybe --- paddle/phi/kernels/gpu/{moe_fuse_op.cuh => moe_fuse_op.h} | 2 ++ paddle/phi/kernels/gpu/moe_kernel_impl.h | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) rename paddle/phi/kernels/gpu/{moe_fuse_op.cuh => moe_fuse_op.h} (99%) diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.cuh b/paddle/phi/kernels/gpu/moe_fuse_op.h similarity index 99% rename from paddle/phi/kernels/gpu/moe_fuse_op.cuh rename to paddle/phi/kernels/gpu/moe_fuse_op.h index ba2474e6bec250..150903fed5a74c 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_op.cuh +++ b/paddle/phi/kernels/gpu/moe_fuse_op.h @@ -1,5 +1,7 @@ +#pragma once #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" +#include template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, diff --git a/paddle/phi/kernels/gpu/moe_kernel_impl.h b/paddle/phi/kernels/gpu/moe_kernel_impl.h index fbf6063b284dbe..38fc60ae7d8f4d 100644 --- a/paddle/phi/kernels/gpu/moe_kernel_impl.h +++ b/paddle/phi/kernels/gpu/moe_kernel_impl.h @@ -16,7 +16,8 @@ limitations under the License. */ #include #include "cub/cub.cuh" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" - +#include +#include namespace phi { static const float HALF_FLT_MAX = 65504.F; From fb784e3930e59fa9edd4f3cb154aec0957cae0ba Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 05:38:34 +0000 Subject: [PATCH 17/71] update header file format --- paddle/phi/kernels/gpu/moe_fuse_op.h | 1 + paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.h b/paddle/phi/kernels/gpu/moe_fuse_op.h index 150903fed5a74c..9284c88786e118 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_op.h +++ b/paddle/phi/kernels/gpu/moe_fuse_op.h @@ -1,6 +1,7 @@ #pragma once #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" +#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" #include template diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu index 7d110b319d6f5d..31941833196e4d 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -13,11 +13,10 @@ // limitations under the License. #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" +#include "paddle/phi/kernels/gpu/moe_fuse_op.h" #include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/gpu/moe_fuse_op.cuh" namespace phi { // -------- getWorkspaceSize -------- // From d25d23ef21c5eabd318445d6c9c821cc6aba34f4 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Fri, 23 May 2025 06:18:29 +0000 Subject: [PATCH 18/71] fix-bugs --- paddle/phi/infermeta/backward.cc | 9 ++++----- paddle/phi/infermeta/backward.h | 4 ++-- paddle/phi/kernels/cal_aux_loss_grad_kernel.h | 2 +- paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 3 --- paddle/phi/ops/yaml/backward.yaml | 2 +- 6 files changed, 9 insertions(+), 13 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 44c4ca41c064ce..8f8e0cc01bdcf1 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1888,20 +1888,19 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, } } -void CalAuxLossGradInferMeta(const MetaTensor& l_aux_loss_grad, - const MetaTensor& gate_prob, +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, const MetaTensor& seqlen_float, const MetaTensor& ce, + const MetaTensor& l_aux_loss_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, - MetaTensor* gate_prob_grad); -{ + MetaTensor* gate_prob_grad) { auto gate_prob_dims = gate_prob.dims(); PADDLE_ENFORCE_EQ( gate_prob.dtype(), - out_grad.dtype(), + l_aux_loss_grad.dtype(), errors::InvalidArgument( "The input out_grad type should be equal to gate_prob type")); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index db43a0dde64877..d20c679ba95ce3 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -680,10 +680,10 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* value_grad); -void CalAuxLossGradInferMeta(const MetaTensor& l_aux_loss_grad, - const MetaTensor& gate_prob, +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, const MetaTensor& seqlen_float, const MetaTensor& ce, + const MetaTensor& l_aux_loss_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, diff --git a/paddle/phi/kernels/cal_aux_loss_grad_kernel.h b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h index d9113b194aa969..3b1cc9dfe66e3a 100644 --- a/paddle/phi/kernels/cal_aux_loss_grad_kernel.h +++ b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h @@ -19,10 +19,10 @@ namespace phi { template void CalAuxLossGradKernel(const Context& dev_ctx, - const DenseTensor& l_aux_loss_grad, const DenseTensor& gate_prob, const DenseTensor& seqlen_float, const DenseTensor& ce, + const DenseTensor& l_aux_loss_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu index 7da3f8c10f9b31..1dbc62b1fadc3e 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu @@ -72,10 +72,10 @@ void cal_aux_loss_grad(const T* out_grad, template void CalAuxLossGradKernel(const Context& dev_ctx, - const DenseTensor& l_aux_loss_grad, const DenseTensor& gate_prob, const DenseTensor& seqlen_float, const DenseTensor& ce, + const DenseTensor& l_aux_loss_grad, const int64_t num_experts, const bool use_group, const int64_t moe_k, diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index 10d61a97b50932..bdd9d46d811008 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -267,6 +267,3 @@ PD_REGISTER_KERNEL(cal_aux_loss, float, phi::dtype::float16, phi::dtype::bfloat16) {} - -PD_REGISTER_KERNEL( - cal_aux_loss_grad, GPU, ALL_LAYOUT, phi::CalAuxLossGradKernel, float) {} diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 15b5bcab2d6211..0ff49350a12c47 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -363,7 +363,7 @@ - backward_op : cal_aux_loss_grad forward : cal_aux_loss (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) -> Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) - args : (Tensor l_aux_loss_grad, Tensor gate_prob, Tensor seqlen_float, Tensor ce, int64_t num_experts, bool use_group, int64_t moe_k) + args : ( Tensor gate_prob, Tensor seqlen_float, Tensor ce, Tensor l_aux_loss_grad, int64_t num_experts, bool use_group, int64_t moe_k) output : Tensor(gate_prob_grad) infer_meta : func : CalAuxLossGradInferMeta From 47f010dbd106a982d5d38c972b65abee9ea10537 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 07:07:16 +0000 Subject: [PATCH 19/71] delete op test for pass CI --- .../paddle/incubate/nn/functional/__init__.py | 10 ++--- .../test_expand_modality_expert_id.py | 18 --------- test/legacy_test/test_moe_combine.py | 28 -------------- .../test_moe_gate_dispatch_permute.py | 37 ------------------- 4 files changed, 5 insertions(+), 88 deletions(-) delete mode 100644 test/legacy_test/test_expand_modality_expert_id.py delete mode 100644 test/legacy_test/test_moe_combine.py delete mode 100644 test/legacy_test/test_moe_gate_dispatch_permute.py diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index ff567942662bed..cf7afb504f6782 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -43,9 +43,9 @@ from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, ) -from .moe_combine import moe_combine -from .expand_modality_expert_id import expand_modality_expert_id -from .moe_gate_dispatch_permute import moe_gate_dispatch_permute +# from .moe_combine import moe_combine +# from .expand_modality_expert_id import expand_modality_expert_id +# from .moe_gate_dispatch_permute import moe_gate_dispatch_permute __all__ = [ 'fused_multi_head_attention', @@ -65,6 +65,6 @@ "blha_get_max_len", "block_multihead_attention", "swiglu", - "moe_combine", - "expand_modality_expert_id", + # "moe_combine", + # "expand_modality_expert_id", ] diff --git a/test/legacy_test/test_expand_modality_expert_id.py b/test/legacy_test/test_expand_modality_expert_id.py deleted file mode 100644 index d197f6d1f4f597..00000000000000 --- a/test/legacy_test/test_expand_modality_expert_id.py +++ /dev/null @@ -1,18 +0,0 @@ -from paddle.incubate.nn.functional import expand_modality_expert_id -import paddle - -num_expert_per_modality = 4 -group_size = 10 -modality_offset = 3 -is_group_expert = True - -expert_id = paddle.to_tensor([[0, 1, 2,], [3, 4, 5]], dtype='int32') - -expert_id_out = expand_modality_expert_id(expert_id, - num_expert_per_modality, - group_size, - modality_offset, - is_group_expert) - -print(expert_id_out) - diff --git a/test/legacy_test/test_moe_combine.py b/test/legacy_test/test_moe_combine.py deleted file mode 100644 index eea62e1f0a6728..00000000000000 --- a/test/legacy_test/test_moe_combine.py +++ /dev/null @@ -1,28 +0,0 @@ -import paddle -from paddle.incubate.nn.functional import moe_combine - -x = paddle.arange(1, 16).view((5, 3)).astype("float32") # [[1,2,3], [4,5,6], ..., [13,14,15]] -x.stop_gradient = False - -# 组合权重(手动构造), 数据类型需要与x相同 -combine_weights = paddle.to_tensor([ -[0.7, 0.3], -[0.6, 0.4], -[0.5, 0.5], -[0.4, 0.6], -[0.2, 0.8] -], stop_gradient=False) - -# 分散索引 仅支持int32 -scatter_index = paddle.to_tensor([ -[0, 1, 2, 3, 4], -[0, 1, 2, 3, 4] -], dtype="int32", stop_gradient=False) - -y = moe_combine(x, combine_weights, scatter_index) -print("\n##########forward output##########\n") -print(y) -print(f"x.grad: {x.grad,}, combine_weights.grad: {combine_weights.grad}, scatter_index.grad: {scatter_index.grad}") -y.backward() -print("\n##########backward output##########\n") -print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n scatter_index.grad: {scatter_index.grad}") \ No newline at end of file diff --git a/test/legacy_test/test_moe_gate_dispatch_permute.py b/test/legacy_test/test_moe_gate_dispatch_permute.py deleted file mode 100644 index e870e109bcb2fd..00000000000000 --- a/test/legacy_test/test_moe_gate_dispatch_permute.py +++ /dev/null @@ -1,37 +0,0 @@ -import paddle -from paddle.incubate.nn.functional import moe_gate_dispatch_permute - -# 定义输入参数 -num_rows = 10 # 示例行数 -hidden_size = 128 # 隐藏层维度 -num_experts = 4 # 专家数 -world_size = 2 # 分布式世界大小 -k = 2 # 选择的Top-k专家 -capacity = 5 # 每个专家的处理容量 - -# 确保num_experts可以被world_size整除 -assert num_experts % world_size == 0 - -# 生成输入数据 -x = paddle.randn([num_rows, hidden_size], dtype='float32') -gate_logits = paddle.randn([num_rows, num_experts], dtype='float32') - -# 可选的修正偏差 -corr_bias = paddle.randn([num_rows], dtype='float32') - -# 调用封装的API -y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( - x=x, - gate_logits=gate_logits, - corr_bias=corr_bias, - k=k, - capacity=capacity, - world_size=world_size -) - -# 打印输出结果的形状和类型,验证结果 -print("Output y shape:", y.shape) -print("Combine weights shape:", combine_weights.shape) -print("Scatter index shape:", scatter_index.shape) -print("Expert offset shape:", expert_offset.shape) -print("Expert ID shape:", expert_id.shape) \ No newline at end of file From 80bd65f8ce27f4a8637fb3b991da9c6e043065c5 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 07:10:17 +0000 Subject: [PATCH 20/71] add cmath header --- paddle/phi/kernels/gpu/moe_fuse_op.h | 1 - paddle/phi/kernels/gpu/moe_kernel_impl.h | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.h b/paddle/phi/kernels/gpu/moe_fuse_op.h index 9284c88786e118..08b65f6dfd58d9 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_op.h +++ b/paddle/phi/kernels/gpu/moe_fuse_op.h @@ -2,7 +2,6 @@ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" #include "paddle/phi/kernels/gpu/moe_kernel_impl.h" -#include template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, diff --git a/paddle/phi/kernels/gpu/moe_kernel_impl.h b/paddle/phi/kernels/gpu/moe_kernel_impl.h index 38fc60ae7d8f4d..e2ea1eac0d5554 100644 --- a/paddle/phi/kernels/gpu/moe_kernel_impl.h +++ b/paddle/phi/kernels/gpu/moe_kernel_impl.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "cub/cub.cuh" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include +#include #include namespace phi { From 9e889aa84671ce9f8588235e9410a0eac2a68cf4 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 23 May 2025 08:10:07 +0000 Subject: [PATCH 21/71] tmp --- paddle/phi/infermeta/backward.cc | 40 ++++++- paddle/phi/infermeta/backward.h | 11 ++ paddle/phi/kernels/gpu/moe_fleety_utils.h | 108 +++++++++++++++++ paddle/phi/kernels/gpu/moe_fuse_bwd_op.h | 0 .../moe_gate_dispatch_permute_grad_kernel.cu | 113 ++++++++++++++++++ .../moe_gate_dispatch_permute_grad_kernel.h | 31 +++++ paddle/phi/ops/yaml/backward.yaml | 10 ++ paddle/phi/ops/yaml/ops.yaml | 2 +- 8 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 paddle/phi/kernels/gpu/moe_fleety_utils.h create mode 100644 paddle/phi/kernels/gpu/moe_fuse_bwd_op.h create mode 100644 paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu create mode 100644 paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 805e4a0ac868a5..30591de9397228 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1228,18 +1228,50 @@ void MoeCombineGradInferMeta(const MetaTensor& x, PADDLE_ENFORCE_EQ( x_dim.size(), 2, - errors::InvalidArgument("Input X should have 2 dimensions")); + errors::InvalidArgument("The input X should have 2 dimensions" + "But received X's dimension = %d", + x_dim.size())); PADDLE_ENFORCE_EQ( (scatter_index.dtype() == phi::DataType::INT32), true, errors::InvalidArgument( - "The input scatter_index type should be int32")); - grad_x->set_dims(phi::make_ddim({x_dim[0],x_dim[1]})); + "The input scatter_index type should be int32" + "But received scatter_index type = %s", + scatter_index.dtype())); + grad_x->set_dims(common::make_ddim({x_dim[0],x_dim[1]})); grad_x->set_dtype(x.dtype()); - grad_combine_weights_helper->set_dims(phi::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); + grad_combine_weights_helper->set_dims(common::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); grad_combine_weights_helper->set_dtype(x.dtype()); } +void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* x_grad, + MetaTensor* gate_logtis_grad){ + + auto y_grad_dims = y_grad.dims(); + PADDLE_ENFORCE_EQ( + y_grad_dims[1], + world_size, + common::errors::InvalidArgument("The second dimension of y_grad should be equal to world_size, but " + "received y_grad_dims[1] = %d, world_size = %d", + y_grad_dims[1], world_size)); + int64_t num_local_experts = y_grad_dims[0]; + int64_t num_experts = world_size * num_local_experts; + int64_t hidden_size = y_grad_dims[y_grad_dims.size()-1]; + int64_t num_rows = scatter_index.dims()[1]; + x_grad->set_dims({num_rows, hidden_size}); + x_grad->set_dtype(y_grad.dtype()); + gate_logtis_grad->set_dims({num_rows, num_experts}); + gate_logtis_grad->set_dtype(phi::DataType::FLOAT32); +} + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 50cd2500b26d72..a2f61c761228a1 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -469,6 +469,17 @@ void MoeCombineGradInferMeta(const MetaTensor& x, MetaTensor* grad_x, MetaTensor* grad_combine_weights_helper); +void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* x_grad, + MetaTensor* gate_logtis_grad); + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad); diff --git a/paddle/phi/kernels/gpu/moe_fleety_utils.h b/paddle/phi/kernels/gpu/moe_fleety_utils.h new file mode 100644 index 00000000000000..027ed66fde6bc3 --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_fleety_utils.h @@ -0,0 +1,108 @@ +#pragma once + +#include "paddle/extension.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/include/kernels.h" + +namespace phi { + +template +void ContiguousKernel(const Context& dev_ctx, + const DenseTensor& input, + DenseTensor* out); + +} // namespace phi + +namespace fleety_utils { + +namespace internal { + +template +struct TensorHasStrideImpl { +private: + struct YesType {}; + struct NoType {}; + + template + static YesType Check(decltype(std::declval().is_contiguous())) { + return 0; + } + + template + static NoType Check(...) { + return 0; + } + +public: + static constexpr bool kValue = + std::is_same(false))>::value; +}; + + +template +struct ContiguousTensorHelperImpl { + static_assert(_SupportStride, "_SupportStride should be true"); + + static bool IsContiguousTensor(const DenseT &t) { + return t.meta().is_contiguous(); + } + + static typename std::enable_if<_SupportStride, void>::type TensorTrans2Contiguous(DenseT *t) { + if (t != nullptr && t->initialized() && !t->meta().is_contiguous()) { + auto place = t->place(); + auto is_gpu_place = place.GetType() == phi::AllocationType::GPU; + PD_CHECK(is_gpu_place, "Only support GPU place"); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto gpu_ctx = reinterpret_cast(dev_ctx); + auto dtype = t->dtype(); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(dtype, "contiguous_kernel", ([&] { + DenseT out; + phi::ContiguousKernel(*gpu_ctx, *t, &out); + *t = out; + })); + } + } + + static void TensorTrans2Contiguous(PaddleT *t) { + if (t != nullptr) { + if (!t->is_dense_tensor()) { + PD_THROW("Trans2Contiguous only supports DenseTensor"); + } + auto *dense_t = static_cast(t->impl().get()); + TensorTrans2Contiguous(dense_t); + } + } +}; + + +template +struct ContiguousTensorHelperImpl { + static bool IsContiguousTensor(const DenseT &t) { return true; } + static void TensorTrans2Contiguous(DenseT *t) {} + static void TensorTrans2Contiguous(PaddleT *t) {} +}; + + +} // namespace internal + + +inline constexpr bool SupportStride() { + return internal::TensorHasStrideImpl::kValue; +} + +using ContiguousTensorHelper = internal::ContiguousTensorHelperImpl; + +inline bool IsContiguousTensor(const phi::DenseTensor &t) { + return ContiguousTensorHelper::IsContiguousTensor(t); +} + +inline void TensorTrans2Contiguous(phi::DenseTensor *t) { + return ContiguousTensorHelper::TensorTrans2Contiguous(t); +} + +inline void TensorTrans2Contiguous(paddle::Tensor *t) { + return ContiguousTensorHelper::TensorTrans2Contiguous(t); +} + +} // namespace fleety_utils diff --git a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h b/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu new file mode 100644 index 00000000000000..134da55ec26eef --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu @@ -0,0 +1,113 @@ +#include "paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" // 注册相关 +#include "paddle/phi/backends/gpu/gpu_context.h" // context相关 +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +namespace phi{ + +template +void apply_moe_dispatch_bwd( + const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_grad, + const int* expert_id, // [s, k] + float* gate_logtis_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream){ + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, num_rows, k, dim, -1, stream, use_all2all_permute, world_size, num_local_experts, capacity); + + topk_grad_with_mask_launcher(combine_weights_grad, + expert_id, + combine_weights, + gate_logtis_grad, + num_rows, k, num_experts, stream); +} + + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + const DenseTensor&x_grad, + const DenseTensor& gate_logtis_grad, + int64_t capacity, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1){ + auto combine_weights_dims = combine_weights.dims(); + int64_t num_rows = combine_weights_dims[0]; + int64_t k = combine_weights_dims[1]; + auto y_grad_dims = y_grad.dims(); +#ifdef MOE_OPS_AUTO + int64_t hidden_size = y_grad_dims[2]; +#else + int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; +#endif + int64_t num_experts = gate_logtis_grad.dims()[1]; + + apply_moe_dispatch_bwd( + y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_grad.data(), + expert_id.data(), + gate_logtis_grad.data(), + x_grad.data(), + num_rows, + k, + hidden_size, + num_experts, + capacity, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [num_local_experts, num_experts * capacity // num_local_experts, h] + const DenseTensor& y_grad, // [s, k] + const DenseTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* x_grad, + DenseTensor* gate_logtis_grad){ + int64_t num_local_experts = y_grad.dims()[0]; + auto scatter_index_dims = scatter_index.dims(); + DenseTensor t_scatter_index = phi::Empty(dev_ctx, {scatter_index_dims[1], scatter_index_dims[0]}); + phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); + fleety_utils::TensorTrans2Contiguous(&t_scatter_index); + moe_dispatch_bwd(dev_ctx, + combine_weights, + t_scatter_index, + expert_id, + y_grad, + combine_weights_grad, + x_grad, + gate_logtis_grad, + capacity, + true, /*use_all2all_permute*/ + world_size, + num_local_experts); + +} +} // namespace phi diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h new file mode 100644 index 00000000000000..10ecb4978b71ff --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* x_grad, + DenseTensor* gate_logtis_grad); +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index c36a3c1694ae9b..92895c820709ac 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2256,6 +2256,16 @@ kernel : func : moe_combine_grad +- backward_op : moe_gate_dispatch_permute_grad + forward : moe_gate_dispatch_permute (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, int64_t world_size) + output : Tensor(x_grad), Tensor(gate_logtis_grad) + infer_meta : + func : MoeGateDispatchPermuteGradInferMeta + kernel : + func : moe_gate_dispatch_permute_grad + data_type : y_grad + - backward_op : mp_allreduce_sum_grad forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out) args : (Tensor out_grad, int ring_id = 0) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index e183db924f844f..c536cc129f36e9 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3611,7 +3611,7 @@ func : moe_gate_dispatch_permute data_type : x optional : corr_bias - # backward : moe_gate_dispatch_permute_grad + backward : moe_gate_dispatch_permute_grad - op : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f) From c78a3cb9181b39b8b3730fdd11dd4c158e356ba3 Mon Sep 17 00:00:00 2001 From: zhenghuaijin Date: Fri, 23 May 2025 16:54:50 +0800 Subject: [PATCH 22/71] pass int_bincount --- paddle/phi/infermeta/unary.cc | 20 ++++ paddle/phi/infermeta/unary.h | 6 ++ paddle/phi/kernels/CMakeLists.txt | 3 +- paddle/phi/kernels/gpu/int_bincount.cu | 95 +++++++++++++++++++ paddle/phi/kernels/int_bincount.h | 22 +++++ paddle/phi/ops/yaml/ops.yaml | 9 ++ .../paddle/incubate/nn/functional/__init__.py | 2 + .../incubate/nn/functional/int_bincount.py | 21 ++++ 8 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/gpu/int_bincount.cu create mode 100644 paddle/phi/kernels/int_bincount.h create mode 100644 python/paddle/incubate/nn/functional/int_bincount.py diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index abf1823d67c86e..dacc2e5886de47 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -6164,6 +6164,26 @@ void ArrayPopInferMeta(const MetaTensor& array, out->set_dtype(array.dtype()); } +void IntBincountInferMeta(const MetaTensor& x, + int64_t low, + int64_t high, + int64_t dtype, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + x.dims().size(), 1, + errors::InvalidArgument( + "The input 'x' of int_bincount must be a 1-D Tensor, but got %u-D.", + x.dims().size())); + PADDLE_ENFORCE_GT( + high, low, + errors::InvalidArgument("Attr high (%d) must be > low (%d).", high, low)); + int64_t bin_count = high - low + 1; + + out->set_dims(phi::make_ddim({bin_count})); + out->set_dtype(x.dtype()); +} + + } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6e9454e9fdac9d..458e1860ba4322 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -998,4 +998,10 @@ void ArrayPopInferMeta(const MetaTensor& array, MetaTensor* out, MetaConfig config = MetaConfig()); +void IntBincountInferMeta(const MetaTensor& x, + int64_t low, + int64_t high, + int64_t dtype, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 2f45770291bd58..1c1f7ce4fc61fd 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -236,7 +236,8 @@ if(WITH_ROCM) "gpu/matrix_rank_kernel.cu" "gpu/matrix_rank_tol_kernel.cu" "gpu/svd_kernel.cu" - "gpu/cuda_gemm_kernel.cu") + "gpu/cuda_gemm_kernel.cu" + "gpu/int_bincount.cu") endif() # Remove AP kernel when CINN is not enabled. diff --git a/paddle/phi/kernels/gpu/int_bincount.cu b/paddle/phi/kernels/gpu/int_bincount.cu new file mode 100644 index 00000000000000..a662f97b64777c --- /dev/null +++ b/paddle/phi/kernels/gpu/int_bincount.cu @@ -0,0 +1,95 @@ +// #include "paddle/extension.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/common/flags.h" +#include +#include +#include "cub/device/device_histogram.cuh" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" // NOLINT +#include "paddle/phi/kernels/int_bincount.h" // NOLINT + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + +COMMON_DECLARE_bool(enable_pir_api); + +namespace phi{ +static phi::DataType TransToDataType(int64_t dtype) { + if (FLAGS_enable_pir_api) { + return static_cast(dtype); + } else { + return phi::TransToPhiDataType(dtype); + } +} + +std::vector> IntBincountInferShape( + std::vector x_shape, + int64_t min_value, + int64_t max_value, + int64_t out_dtype) { + return {{max_value - min_value}}; +} + +std::vector IntBincountInferDType( + phi::DataType x_dtype, + int64_t min_value, + int64_t max_value, + int64_t out_dtype) { + return {TransToDataType(out_dtype)}; +} + +template +void IntBincountImpl(const Context& ctx, const T *x, int64_t n, T min_v, T max_v, BinsT *bins) { + DenseTensor workspace; + void *workspace_ptr = nullptr; + size_t workspace_size = 0; +#pragma unroll + for (int i = 0; i < 2; ++i) { + if (workspace_size > 0) { + workspace = phi::Empty(ctx, {static_cast(workspace_size)}); + workspace_ptr = workspace.data(); + } + auto err = cub::DeviceHistogram::HistogramEven( + workspace_ptr, workspace_size, x, bins, max_v - min_v + 1, min_v, max_v, n, ctx.stream()); + PD_CHECK(err == cudaSuccess, "HistogramEven error: %s", cudaGetErrorString(err)); + } +} + +// T is x's input type and out_dtype is in args +template +void IntBincount(const Context& ctx, const DenseTensor &x, int64_t low, int64_t high, int64_t out_dtype, DenseTensor* out) { + PD_CHECK(low < high); + int64_t bins_width = high - low; + PD_CHECK(bins_width + 1 < std::numeric_limits::max()); + + auto bins_dtype = TransToPhiDataType(out_dtype); + DenseTensor bins = phi::Empty(ctx, {bins_width}); + + // auto x_dytpe = x.dtype(); + auto low_v = static_cast(low); + auto high_v = static_cast(high); + PD_CHECK(static_cast(low_v) == low); + PD_CHECK(static_cast(high_v) == high); + const auto *x_data = x.data(); + void *bins_data = bins.data(); + int64_t n = x.numel(); + if (bins_dtype == phi::DataType::INT32) { + IntBincountImpl(ctx, x_data, n, low_v, high_v, static_cast(bins_data)); + } else if (bins_dtype == phi::DataType::INT64) { + using ULLI = unsigned long long int; + static_assert(sizeof(int64_t) == sizeof(ULLI)); + IntBincountImpl(ctx, x_data, n, low_v, high_v, static_cast(bins_data)); + } else { + PD_THROW("Only support INT32 and INT64, but got %s", bins_dtype); + } + out = &bins; +} +} // namespace phi + +PD_REGISTER_KERNEL(int_bincount, + GPU, + ALL_LAYOUT, + phi::IntBincount, + int64_t, + int) {} \ No newline at end of file diff --git a/paddle/phi/kernels/int_bincount.h b/paddle/phi/kernels/int_bincount.h new file mode 100644 index 00000000000000..0c4286eeadd39c --- /dev/null +++ b/paddle/phi/kernels/int_bincount.h @@ -0,0 +1,22 @@ +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/common/flags.h" +#include +#include +#include "cub/device/device_histogram.cuh" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" // NOLINT + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi{ + +template +void IntBincount(const Context& ctx, + const DenseTensor &x, + int64_t low, + int64_t high, + int64_t out_dtype, + DenseTensor* out); +} \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 69555563cc1965..98c89cca6b678b 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5674,3 +5674,12 @@ data_type: numbers interfaces : paddle::dialect::InferSymbolicShapeInterface traits : paddle::dialect::ForwardOnlyTrait + +- op: int_bincount + args: (Tensor x, int low, int high, int dtype) + output: Tensor(out) + infer_meta: + func: IntBincountInferMeta + kernel: + func: int_bincount + data_type: x \ No newline at end of file diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index aec7625145d348..7ae9a98964a6f2 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -43,6 +43,7 @@ from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, ) +from .int_bincount import int_bincount __all__ = [ 'fused_multi_head_attention', @@ -62,4 +63,5 @@ "blha_get_max_len", "block_multihead_attention", "swiglu", + "int_bincount", ] diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py new file mode 100644 index 00000000000000..6389f834a49e45 --- /dev/null +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -0,0 +1,21 @@ +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import convert_dtype + + +def int_bincount(x, low, high, dtype=None, name=None): + helper = LayerHelper("int_bincount", **locals()) + out_dtype = dtype if dtype is not None else x.dtype + y = helper.create_variable_for_type_inference(dtype=out_dtype) + dtype_attr = convert_dtype(out_dtype) + + helper.append_op( + type="int_bincount", + inputs={"x": x}, + outputs={"y": y}, + attrs={ + "low": low, + "high": high, + "dtype": dtype_attr, + }) + return y \ No newline at end of file From 19085da21a3bb9515b0c5a11c13ca64de1c8e7ac Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Fri, 23 May 2025 09:43:15 +0000 Subject: [PATCH 23/71] add moe_dispatch_bwd --- paddle/phi/kernels/fused_moe_bwd_op.h | 351 ++++++++++++++++++ .../gpu/moe_gate_dispatch_grad_kernel.cu | 164 ++++++++ .../kernels/moe_gate_dispatch_grad_kernel.h | 46 +++ 3 files changed, 561 insertions(+) create mode 100644 paddle/phi/kernels/fused_moe_bwd_op.h create mode 100644 paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu create mode 100644 paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h diff --git a/paddle/phi/kernels/fused_moe_bwd_op.h b/paddle/phi/kernels/fused_moe_bwd_op.h new file mode 100644 index 00000000000000..3714e5872d8eb0 --- /dev/null +++ b/paddle/phi/kernels/fused_moe_bwd_op.h @@ -0,0 +1,351 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifndef _FUSED_MOE_BWD_OP_H_ +#define _FUSED_MOE_BWD_OP_H_ + +#include +#include +#include + +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +#define WARP_SIZE 32 +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-function" + +#pragma GCC diagnostic pop + +// namespace paddle { +// namespace operators { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void topk_grad_with_mask(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts // e +) { + // init dx to zero + for (int i = blockIdx.x; i < num_rows; i += gridDim.x) { + int base_grad = i * num_experts; + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + dx[base_grad + j] = static_cast(0); + } + __syncthreads(); + int base_index = i * k; + for (int j = threadIdx.x; j < k; j += blockDim.x) { + int64_t idx = topk_idx[base_index + j]; + if (combine_weights[base_index + j] > static_cast(0)) { + dx[base_grad + idx] = dy[base_index + j]; + } + } + } +} + +// y=zero_part(topk(x)) 的反向过程 +// x: [s,e] +// dy: [s,k] +// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> combine_weights: [s, k] +template +void topk_grad_with_mask_launcher(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts, // e + cudaStream_t stream) { + int blocks = num_rows; + int threads = 1024; + + topk_grad_with_mask<<>>( + dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); +} + +template +__global__ void gather_with_mask_permute_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num, + int64_t capacity, + int64_t world_size, + int64_t num_local_experts) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t remaining_after_irank = id % (num_local_experts * capacity); + + int64_t irank = id / (num_local_experts * capacity); + int64_t local_iexpert = remaining_after_irank / capacity; + int64_t row_in_expert = remaining_after_irank % capacity; + int64_t permuted_id = local_iexpert * (world_size * capacity) + + irank * capacity + row_in_expert; + int64_t in_offset = permuted_id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +__global__ void gather_with_mask_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t in_offset = id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} + +inline int64_t max_shared_s_num(int64_t num_rows, + int64_t dim, + int64_t threads, + int64_t vec_size) { + if ((threads * vec_size) % dim == 0) { + return min(num_rows, threads * vec_size / dim); + } else { + int64_t max_res = DivUp(threads * 4, dim); + for (int64_t idx = 0; idx < num_rows * dim; idx += vec_size * threads) { + int64_t si_start = idx / dim; + int64_t si_end = min(num_rows * dim, idx + vec_size * threads - 1) / dim; + max_res = max(max_res, (si_end - si_start + 1)); + } + return min(num_rows, max_res); + } +} + +template +void gather_with_mask_launcher(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s,k,d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t num_active, + cudaStream_t stream, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1, + int64_t capacity = -1) { + int numel = num_rows * dim; +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG-BWD] launch kernel, num_active=" << num_active + << ", num_rows=" << num_rows << ", dim=" << dim << std::endl; +#endif + + int64_t threads = 512; + if (dim % 4 == 0) { + int64_t blocks = DivUp(DivUp(numel, 4), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG-BWD] gather_with_mask with vectorized, s_shared_num=" + << s_shared_num << ", block=" << blocks << std::endl; +#endif + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } else { + int64_t blocks = DivUp(DivUp(numel, 1), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +#ifdef DEBUG_MOE_OP + std::cerr + << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" + << s_shared_num << ", block=" << blocks << std::endl; +#endif + + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } +} + +// } // namespace operators +// } // namespace paddle + +#endif diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu new file mode 100644 index 00000000000000..23e8d8cfd985bc --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu @@ -0,0 +1,164 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/include/kernels.h" +#include "paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_grad, + const int* expert_id, // [s, k] + float* gate_logtis_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + -1, + stream, + use_all2all_permute, + world_size, + num_local_experts, + capacity); + + topk_grad_with_mask_launcher(combine_weights_grad, + expert_id, + combine_weights, + gate_logtis_grad, + num_rows, + k, + num_experts, + stream); +} + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + const DenseTensor& x_grad, + const DenseTensor& gate_logtis_grad, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts) { + int64_t num_rows = combine_weights.dims()[0]; + int64_t k = combine_weights.dims()[1]; +#ifdef MOE_OPS_AUTO + int64_t hidden_size = y_grad.dims()[2]; +#else + int64_t hidden_size = y_grad.dims()[1]; +#endif + int64_t num_experts = gate_logtis_grad.dims()[1]; + + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_grad.data(), + expert_id.data(), + static_cast gate_logtis_grad.data(), + static_cast x_grad.data(), + num_rows, + k, + hidden_size, + num_experts, + capacity, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* x_grad, + DenseTensor* gate_logtis_grad) { + auto y_grad_dims = y_grad.dims(); + auto scatter_index_dims = scatter_index.dims(); + +#ifdef MOE_OPS_AUTO + // y_grad shape is [num_experts, capacity, h] + int64_t num_experts = y_grad_dims[0]; + int64_t hidden_size = y_grad_dims[2]; +#else + int64_t num_experts = y_grad_dims[0] / capacity; + int64_t hidden_size = y_grad_dims[1]; +#endif + int64_t num_rows = scatter_index_dims[1]; + + const std::vector axis = {1, 0}; + + DenseTensor t_scatter_index; + phi::Transpose(dev_ctx, scatter_index, axis, &t_scatter_index); + DenseTensor t_scatter_index_; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_); + const DenseTensor t_scatter_index__ = t_scatter_index_; + + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(gate_logtis_grad); + + moe_dispatch_bwd(dev_ctx, + combine_weights, + t_scatter_index__, + expert_id, + y_grad, + combine_weights_grad, + *x_grad, + *gate_logtis_grad, + capacity); +} + +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h new file mode 100644 index 00000000000000..b61e8ae3398f48 --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h @@ -0,0 +1,46 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* x_grad, + DenseTensor* gate_logtis_grad); + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + const DenseTensor& x_grad, + const DenseTensor& gate_logtis_grad, + int64_t capacity, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1); +} // namespace phi From aca0c5d343b7bd57b5fe7ec3ba4f90052b837aaf Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Fri, 23 May 2025 10:56:12 +0000 Subject: [PATCH 24/71] add moe_gate_dispatch --- paddle/phi/infermeta/backward.cc | 53 ++ paddle/phi/infermeta/backward.h | 11 + paddle/phi/infermeta/multiary.cc | 90 +++ paddle/phi/infermeta/multiary.h | 12 + paddle/phi/kernels/CMakeLists.txt | 2 + .../gpu/moe_gate_dispatch_grad_kernel.cu | 26 +- .../kernels/gpu/moe_gate_dispatch_kernel.cu | 372 ++++++++++ paddle/phi/kernels/moe_fuse_op.h | 457 +++++++++++++ .../kernels/moe_gate_dispatch_grad_kernel.h | 4 +- paddle/phi/kernels/moe_gate_dispatch_kernel.h | 35 + paddle/phi/kernels/moe_kernel_impl.h | 642 ++++++++++++++++++ paddle/phi/ops/yaml/backward.yaml | 10 + paddle/phi/ops/yaml/ops.yaml | 10 + 13 files changed, 1709 insertions(+), 15 deletions(-) mode change 100755 => 100644 paddle/phi/kernels/CMakeLists.txt create mode 100644 paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu create mode 100644 paddle/phi/kernels/moe_fuse_op.h create mode 100644 paddle/phi/kernels/moe_gate_dispatch_kernel.h create mode 100644 paddle/phi/kernels/moe_kernel_impl.h diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 8f8e0cc01bdcf1..ffeabded2f6f04 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1907,4 +1907,57 @@ void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, gate_prob_grad->set_dims({gate_prob_dims}); gate_prob_grad->set_dtype(gate_prob.dtype()); } + +void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad) { + auto combine_weights_dims = combine_weights.dims(); + auto scatter_index_dims = scatter_index.dims(); + auto expert_id_dims = expert_id.dims(); + auto y_grad_dims = y_grad.dims(); + auto combine_weights_grad_dims = combine_weights_grad.dims(); + + PADDLE_ENFORCE_EQ(combine_weights_dims.size(), + 2, + errors::InvalidArgument( + "Input combine_weights should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + scatter_index_dims.size(), + 2, + errors::InvalidArgument("Input scatter_index should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + expert_id_dims.size(), + 2, + errors::InvalidArgument("Input expert_id should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + y_grad_dims.size(), + 2, + errors::InvalidArgument("Input y_grad should have 2 dimensions")); + + PADDLE_ENFORCE_EQ(combine_weights_grad_dims.size(), + 2, + errors::InvalidArgument( + "Input combine_weights_grad should have 2 dimensions")); + + int64_t num_experts = y_grad_dims[0] / capacity; + int64_t hidden_size = y_grad_dims[1]; + + int64_t num_rows = scatter_index_dims[1]; + + gate_logits_grad->set_dims(common::make_ddim({num_rows, num_experts})); + gate_logits_grad->set_dtype(phi::DataType::FLOAT32); + + x_grad->set_dims(common::make_ddim({num_rows, hidden_size})); + x_grad->set_dtype(y_grad.dtype()); +} } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d20c679ba95ce3..f134a17d7beab2 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -689,4 +689,15 @@ void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, const int64_t moe_k, MetaTensor* gate_prob_grad); +void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad); + } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5ddc5fd5a3080a..b0950f27a7ada8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6362,5 +6362,95 @@ void CalAuxLossInferMeta(const MetaTensor& gate_prob, ce->set_dtype(gate_prob.dtype()); } +void MoeGateDispatchInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id) { + auto x_dims = x.dims(); + auto gate_logits_dims = gate_logits.dims(); + auto corr_bias_dims = corr_bias.dims(); + + const int64_t num_rows = x_dims[0]; + const int64_t hidden_size = x_dims[1]; + const int64_t num_experts = gate_logits_dims[1]; + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + errors::InvalidArgument("Input x should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + gate_logits_dims.size(), + 2, + errors::InvalidArgument("Input gate_logits should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + x_dims[0], + gate_logits_dims[0], + errors::InvalidArgument( + "The 0-th dimension of x [%d] " + "must match that of the 0-th dimension gate_logits [%d].", + x_dims[0], + gate_logits_dims[0])); + + PADDLE_ENFORCE_EQ(gate_logits_dims[1] >= k, + true, + errors::InvalidArgument( + "The 1-th dimension of gate_logits [%d] " + "must be greater than or equal to that of k [%d].", + gate_logits_dims[1], + k)); + + PADDLE_ENFORCE_EQ( + corr_bias.dtype(), + phi::DataType::FLOAT32, + errors::InvalidArgument( + "The dtype of rotary_tensor must be float32, but got %d", + corr_bias.dtype())); + + PADDLE_ENFORCE_EQ( + corr_bias_dims.size(), + 1, + errors::InvalidArgument("Input corr_bias should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + corr_bias_dims[0], + gate_logits_dims[1], + errors::InvalidArgument( + "The 0-th dimension of x [%d] " + "must match that of the 0-th dimension gate_logits [%d].", + corr_bias_dims[0], + gate_logits_dims[1])); + + std::vector y_dims; + if (use_pad) { + y_dims = {num_experts * capacity, x_dims[1]}; + } else { + y_dims = {num_rows * k, x_dims[1]}; + } + + y->set_dims(common::make_ddim(y_dims)); + y->set_dtype(x.dtype()); + + combine_weights->set_dims(common::make_ddim({num_rows, k})); + combine_weights->set_dtype(phi::DataType::FLOAT32); + + scatter_index->set_dims(common::make_ddim({k, num_rows})); + scatter_index->set_dtype(phi::DataType::INT32); + + expert_offset->set_dims(common::make_ddim({num_experts})); + expert_offset->set_dtype(phi::DataType::INT64); + + expert_id->set_dims(common::make_ddim({num_rows, k})); + expert_id->set_dtype(phi::DataType::INT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index f511ce4d9d6f28..9e364f96612e3b 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -1296,4 +1296,16 @@ void CalAuxLossInferMeta(const MetaTensor& gate_prob, MetaTensor* seqlen_floats, MetaTensor* ce); +void MoeGateDispatchInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt old mode 100755 new mode 100644 index 8d75f79c1ce431..6ae34c57b97937 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -232,6 +232,8 @@ if(WITH_ROCM) "gpu/cal_aux_loss_kernel.cu" "gpu/cal_aux_loss_grad_kernel.cu" "gpu/build_src_rank_and_local_expert_id_kernel.cu" + "gpu/moe_gate_dispatch_kernel.cu" + "gpu/moe_gate_dispatch_grad_kernel.cu" "gpu/apply_per_channel_scale_kernel.cu" "gpu/calc_reduced_attn_kernel.cu" "gpu/eigvalsh_kernel.cu" diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu index 23e8d8cfd985bc..6c7724ab2cc8b1 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu @@ -14,13 +14,13 @@ #pragma once -#include - -#include "paddle/phi/include/kernels.h" #include "paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h" - +#include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/fused_moe_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -30,7 +30,7 @@ void apply_moe_dispatch_bwd(const T* y_grad, const int* scatter_index, // [s, k] const float* combine_weights_grad, const int* expert_id, // [s, k] - float* gate_logtis_grad, + float* gate_logits_grad, T* x_grad, int64_t num_rows, int64_t k, @@ -58,7 +58,7 @@ void apply_moe_dispatch_bwd(const T* y_grad, topk_grad_with_mask_launcher(combine_weights_grad, expert_id, combine_weights, - gate_logtis_grad, + gate_logits_grad, num_rows, k, num_experts, @@ -73,7 +73,7 @@ void moe_dispatch_bwd(const Context& dev_ctx, const DenseTensor& y_grad, // [num_experts * capacity, h] const DenseTensor& combine_weights_grad, // [s, k] const DenseTensor& x_grad, - const DenseTensor& gate_logtis_grad, + const DenseTensor& gate_logits_grad, int64_t capacity, bool use_all2all_permute, int64_t world_size, @@ -85,15 +85,15 @@ void moe_dispatch_bwd(const Context& dev_ctx, #else int64_t hidden_size = y_grad.dims()[1]; #endif - int64_t num_experts = gate_logtis_grad.dims()[1]; + int64_t num_experts = gate_logits_grad.dims()[1]; apply_moe_dispatch_bwd(y_grad.data(), combine_weights.data(), scatter_index.data(), combine_weights_grad.data(), expert_id.data(), - static_cast gate_logtis_grad.data(), - static_cast x_grad.data(), + const_cast(gate_logits_grad.data()), + const_cast(x_grad.data()), num_rows, k, hidden_size, @@ -116,7 +116,7 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, const int64_t capacity, const bool use_pad, DenseTensor* x_grad, - DenseTensor* gate_logtis_grad) { + DenseTensor* gate_logits_grad) { auto y_grad_dims = y_grad.dims(); auto scatter_index_dims = scatter_index.dims(); @@ -140,7 +140,7 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, const DenseTensor t_scatter_index__ = t_scatter_index_; dev_ctx.template Alloc(x_grad); - dev_ctx.template Alloc(gate_logtis_grad); + dev_ctx.template Alloc(gate_logits_grad); moe_dispatch_bwd(dev_ctx, combine_weights, @@ -149,7 +149,7 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, y_grad, combine_weights_grad, *x_grad, - *gate_logtis_grad, + *gate_logits_grad, capacity); } diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu new file mode 100644 index 00000000000000..e9e4fb7a37d2ce --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -0,0 +1,372 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_gate_dispatch_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" +namespace phi { + +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + // const int max_seq_len, + const phi::CubKeyValueSorter &sorter) { + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * inter_size); + // const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + // total_ws_bytes += padded_experts * sizeof(int64_t); // Hold + // total_rows_before_expert_ // expert_cnt total_ws_bytes += num_softmax_outs + // * sizeof(KeyT); const int bytes_for_fc1_result = interbuf_size * + // sizeof(KeyT); + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity + // // 用所有 bit 做排序,会降低些许性能,但是防止越界 + total_ws_bytes += sorter_ws_size_bytes; // intermediate (fc1) output + cub + // sorting workspace + // std::cout<<"sorter_ws_size_bytes = "< +void apply_moe_dispatch_fwd(const Context &dev_ctx, + const T *x, + const float *gate_logits, + const float *corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + T *y, + float *combine_weights, + int *scatter_index, + int64_t *expert_offset, + int *expert_id, + bool use_pad, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + phi::CubKeyValueSorter sorter(stream); + // phi::funcs::SetConstant zero; + // zero(ctx, &finished_tensor, false); + + DenseTensor xpanded_source_row_to_expanded_dest_row_tensor = + phi::Empty(dev_ctx, IntArray({num_rows, k})); + // int* expanded_source_row_to_expanded_dest_row = + // expanded_source_row_to_expanded_dest_row_tensor.data(); + + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); + + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); + DenseTensor active_cnt_tensor = + phi::Empty(dev_ctx, IntArray({1})); + + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, // hidden-size=0 + 0, // inter-size=0 + num_experts, + k, + sorter); + + DenseTensor ws_ptr_tensor = + phi::Empty(dev_ctx, IntArray({bytes})); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + // Pointers + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expert_id_; + + // T* permuted_data_; + float *softmax_out_; + // int64_t* total_rows_before_expert_; + T *fc1_result_; + + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * 0); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expert_id_ = permuted_experts_ + num_moe_inputs; + + // permuted_data_ = reinterpret_cast(expert_id_ + num_moe_inputs); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); + + // only use one number + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); + + fc1_result_ = reinterpret_cast(expert_id_ + num_moe_inputs); + softmax_out_ = nullptr; + +#ifdef DEBUG_MOE_OP + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); +#endif + + topk_gating_softmax_kernelLauncher(gate_logits, + corr_bias, + combine_weights, // output + softmax_out_, // no use + expert_id, // output + source_rows_, // output + num_rows, + num_experts, + k, + stream); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + // modify expert-id according to k + if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id + modify_expert_id_launcher( + expert_id, expert_id_, k, num_rows, num_experts, stream); + + // calc expert-size + /* + if (!use_pad) + cal_expert_size_and_filter_launcher(expert_id, + k * num_rows, + num_experts, + capacity, + stream); + */ +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id, 8, 16, std::string("expert-id after modified")); +#endif + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + use_pad ? expert_id_ : expert_id, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + if (use_pad) + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_total_rows_before_expert( + permuted_experts_, k * num_rows, num_experts, expert_offset, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(expert_offset, 8, 16, std::string("expert_offset")); + int64_t num_active_host_v2; + cudaMemcpy(&num_active_host_v2, + expert_offset + num_experts - 1, + sizeof(int64_t), + cudaMemcpyDeviceToHost); + std::cerr << "[DEBUG] num_active v2: " << num_active_host_v2 << std::endl; + print_to_screen1(permuted_experts_, + 8, + num_active_host_v2 + 2, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); +#endif + + if (!use_all2all_permute) { + initialize_moe_routing_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + use_pad, + stream); + } else { + PD_CHECK(num_experts > 0); + PD_CHECK(world_size > 0); + initialize_moe_routing_permute_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + world_size, + num_local_experts, + stream); + } + + // turn expert_offset_ptr into experts_num + // auto expert_offset_ptr = thrust::device_pointer_cast(expert_offset); + // thrust::adjacent_difference( + // expert_offset_ptr, expert_offset_ptr + num_experts, expert_offset_ptr + // ); +#ifdef DEBUG_MOE_OP + print_to_screen1( + scatter_index, 8, 16, std::string("scatter_index after pad")); +#endif + // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, + // cudaMemcpyDeviceToDevice); cudaMemcpy(combine_weights, expert_scales_float, + // sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); + return; +} + +template +void moe_dispatch_fwd(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + const DenseTensor &y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &expert_offset, + const DenseTensor &expert_id, + bool use_pad, + int64_t use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1) { + apply_moe_dispatch_fwd( + dev_ctx, + x.data(), + gate_logits.data(), + corr_bias ? corr_bias.get_ptr()->data() : nullptr, + num_rows, + num_experts, + hidden_size, + capacity, + k, + const_cast(y.data()), + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(expert_offset.data()), + const_cast(expert_id.data()), + use_pad, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGradDispatchKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const DenseTensor &corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor *y, + DenseTensor *combine_weights, + DenseTensor *scatter_index, + DenseTensor *expert_offset, + DenseTensor *expert_id) { + dev_ctx.template Alloc(expert_id); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(combine_weights); + dev_ctx.template Alloc(y); + + auto x_dims = x.dims(); + auto gate_logits_dims = gate_logits.dims(); + auto corr_bias_dims = corr_bias.dims(); + + const int64_t num_rows = x_dims[0]; + const int64_t hidden_size = x_dims[1]; + const int64_t num_experts = gate_logits_dims[1]; + + moe_dispatch_fwd(dev_ctx, + x, + gate_logits, + corr_bias, + num_rows, + num_experts, + hidden_size, + capacity, + k, + *y, + *combine_weights, + *scatter_index, + *expert_offset, + *expert_id, + use_pad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch, + GPU, + ALL_LAYOUT, + phi::MoeGradDispatchKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h new file mode 100644 index 00000000000000..b2b94b9e1faf1e --- /dev/null +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -0,0 +1,457 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/common/exception.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" + +template +__launch_bounds__(TPB) __global__ + void moe_top_k(const T* inputs_after_softmax, + const T* bias, // bias could be nullptr if not used + T* output, + int* indices, + int* source_rows, + const int num_experts, + const int k) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] + : inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = + bias ? inputs_after_softmax[thread_read_offset + result_kvp.key] + : result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, + const T* bias, + T* output, + T* softmax, // no use + int* indices, + int* source_row, + const int num_rows, + const int num_experts, + const int k, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + static constexpr int TPB = 256; + moe_top_k<<>>( + input, bias, output, indices, source_row, num_experts, k); +} + +template +__global__ void modify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id[idx], ik); + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx] << offset) | mask; + // printf("after: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id_out[idx], + // ik); +} + +template +void modify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_expert_id<<>>( + expert_id, expert_id_out, k, num_rows, num_experts); +} + +template +__global__ void unmodify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx] >> offset); +} + +template +void unmodify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + unmodify_expert_id<<>>( + expert_id, expert_id_out, k, num_rows, num_experts); +} + +template +__device__ inline int find_total_elts_leq_target(const T* sorted_indices, + const int arr_length, + const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +template +__global__ void compute_total_rows_before_expert_kernel( + const T* sorted_experts, + const int sorted_experts_len, + const int64_t num_experts, + int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = + find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); + // total_rows_before_expert[0] = 0; + // total_rows_before_expert[1] = 1; + // if (sorted_experts_len > 3) { + // for (int i=0; i<35;i++){ + // total_rows_before_expert[i] = i; + // } + // } +} + +template +void compute_total_rows_before_expert(const T* sorted_indices, + const int total_indices, + const int64_t num_experts, + int64_t* total_rows_before_expert, + const cudaStream_t& stream) { + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity) { + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + } + return; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + // printf("going through: capacity=%lld, num_active=%lld, row=[%d->%d], + // row-in-expert %lld\n", + // capacity, + // num_active, + // expanded_dest_row, expanded_source_row, + // row_in_expert + // ); + if (use_pad) num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + if (use_pad) { + dest_row_ptr = + permuted_output + iexpert * capacity * cols + row_in_expert * cols; + } else { + dest_row_ptr = permuted_output + expanded_dest_row * cols; + } + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } +} + +template +void initialize_moe_routing_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad, + cudaStream_t stream) { + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad); + } else { + initialize_moe_routing_kernel<<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad); + } +} + +/** + * 原逻辑的output: + * R0E0 + * R0E1 + * R1E0 + * R1E1 + * + * 我们想对all2all和专家gemm做overlap, 所以需要将all2all拆成流水线, + * 为了便于后续计算, 此kernel的output: R0E0 R1E0 R0E1 R1E1 + */ +template +__global__ void initialize_moe_routing_permute_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. +#pragma unroll + for (int i = 0; i < LoopSize; i++) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x + i * gridDim.x; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity) { + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + } + continue; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + + const int64_t irank = iexpert / num_local_experts; + const int64_t local_iexpert = iexpert % num_local_experts; + dest_row_ptr = permuted_output + + local_iexpert * world_size * capacity * cols + + irank * capacity * cols + row_in_expert * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +template +void initialize_moe_routing_permute_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts, + cudaStream_t stream) { + const int loop_size = 2; + const int blocks = (num_rows * k) / loop_size; + assert((num_rows * k) % loop_size == 0); + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_permute_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts); + } else { + initialize_moe_routing_permute_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts); + } +} diff --git a/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h index b61e8ae3398f48..6c3b4d6d6f241c 100644 --- a/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h +++ b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h @@ -28,7 +28,7 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, const int64_t capacity, const bool use_pad, DenseTensor* x_grad, - DenseTensor* gate_logtis_grad); + DenseTensor* gate_logits_grad); template void moe_dispatch_bwd(const Context& dev_ctx, @@ -38,7 +38,7 @@ void moe_dispatch_bwd(const Context& dev_ctx, const DenseTensor& y_grad, // [num_experts * capacity, h] const DenseTensor& combine_weights_grad, // [s, k] const DenseTensor& x_grad, - const DenseTensor& gate_logtis_grad, + const DenseTensor& gate_logits_grad, int64_t capacity, bool use_all2all_permute = false, int64_t world_size = -1, diff --git a/paddle/phi/kernels/moe_gate_dispatch_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_kernel.h new file mode 100644 index 00000000000000..a87dad8e82c925 --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MoeGradDispatchKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const DenseTensor& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* y, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, + DenseTensor* expert_id); + +} // namespace phi diff --git a/paddle/phi/kernels/moe_kernel_impl.h b/paddle/phi/kernels/moe_kernel_impl.h new file mode 100644 index 00000000000000..bae05d8c094e20 --- /dev/null +++ b/paddle/phi/kernels/moe_kernel_impl.h @@ -0,0 +1,642 @@ +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" +namespace phi { + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + explicit CubKeyValueSorter(cudaStream_t stream = 0); + + explicit CubKeyValueSorter(const int num_experts); + + void update_num_experts(const int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false); + + template + void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; + cudaStream_t stream_; +}; + +// ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + + 3; // 额外增加 3 位用于标记 topk的位置 +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32, + stream_); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_, + stream_); + } + return required_storage; +} + +template +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss + << "Error. The allocated workspace is too small to run this problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } +} + +template <> +void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const __nv_bfloat16* keys_in, + __nv_bfloat16* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) {} + +// CubKeyValueSorter sorter_(stream); + +// -------- initialize_expert_choice_route_kernel -------- // +template +__global__ void initialize_expert_choice_route_kernel( + int* expert_for_source_row, + int* source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* total_rows_before_expert, + T* attr_mask, + const int cols, + const int k, + const int batch_size) { + int start = cols * blockIdx.x; + + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + expert_for_source_row[start + i] = blockIdx.x; + source_row[start + i] = start + i; + expanded_source_row_to_expanded_dest_row[start + i] = -1; + attr_mask[start + i] = (T)1.0f; + } + if (threadIdx.x == 0) { + total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); + } +} + +// -------- softmax_kernel -------- // +template +__global__ void softmax_kernel_v4( + T* qk_buf_, + const T* qk_buf_src, // shape [batch_size, seq_len] + const T* attr_mask, // shape [batch_size, seq_len] + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + float data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x; + + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + data[i] = qk + mask_val; + local_max = fmax(local_max, data[i]); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + data[i] = __expf(data[i] - s_max); + local_sum += data[i]; + } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + qk_buf_[qk_offset] = (T)(data[i] * s_mean); + } +#endif +} + +template +__global__ void softmax_kernel_v4_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x; + + T2 qk = qk_buf_half2[qk_offset]; + T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); + mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), + __float2half2_rn(-10000.0f)); + + data[i] = __hadd2(qk, mask_val); + + local_max = fmax( + local_max, + fmax(static_cast(data[i].x), static_cast(data[i].y))); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); + local_sum += static_cast(data[i].x + data[i].y); + } + + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); + } +#endif +} + +template +__global__ void softmax_kernel_v5_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[NUM][ITEMS_PER_THREAD]; + + int qk_offset[NUM]; + + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_max[j] = -1e20f; + } + + const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + int mask_offset[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + + T2 mask_val[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); + } + + T2 qk[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), + __float2half2_rn(-10000.0f)); + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = __hadd2(qk[j], mask_val[j]); + local_max[j] = fmax(local_max[j], + fmax(static_cast(data[j][i].x), + static_cast(data[j][i].y))); + } + } + if (blockDim.x <= 32) { + phi::funcs::WarpReduceMaxV2(local_max); + } else { + phi::funcs::BlockReduceMaxV2(local_max); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_max[j] = local_max[j]; + } + } + __syncthreads(); + float local_sum[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; + } + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + local_sum[j] += static_cast(data[j][i].x + data[j][i].y); + } + } + + if (blockDim.x <= 32) { + phi::funcs::WarpReduceSumV2(local_sum); + + } else { + phi::funcs::BlockReduceSumV2(local_sum); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); + } + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_buf_half2[qk_offset[j]] = + __hmul2(data[j][i], __float2half2_rn(s_sum[j])); + } + } +#endif +} + +// -------- transpose_kernel -------- // +template +__global__ void transposeAxis01( + T* out, T* in, const int dim0, const int dim1, const int dim2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < dim0 * dim1 * dim2) { + const int input_dim2_index = index % dim2; + index = (index - input_dim2_index) / dim2; + const int input_dim1_index = index % dim1; + index = (index - input_dim1_index) / dim1; + const int input_dim0_index = index % dim0; + + out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + + input_dim2_index] = in[input_dim0_index * dim1 * dim2 + + input_dim1_index * dim2 + input_dim2_index]; + } +} + +// -------- padding_kernel -------- // +template +__global__ void paddingKernel(T* output1, + int* output2, + const T* input1, + const int* input2, + const int* input_lengths, + const int num_tokens, + const int batch_size, + const int max_seq_len, + const int num_experts) { + const bool IS_FP16 = std::is_same::value; + const T MIN_T_VAL = (IS_FP16) ? (T)HALF_FLT_MIN : (T)FLT_MIN; + int offset1 = blockIdx.x * num_tokens; + int offset2 = blockIdx.x * batch_size * max_seq_len; + for (int i = 0; i < batch_size; i++) { + const T* in1_ptr = input1 + offset1; + const int* in2_ptr = input2 + offset1; + int input_length = input_lengths[i]; + offset1 += input_length; + + T* out1_ptr = output1 + offset2; + int* out2_ptr = output2 + offset2; + offset2 += max_seq_len; + + for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) { + if (j < input_length) { + out1_ptr[j] = in1_ptr[j]; + out2_ptr[j] = in2_ptr[j]; + } else { + out1_ptr[j] = MIN_T_VAL; + out2_ptr[j] = 0; + } + } + } +} + +// -------- general_topk_pair_sort_kernel -------- // +template +__global__ void general_topk_pair_sort(T* out_keys, + int* out_values, + T* in_keys, + int* in_values) { + typedef cub::BlockRadixSort + BlockRadixSort; + typedef cub:: + BlockLoad + BlockLoadKey; + typedef cub:: + BlockLoad + BlockLoadValue; + typedef cub:: + BlockStore + BlockStoreKey; + typedef cub::BlockStore + BlockStoreValue; + + __shared__ union { + typename BlockRadixSort::TempStorage sort; + typename BlockLoadKey::TempStorage loadkey; + typename BlockLoadValue::TempStorage loadvalue; + typename BlockStoreKey::TempStorage storekey; + typename BlockStoreValue::TempStorage storevalue; + } temp_storage; + + int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; + + T thread_keys[ITEMS_PER_THREAD]; + int thread_values[ITEMS_PER_THREAD]; + BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); + BlockLoadValue(temp_storage.loadvalue) + .Load(in_values + block_offset, thread_values); + __syncthreads(); + + BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); + __syncthreads(); + + BlockStoreKey(temp_storage.storekey) + .Store(out_keys + block_offset, thread_keys); + BlockStoreValue(temp_storage.storevalue) + .Store(out_values + block_offset, thread_values); +} + +// -------- finalize_moe_routing_kernel -------- // +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int cols, + const int k, + bool ec_route) { + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + const T* skip_row_ptr = skip + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output = skip_row_ptr[tid]; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + if (ec_route && expanded_permuted_row == -1) continue; + const int64_t k_offset = + ec_route ? expanded_original_row : original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = + thread_output + + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} + +// -------- initialize_moe_routing_kernel -------- // +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_rows, + const int active_rows, + const int cols, + const int k, + const int max_seq_len, + bool ec_route) { + // using LoadT = phi::AlignedVector; + // LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row / + k * max_seq_len + + expanded_dest_row % k] + : expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + dest_row_ptr[tid] = source_row_ptr[tid]; + // phi::Load(&source_row_ptr[tid], &src_vec); + // phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 0ff49350a12c47..4e668dcee1ec98 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2256,6 +2256,16 @@ kernel : func : mode_grad +- backward_op : moe_gate_dispatch_grad + forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(x_grad), Tensor(gate_logits_grad) + infer_meta : + func : MoeGateDispatchGradInferMeta + kernel : + func : moe_gate_dispatch_grad + data_type : y_grad + - backward_op : mp_allreduce_sum_grad forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out) args : (Tensor out_grad, int ring_id = 0) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 7d39ba2010b28c..5c2d7eee69009b 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3602,6 +3602,16 @@ backward : mode_grad interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface +- op : moe_gate_dispatch + args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + infer_meta : + func : MoeGateDispatchInferMeta + kernel : + func : moe_gate_dispatch + data_type : x + backward : moe_gate_dispatch_grad + - op : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) From 3e4e3924ed02e00c4530a2a5b53f4932a10a10c1 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Fri, 23 May 2025 11:02:13 +0000 Subject: [PATCH 25/71] fix-bugs --- paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu index e9e4fb7a37d2ce..1164e642bba93b 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -27,7 +27,7 @@ size_t getWorkspaceSize(const int num_rows, const int num_experts, const int k, // const int max_seq_len, - const phi::CubKeyValueSorter &sorter) { + phi::CubKeyValueSorter &sorter) { // const int buf_size = AlignTo16(k * num_rows * hidden_size); // const int interbuf_size = AlignTo16(k * num_rows * inter_size); // const int padded_experts = AlignTo16(num_experts); From 0242f9c970ebffef215743fa788e7c81bd91e277 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Sun, 25 May 2025 07:44:04 +0000 Subject: [PATCH 26/71] fix optional Tensor --- paddle/phi/infermeta/multiary.cc | 94 +++++++++++-------- paddle/phi/kernels/cal_aux_loss_kernel.h | 7 +- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 45 ++++----- .../kernels/gpu/moe_gate_dispatch_kernel.cu | 3 +- paddle/phi/kernels/moe_gate_dispatch_kernel.h | 2 +- 5 files changed, 79 insertions(+), 72 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index b0950f27a7ada8..a99f305e047971 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6286,8 +6286,6 @@ void CalAuxLossInferMeta(const MetaTensor& gate_prob, MetaTensor* ce) { auto gate_prob_dims = gate_prob.dims(); auto dispatch_mask_dims = dispatch_mask.dims(); - auto tokens_mask_dims = tokens_mask.dims(); - auto dispatch_tokens_mask_dims = dispatch_tokens_mask.dims(); PADDLE_ENFORCE_EQ( gate_prob_dims.size(), @@ -6331,26 +6329,44 @@ void CalAuxLossInferMeta(const MetaTensor& gate_prob, phi::DataType::INT64, errors::InvalidArgument("The input dispatch_mask type should be INT64")); - PADDLE_ENFORCE_EQ( - tokens_mask_dims.size(), - 1, - errors::InvalidArgument("Input tokens_mask should have 1 dimensions")); + if (tokens_mask) { + auto tokens_mask_dims = tokens_mask.dims(); + PADDLE_ENFORCE_EQ( + tokens_mask_dims.size(), + 1, + errors::InvalidArgument("Input tokens_mask should have 1 dimensions")); - PADDLE_ENFORCE_EQ( - tokens_mask.dtype(), - gate_prob.dtype(), - errors::InvalidArgument( - "The input tokens_mask type should be equal to gate_prob type")); + PADDLE_ENFORCE_EQ( + tokens_mask.dtype(), + gate_prob.dtype(), + errors::InvalidArgument( + "The input tokens_mask type should be equal to gate_prob type")); - PADDLE_ENFORCE_EQ(dispatch_tokens_mask_dims.size(), - 1, - errors::InvalidArgument( - "Input dispatch_tokens_mask should have 1 dimensions")); + PADDLE_ENFORCE_EQ( + tokens_mask_dims[0], + gate_prob_dims[0], + errors::InvalidArgument( + "The 0-th dimension of tokens_mask [%d] " + "must match that of the 0-th dimension of gate_prob [%d].", + tokens_mask_dims[0], + gate_prob_dims[0])); + } - PADDLE_ENFORCE_EQ(dispatch_tokens_mask.dtype(), - phi::DataType::BOOL, - errors::InvalidArgument( - "The input dispatch_tokens_mask type should be BOOL")); + if (dispatch_tokens_mask) { + auto dispatch_tokens_mask_dims = dispatch_tokens_mask.dims(); + + PADDLE_ENFORCE_EQ( + dispatch_tokens_mask_dims.size(), + 1, + errors::InvalidArgument( + "Input dispatch_tokens_mask should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + dispatch_tokens_mask.dtype(), + phi::DataType::BOOL, + errors::InvalidArgument( + "The input dispatch_tokens_mask type should be BOOL")); + } l_aux_loss->set_dims({1}); l_aux_loss->set_dtype(gate_prob.dtype()); @@ -6375,7 +6391,6 @@ void MoeGateDispatchInferMeta(const MetaTensor& x, MetaTensor* expert_id) { auto x_dims = x.dims(); auto gate_logits_dims = gate_logits.dims(); - auto corr_bias_dims = corr_bias.dims(); const int64_t num_rows = x_dims[0]; const int64_t hidden_size = x_dims[1]; @@ -6408,26 +6423,29 @@ void MoeGateDispatchInferMeta(const MetaTensor& x, gate_logits_dims[1], k)); - PADDLE_ENFORCE_EQ( - corr_bias.dtype(), - phi::DataType::FLOAT32, - errors::InvalidArgument( - "The dtype of rotary_tensor must be float32, but got %d", - corr_bias.dtype())); + if (corr_bias) { + auto corr_bias_dims = corr_bias.dims(); + PADDLE_ENFORCE_EQ( + corr_bias.dtype(), + phi::DataType::FLOAT32, + errors::InvalidArgument( + "The dtype of rotary_tensor must be float32, but got %d", + corr_bias.dtype())); - PADDLE_ENFORCE_EQ( - corr_bias_dims.size(), - 1, - errors::InvalidArgument("Input corr_bias should have 1 dimensions")); + PADDLE_ENFORCE_EQ( + corr_bias_dims.size(), + 1, + errors::InvalidArgument("Input corr_bias should have 1 dimensions")); - PADDLE_ENFORCE_EQ( - corr_bias_dims[0], - gate_logits_dims[1], - errors::InvalidArgument( - "The 0-th dimension of x [%d] " - "must match that of the 0-th dimension gate_logits [%d].", - corr_bias_dims[0], - gate_logits_dims[1])); + PADDLE_ENFORCE_EQ( + corr_bias_dims[0], + gate_logits_dims[1], + errors::InvalidArgument( + "The 0-th dimension of x [%d] " + "must match that of the 0-th dimension gate_logits [%d].", + corr_bias_dims[0], + gate_logits_dims[1])); + } std::vector y_dims; if (use_pad) { diff --git a/paddle/phi/kernels/cal_aux_loss_kernel.h b/paddle/phi/kernels/cal_aux_loss_kernel.h index 4dfd4b5b020a4f..3a73a1a376cab7 100644 --- a/paddle/phi/kernels/cal_aux_loss_kernel.h +++ b/paddle/phi/kernels/cal_aux_loss_kernel.h @@ -13,7 +13,8 @@ // limitations under the License. #pragma once -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -21,8 +22,8 @@ template void CalAuxLossKernel(const Context& dev_ctx, const DenseTensor& gate_prob, const DenseTensor& dispatch_mask, - const DenseTensor& tokens_mask, - const DenseTensor& dispatch_tokens_mask, + const paddle::optional& tokens_mask, + const paddle::optional& dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index bdd9d46d811008..402bb88866e4d3 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -204,8 +204,8 @@ template void CalAuxLossKernel(const Context& dev_ctx, const DenseTensor& gate_prob, const DenseTensor& dispatch_mask, - const DenseTensor& tokens_mask, - const DenseTensor& dispatch_tokens_mask, + const paddle::optional& tokens_mask, + const paddle::optional& dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, @@ -215,38 +215,27 @@ void CalAuxLossKernel(const Context& dev_ctx, DenseTensor* ce) { auto gate_prob_dims = gate_prob.dims(); auto dispatch_mask_dims = dispatch_mask.dims(); - auto dispatch_tokens_mask_dim = dispatch_tokens_mask.dims(); - const T* gate_prob_data = gate_prob.data(); - const int64_t* dispatch_mask_data = dispatch_mask.data(); - const T* tokens_mask_data = tokens_mask.data(); - const bool* dispatch_tokens_mask_data = dispatch_tokens_mask.data(); + int64_t dispatch_tokens_mask_len = 0; + if (dispatch_tokens_mask) { + dispatch_tokens_mask_len = dispatch_tokens_mask.get_ptr()->dims()[0]; + } T* l_aux_loss_data = dev_ctx.template Alloc(l_aux_loss); T* seqlen_float_data = dev_ctx.template Alloc(seqlen_float); T* ce_data = dev_ctx.template Alloc(ce); - int64_t row_gate_prob = gate_prob_dims[0]; - int64_t col_gate_prob = gate_prob_dims[1]; - - int64_t col_dispatch_mask = 0; - int64_t row_dispatch_mask = dispatch_mask_dims[0]; - if (dispatch_mask_dims.size() > 1) { - col_dispatch_mask = dispatch_mask_dims[1]; - } else { - col_dispatch_mask = 1; - } - - int dispatch_tokens_mask_len = dispatch_tokens_mask_dim[0]; - - cal_aux_loss(gate_prob_data, - row_gate_prob, - col_gate_prob, - dispatch_mask_data, - row_dispatch_mask, - col_dispatch_mask, - tokens_mask_data, - dispatch_tokens_mask_data, + cal_aux_loss(gate_prob.data(), + gate_prob_dims[0], + gate_prob_dims[1], + dispatch_mask.data(), + dispatch_mask_dims[0], + dispatch_mask_dims.size() > 1 ? dispatch_mask_dims[1] + : static_cast(1), + tokens_mask ? tokens_mask.get_ptr()->data() : nullptr, + dispatch_tokens_mask + ? dispatch_tokens_mask.get_ptr()->data() + : nullptr, dispatch_tokens_mask_len, num_experts, use_group, diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu index 1164e642bba93b..98a75af605cb60 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -320,7 +320,7 @@ template void MoeGradDispatchKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &gate_logits, - const DenseTensor &corr_bias, + const paddle::optional &corr_bias, const int64_t k, const int64_t capacity, const bool use_pad, @@ -337,7 +337,6 @@ void MoeGradDispatchKernel(const Context &dev_ctx, auto x_dims = x.dims(); auto gate_logits_dims = gate_logits.dims(); - auto corr_bias_dims = corr_bias.dims(); const int64_t num_rows = x_dims[0]; const int64_t hidden_size = x_dims[1]; diff --git a/paddle/phi/kernels/moe_gate_dispatch_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_kernel.h index a87dad8e82c925..b17d0387c7a750 100644 --- a/paddle/phi/kernels/moe_gate_dispatch_kernel.h +++ b/paddle/phi/kernels/moe_gate_dispatch_kernel.h @@ -22,7 +22,7 @@ template void MoeGradDispatchKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& gate_logits, - const DenseTensor& corr_bias, + const paddle::optional& corr_bias, const int64_t k, const int64_t capacity, const bool use_pad, From 9b09cff900c5a9f8f7a9ccb70638dab4125bd384 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Mon, 26 May 2025 06:11:28 +0000 Subject: [PATCH 27/71] update cal_aux_loss_kernel --- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index 402bb88866e4d3..3eb13b55da78f8 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -44,6 +44,8 @@ __global__ void cal_aux_loss_kernel( extern __shared__ int64_t aux_loss_shared[]; static __shared__ float shared_float[1]; + float scale_val = 1.f; + // 算seqlen_float float seqlen_float_f = 0.f; if (dispatch_tokens_mask) { @@ -56,6 +58,26 @@ __global__ void cal_aux_loss_kernel( } seqlen_float_f = phi::funcs::BlockReduceSum(local_seqlen_float_f, 0xFFFFFFFF); + + // 算scale_val + if (tokens_mask && row_gate_prob != dispatch_tokens_mask_len) { + float sum_tokens_mask = 0.f; + float local_sum_tokens_mask = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; + T mask = tokens_mask[k * blockDim.x + threadIdx.x]; + local_sum_tokens_mask += static_cast(mask); + } + sum_tokens_mask = + phi::funcs::BlockReduceSum(local_sum_tokens_mask, 0xFFFFFFFF); + if (threadIdx.x == 0) { + shared_float[0] = seqlen_float_f / max(sum_tokens_mask, clip_min); + } + __syncthreads(); + scale_val = shared_float[0]; + } + } else if (tokens_mask) { float local_seqlen_float_f = 0.f; int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; @@ -107,26 +129,6 @@ __global__ void cal_aux_loss_kernel( } } - // 算scale_val - float scale_val = 1.f; - if (tokens_mask) { - float sum_tokens_mask = 0.f; - float local_sum_tokens_mask = 0.f; - int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; - for (int64_t k = 0; k < num_k; ++k) { - if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; - T mask = tokens_mask[k * blockDim.x + threadIdx.x]; - local_sum_tokens_mask += static_cast(mask); - } - sum_tokens_mask = - phi::funcs::BlockReduceSum(local_sum_tokens_mask, 0xFFFFFFFF); - if (threadIdx.x == 0) { - shared_float[0] = seqlen_float_f / max(sum_tokens_mask, clip_min); - } - __syncthreads(); - scale_val = shared_float[0]; - } - // 算me和l_aux float l_aux = 0.f; int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; From 3b3c98cf70af0d294ae8702ddc1924c737fec59f Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Mon, 26 May 2025 06:18:29 +0000 Subject: [PATCH 28/71] Finished moe_combine & expand_modality_expert_id integrate and optests. --- .../paddle/incubate/nn/functional/__init__.py | 8 +- .../functional/expand_modality_expert_id.py | 2 +- .../ernie_utils/moe_all_gather_layer.py | 264 +++++ test/legacy_test/ernie_utils/moe_layer.py | 207 ++++ .../ernie_utils/moe_layer_uneven.py | 279 +++++ test/legacy_test/ernie_utils/top2_gate.py | 1012 +++++++++++++++++ ...test_incubate_expand_modality_expert_id.py | 126 ++ test/legacy_test/test_incubate_moe_combine.py | 175 +++ 8 files changed, 2068 insertions(+), 5 deletions(-) create mode 100644 test/legacy_test/ernie_utils/moe_all_gather_layer.py create mode 100644 test/legacy_test/ernie_utils/moe_layer.py create mode 100644 test/legacy_test/ernie_utils/moe_layer_uneven.py create mode 100644 test/legacy_test/ernie_utils/top2_gate.py create mode 100644 test/legacy_test/test_incubate_expand_modality_expert_id.py create mode 100644 test/legacy_test/test_incubate_moe_combine.py diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index cf7afb504f6782..a9292038396f85 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -43,8 +43,8 @@ from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, ) -# from .moe_combine import moe_combine -# from .expand_modality_expert_id import expand_modality_expert_id +from .moe_combine import moe_combine +from .expand_modality_expert_id import expand_modality_expert_id # from .moe_gate_dispatch_permute import moe_gate_dispatch_permute __all__ = [ @@ -65,6 +65,6 @@ "blha_get_max_len", "block_multihead_attention", "swiglu", - # "moe_combine", - # "expand_modality_expert_id", + "moe_combine", + "expand_modality_expert_id", ] diff --git a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py index 086b42c83b9a4a..1d6351da47602f 100644 --- a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py +++ b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py @@ -41,4 +41,4 @@ def expand_modality_expert_id( 'is_group_expert': is_group_expert } helper.append_op(type='expand_modality_expert_id', inputs=inputs, attrs=attrs, outputs={'expert_id_out': expert_id_out}) - return y + return expert_id_out diff --git a/test/legacy_test/ernie_utils/moe_all_gather_layer.py b/test/legacy_test/ernie_utils/moe_all_gather_layer.py new file mode 100644 index 00000000000000..334d1b9c6a823f --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_all_gather_layer.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +# !/usr/bin/env python3 +""" +@author: kebo +@contact: kebo01@baidu.com + +@version: 1.0 +@file: moe_layer_all_gather.py +@time: 2024/09/21 15:11:10 +@Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved + +这一行开始写关于本文件的说明与解释 + + +""" +from typing import Any, Tuple, List, Dict, Optional, Callable +import itertools +from collections import defaultdict +import logging +import contextlib +import numpy as np +import inspect + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle import framework +import paddle.nn.functional as F +from paddle import nn +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.fleet.utils import recompute +from paddle.distributed.communication.group import Group + +from .top2_gate import TopKGateFused, compute_optimal_transport +from paddle.incubate.tensor.manipulation import async_offload, async_reload + +from .moe_layer import MOELayer, fuse_logging + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None + + +def profile(_): + """dumy profile""" + return contextlib.nullcontext() + + +logger = logging.getLogger(__name__) + +if False: + try: + from paddle_xpu_nn import moe_gate_dispatch as xpu_moe_gate_dispatch + except ImportError: + xpu_moe_gate_dispatch = None + logger.warning("`xpu moe dispatch` not found") +else: + try: + import moe_ops + except ImportError: + moe_ops = None + logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + try: + import moe_ops_partial + except ImportError: + moe_ops_partial = None + logger.warning( + "`moe-ops-partial` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + try: + import moe_ops_partial_nosoftmaxtopk + except ImportError: + moe_ops_partial_nosoftmaxtopk = None + logger.warning( + "`moe-ops-partial-nosoftmaxtopk` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + + try: + import moe_utils + except ImportError: + moe_utils = None + logger.warning("`moe_utils` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + +class MOEAllGatherLayer(MOELayer): + """_summary_ + + Args: + MOELayer (_type_): _description_ + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + dense_experts: Optional[List[nn.Layer]] = None, # no use + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + enable_logging, + k, + enable_bpr, + all_to_all_dropout, + group_experts, + moe_statics, + ) + + +class MOEAllGatherLayerV2(MOEAllGatherLayer): + """_summary_ + + Args: + MOELayer (_type_): _description_ + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + dense_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_expert_alltoall_overlap=False, + use_padding=True, + dense_token_type=3, # considerd as dense tokens (no moe) + moe_statics=None, + ): + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + dense_experts, + group, + recompute, + enable_logging, + k, + enable_bpr, + all_to_all_dropout, + group_experts, + moe_statics, + ) + self.enable_reverse_token_drop = enable_reverse_token_drop + self.is_allgather_moe_layer = True + # assert self.gate.config.sequence_parallel + world_size = self.gate.config.moe_world_size + self.use_padding = use_padding + + # 全局 gate gather + self.send_rank = None + self.local_expert_id = None + self.dense_token_type = dense_token_type + self.dense_experts = dense_experts + self.capacity_tensor = None + self.use_expert_out_alltoall = use_expert_out_alltoall + self.use_expert_alltoall_overlap = use_expert_alltoall_overlap + logger.info( + f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " + f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " + f"enable_reverse_token_drop={self.enable_reverse_token_drop}" + ) + self.two = paddle.to_tensor(2, dtype=paddle.float32) + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_type_ids): + """process gatelogits w/ moe utils""" + #top_k = 1 if isinstance(self.gate, SinkHornGateFused) else self.k + top_k = self.k + num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] // self.config.moe_world_size + group_size = gate_logits_lm.shape[-1] // top_k + if self.group_experts: + assert not self.use_correction_bias + gate_logits_lm = gate_logits_lm.reshape([gate_logits_lm.shape[0], top_k, -1]) + prob_lm = self.gate.act(gate_logits_lm) + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=1, axis=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + expert_id_lm = expert_id_lm.squeeze(-1) + else: + prob_lm = self.gate.act(gate_logits_lm) + if self.use_correction_bias: + prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[0].detach() + else: + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, axis=-1) + + if self.use_correction_bias: + batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias + + # num_expert_per_modality == 0 时只执行 group-expert expand,不执行 multimodal-expand + expert_id_lm = moe_utils.expand_modality_expert_id( + expert_id_lm, + num_expert_per_modality=num_expert_per_rank_per_modality + if (token_type_ids is not None and gate_logits_mm is not None) + else 0, + group_size=group_size, + modality_offset=0, + is_group_expert=self.group_experts, + ) + expert_id_lm = expert_id_lm.reshape(weight_lm.shape) + lm_weight_and_expert_id = paddle.concat([weight_lm, expert_id_lm.astype("float32")], -1) + if token_type_ids is None or gate_logits_mm is None: + return lm_weight_and_expert_id, prob_lm.reshape([prob_lm.shape[0], -1]), None + + prob_mm = self.gate.act(gate_logits_mm) + if self.use_correction_bias: + prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[1].detach() + else: + prob_mm_ = prob_mm + weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, axis=-1) + if self.use_correction_bias: + batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias + + expert_id_mm = moe_utils.expand_modality_expert_id( + expert_id_mm, + num_expert_per_modality=num_expert_per_rank_per_modality, + group_size=group_size, + modality_offset=1, + is_group_expert=False, + ) + expert_id_mm = expert_id_mm.reshape(weight_mm.shape) + mm_weight_and_expert_id = paddle.concat([weight_mm, expert_id_mm.astype("float32")], -1) + weight_and_expert = paddle.where( + (token_type_ids == 0).unsqueeze(-1), + lm_weight_and_expert_id, + mm_weight_and_expert_id, + ) + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + + diff --git a/test/legacy_test/ernie_utils/moe_layer.py b/test/legacy_test/ernie_utils/moe_layer.py new file mode 100644 index 00000000000000..bf02a5a99c5ab4 --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_layer.py @@ -0,0 +1,207 @@ +# !/usr/bin/env python3 +"""_summary_ + +Returns: + _type_: _description_ +""" +from typing import Any, Tuple, List, Optional, Callable +import logging +from collections import namedtuple +from functools import partial +import inspect +import numpy as np + +import paddle +from paddle import framework +from paddle import nn +from paddle.distributed.communication import stream +import paddle.nn.functional as F + +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.distributed import fleet + +import paddle.distributed as dist +from paddle import Tensor + + + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None +try: + from paddle.distributed import in_auto_parallel_align_mode +except: + + def in_auto_parallel_align_mode(): + """ + hack for paddlenlp develop branch. + """ + return False + + +try: + from bincount_ops import int_bincount +except ImportError: + int_bincount = None + +logger = logging.getLogger(__name__) + +try: + import moe_ops +except ImportError: + moe_ops = None + logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + +class MOELayer(nn.Layer): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + + moe = MOELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (paddle.nn.Layer): + gate network + expert (paddle.nn.LayerList): + expert network, LayerList 长度是 per_device 上的 expert 数。 + group (paddle.ProgressGroup) + recompute: 启用MOE内recomupte + Returns: + output + combine_weight + router-loss + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + """ + 初始化MoE层。 + + Args: + gate (nn.Layer): 智能门控层,用于选择需要使用的专家。 + experts (List[nn.Layer]): 需要使用的专家列表。 + layer_idx (int): 当前MoE层的索引。 + group (Group): 分布式通信组。默认值为None。 + recompute (bool): 是否在每个训练迭代中重新计算MoE输出。默认值为False。 + """ + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.enable_logging = enable_logging + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info(f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}") + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe: + p.is_distributed = True + + self.world_size = dist.get_world_size(self.group) + # assert self.world_size > 1, f'moe-group not found, world_size {self.world_size}' + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = len(self.experts) + self.dispatch_by_task = hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + + if self.dispatch_by_task: + assert 0, f"no supported, checkout earylier code" + assert self.num_local_experts == 1 + + if enable_bpr: + logger.info(f"using BPR") + prepost_process_buffer = {} + self.input_preprocess = partial(bpr_preprocess, buffer=prepost_process_buffer) + self.output_postprocess = partial(bpr_postprocess, buffer=prepost_process_buffer) + else: + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + self._rr_moe_gate_dispatch = None + self._rr_moe_combine = None + if self.config.use_recompute and self.config.skip_recompute_ops.get("moe_gate_dispatch", False): + self._rr_moe_gate_dispatch = RefinedRcomputeMoEGateDispatch() + if self.config.use_recompute and self.config.skip_recompute_ops.get("moe_combine", False): + self._rr_moe_combine = RefinedRcomputeMoECombine() + + +def fuse_logging(gate_logits, combine_weights, token_type_ids): + """fuse_logging""" + with paddle.no_grad(): + gate_expert_per_token_type_0, gate_expert_per_token_type_1 = None, None + gate_experts_per_token = None + ce = moe_router_loss_ops.cal_cross_entropy_info(gate_logits).mean(0) + if token_type_ids is not None: + ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ) = moe_router_loss_ops.cal_gate_experts_per_token_info(combine_weights, token_type_ids) + else: + gate_experts_per_token = paddle.count_nonzero(combine_weights) / (gate_logits.shape[0]) + + return gate_expert_per_token_type_0, gate_expert_per_token_type_1, gate_experts_per_token, ce + diff --git a/test/legacy_test/ernie_utils/moe_layer_uneven.py b/test/legacy_test/ernie_utils/moe_layer_uneven.py new file mode 100644 index 00000000000000..3b166205453341 --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_layer_uneven.py @@ -0,0 +1,279 @@ +# !/usr/bin/env python3 +""" +moe +""" + +from ast import Import +from operator import le +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast, List +import logging +import sys +import inspect +from collections import defaultdict, namedtuple, Counter + +import numpy as np +import paddle +from paddle import nn +from paddle.distributed.communication import stream + +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +import paddle.distributed as dist +from paddle import Tensor +from paddle.nn import functional as F +from paddle.distributed import fleet + + +# from ernie_core.models.moe.moe_layer import _AllToAll +from .top2_gate import TopKGateFused + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 + +logger = logging.getLogger(__name__) + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +if False: + try: + from paddle_xpu_nn import moe_combine as xpu_moe_combine + from paddle_xpu_nn import moe_combine_bwd as xpu_moe_combine_bwd + except ImportError: + xpu_moe_combine = None + xpu_moe_combine_bwd = None + logger.warning("`xpu moe combine` not found") +else: + try: + import moe_ops + except ImportError: + moe_ops = None + logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + + try: + import moe_combine + except ImportError: + moe_combine = None + logger.warning("`moe-combine` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + + +try: + import moe_ops_no_softmaxtopk +except ImportError: + moe_ops_no_softmaxtopk = None + logger.warning( + "moe-ops-no-softmaxtopk` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + + +def average_grad(x, y, dy, eps=1e-12): + """ + TODO: fuse 这坨 shit + y=x/x.sum(-1, keepdim=True) 的反向过程 + """ + s, k = x.shape + xsum = x.sum(axis=-1, keepdim=True) # [s,1] + maskpos = (xsum == 0.0).expand_as(x) + + xsum_square = xsum.square() # [s,1] + left = paddle.triu(paddle.tril((1 / xsum).unsqueeze(-1).expand([s, k, k]))) # aka diag-emb [s,k,k] + right = (-x / xsum_square).unsqueeze(-1).expand([s, k, k]) + dydx = left + right + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze(-2) # [s,1,k] @[s,k,k] -> [s,1,k] + dx = paddle.where(maskpos, paddle.zeros_like(dx), dx) + return dx + + +mask = paddle.to_tensor( + [ + [1, -1], + [-1, 1], + ] +).unsqueeze(0) + + +def average_grad_bi(x, y, dy, eps=1e-12): + """ + y=x/x.sum(-1, keepdim=True) + k=2 下面的反向过程,精度会更准一些: + dx1 = (y2 *dy1 - y2*dy2)/(y1+y2)**2 + dx2 = (y1 *dy2 - y1*dy1)/(y1+y2)**2 + """ + s, k = x.shape + assert k == 2, k + xsum = paddle.clip(x.sum(axis=-1, keepdim=True), min=eps) # [s,1] + dydx = x.flip(axis=1).unsqueeze(-2).tile([1, 2, 1]) * mask.cast(x.dtype) / xsum.square().unsqueeze(-1) + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze(-2) # [s,1,k] @[s,k,k] -> [s,1,k] + return dx + + +def topk_grad(x, dy, indicies): + """ + TODO: fuse 这坨 shit + y=gather(topk(x)) 的反向过程 + x: [s,e] + dy: [s,k] + """ + s, e = x.shape + _, k = dy.shape + dx = paddle.scatter_nd( + paddle.stack( + [ + paddle.arange(s).repeat_interleave(k).cast(indicies.dtype), + indicies.reshape([-1]), + ], + -1, + ), + dy.reshape([-1]), + shape=[s, e], + ) # [s,k] -> [s,e] + return dx # dx 保持高精度 + + +class GateDispatch(PyLayer): + """doc""" + + @staticmethod + def forward(ctx, x, gate_prob, k, capacity, use_pad, eps=1e-12): + """ + 对`gate_prob` 进行 softmax 并根据结果选取 topk 路由expert。 最后根据 expert 号对 `x` 进行重排。 + Args: + x: [s, d] 输入的 activateion + gate_prob: [s, e] + k: int + capacity: int #no use + Returns: + y: [s*k, d] 将所有 `x` 根据其路由的 `expert-id` 升序的排序,融合到 s 维度。 + 当截断发生时 s 会比输入 s 小。 + combine_weights: [s, k], float: 每个 token 第 k 选择的 expert 的权重。 + 当截断发生时 s 会比输入 s 小。 + scatter_index: [k, s] : 每个 token 第 k 次选择对应到 `y` 中的位置。 + expert_offset: [e]: `y`中每个 expert-id 的分割位置。 + expert_id: [s] `x` 中激活的 expert 号 + """ + ctx.k = k + ctx.eps = eps + ctx.capacity = capacity + ctx.gate_prob = gate_prob + if "corr_bias" in inspect.signature(moe_ops.moe_gate_dispatch).parameters: + compat_args = (None,) + else: + compat_args = () + y, combine_weights, scatter_index, expert_offset, expert_id = moe_ops.moe_gate_dispatch( + x, gate_prob, *compat_args, k=k, capacity=capacity, use_pad=use_pad + ) + ctx.combine_weights = combine_weights + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + ctx.scatter_index = scatter_index + ctx.expert_id = expert_id + num_experts = gate_prob.shape[-1] + + ctx.num_experts = num_experts + ctx.seqlen = gate_prob.shape[0] + + return y, combine_weights, scatter_index, expert_offset, expert_id + + @staticmethod + def backward(ctx, dy, dw, *_): + """ + TODO: 这坨代码可以 fuse 一手。 + 关于 softmax 对 logits 的导数,参考: + https://stats.stackexchange.com/questions/215521/ + how-to-find-derivative-of-softmax-function-for-the-purpose-of-gradient-descent/328095#328095 + """ + s, k = ctx.combine_weights.shape + grad = F.embedding(ctx.scatter_index, dy) # [s, k,d] + mask = (ctx.combine_weights > 0.0).astype(grad.dtype) # [s,k] + dx = paddle.matmul(mask.unsqueeze(1), grad).squeeze(1) # [s,1,k] @ [s,k,d] -> [s,1,d] + if ctx.gate_prob.stop_gradient: + return dx, None + + combine_weights_unnorm = ctx.combine_weights + dw = dw.astype(combine_weights_unnorm.dtype) + d_prob = topk_grad(ctx.gate_prob, dw, ctx.expert_id) + return dx, d_prob + + +class GateCombine(PyLayer): + """GateCombine""" + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Input: + x: [seqlen * k, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + y: [seqlen, hidden_size] + """ + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + if False: + assert xpu_moe_combine is not None + return xpu_moe_combine(x, combine_weights, scatter_index) + else: + assert moe_combine is not None + ret = moe_combine.moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Input: + grad_y: [seqlen, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + grad_x: [seqlen * k, hidden_size] + grad_combine_weight: [seqlen, k] + + """ + + if False: + assert xpu_moe_combine_bwd is not None + grad_x, grad_combine_weight_helper = xpu_moe_combine_bwd( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + else: + assert moe_combine is not None + grad_x, grad_combine_weight_helper = moe_combine.moe_combine_bwd( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + + + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [s, k] + scatter_index: ** [k, s] ** + + Returns: + y: Tensor[s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + diff --git a/test/legacy_test/ernie_utils/top2_gate.py b/test/legacy_test/ernie_utils/top2_gate.py new file mode 100644 index 00000000000000..a9be79abb044a4 --- /dev/null +++ b/test/legacy_test/ernie_utils/top2_gate.py @@ -0,0 +1,1012 @@ +# !/usr/bin/env python3 +""" +top2gate +""" + + +from typing import Tuple +from functools import partial +import logging +import numpy as np +import math +import paddle +from paddle import Tensor +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils import unique_name +from paddle.nn.clip import _squared_l2_norm +from paddle.distributed import fleet + + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None + +try: + from custom_setup_ops import matmul_bwd +except ImportError: + matmul_bwd = None + +try: + from bincount_ops import int_bincount +except ImportError: + int_bincount = None + +logger = logging.getLogger(__name__) + + +class CalOrthogonalLossOptEachWeightFunctor(paddle.autograd.PyLayer): + """CalOrthogonalLossOptEachWeightFunctor""" + + @staticmethod + def forward(ctx, gate_weight, moe_k, use_group, eps=1e-12): + """forward""" + if gate_weight.dtype != paddle.float32: + gate_weight = gate_weight.astype(paddle.float32) + ( + orthogonal_loss, + wnorm, + weight_scale, + normed_weight, + weight_matmul, + ) = moe_router_loss_ops.cal_orthogonal_loss_opt_each_weight(gate_weight, moe_k, use_group, eps) + ctx.save_for_backward(gate_weight, wnorm, weight_scale, normed_weight, weight_matmul) + ctx.moe_k = moe_k + ctx.use_group = use_group + ctx.eps = eps + return orthogonal_loss + + @staticmethod + def backward(ctx, out_grad): + """backward""" + gate_weight, wnorm, weight_scale, normed_weight, weight_matmul = ctx.saved_tensor() + if gate_weight.stop_gradient: + return None + moe_k = ctx.moe_k + use_group = ctx.use_group + eps = ctx.eps + return moe_router_loss_ops.cal_orthogonal_loss_opt_each_weight_grad( + out_grad, wnorm, weight_scale, normed_weight, weight_matmul, moe_k, use_group, eps + ) + + +class CalZLossFunctor(paddle.autograd.PyLayer): + """CalZLossFunctor""" + + @staticmethod + def forward(ctx, logits, loss_mask=None, clip_min=1e-6): + """forward""" + if loss_mask is not None: + assert loss_mask.stop_gradient + loss, max_logits, safe_sumexp, logsumexp_per_token = moe_router_loss_ops.cal_z_loss(logits, loss_mask, clip_min) + ctx.save_for_backward(logits, loss_mask, max_logits, safe_sumexp, logsumexp_per_token) + ctx.clip_min = clip_min + return loss + + @staticmethod + def backward(ctx, out_grad): + """backward""" + logits, loss_mask, max_logits, safe_sumexp, logsumexp_per_token = ctx.saved_tensor() + if logits.stop_gradient: + return None + clip_min = ctx.clip_min + return moe_router_loss_ops.cal_z_loss_grad( + out_grad, logits, loss_mask, max_logits, safe_sumexp, logsumexp_per_token, clip_min + ) + + +class CalAuxLossFunctor(paddle.autograd.PyLayer): + """CalAuxLossFunctor""" + + @staticmethod + def forward( + ctx, gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k, clip_min=1e-6 + ): + """forward""" + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + loss, seqlen_float, ce = moe_router_loss_ops.cal_aux_loss( + gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k, clip_min + ) + ctx.save_for_backward(gate_prob, seqlen_float, ce) + ctx.num_experts = num_experts + ctx.use_group = use_group + ctx.moe_k = moe_k + return loss + + @staticmethod + def backward(ctx, out_grad): + """backward""" + gate_prob, seqlen_float, ce = ctx.saved_tensor() + num_experts = ctx.num_experts + use_group = ctx.use_group + moe_k = ctx.moe_k + return moe_router_loss_ops.cal_aux_loss_grad( + out_grad, gate_prob, seqlen_float, ce, num_experts, use_group, moe_k + ) + + +def cal_orthogonal_loss_opt_each_weight_func(weight, moe_k, use_group, eps, xpu_matmul=None, training=True): + """cal_orthogonal_loss_opt_each_weight_func""" + weight = weight.transpose([1, 0]).contiguous() # transpose weight here + wnorm = weight.norm(axis=1) + weight = weight / paddle.maximum(wnorm, eps).unsqueeze(1) + + if use_group: + weight = weight.reshape([moe_k, -1, weight.shape[1]]) # [K, E/K, H] + eye_matrix = paddle.eye(weight.shape[1], dtype=weight.dtype).unsqueeze(0) + else: + eye_matrix = paddle.eye(weight.shape[0], dtype=weight.dtype) + + if False: + weight_matmul = xpu_matmul(weight, weight, transpose_y=True, training=training) + else: + weight_matmul = paddle.matmul(weight, weight, transpose_y=True) + + orthogonal_loss = weight_matmul - eye_matrix + orthogonal_loss = _squared_l2_norm(orthogonal_loss) / orthogonal_loss.size + return orthogonal_loss + + +def cal_z_loss_func(logits, loss_mask): + """cal_z_loss_func""" + # l_zloss = logits.exp().sum(1).log().square().mean() + if loss_mask is not None: + loss_mask = loss_mask.astype(logits.dtype) + l_zloss = (logits.logsumexp(1).square() * loss_mask).sum() / paddle.clip(loss_mask.sum(), min=1e-6) + else: + l_zloss = logits.logsumexp(1).square().mean() + # TODO group_experts 分group计算zloss + return l_zloss + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + """cal_aux_loss_func""" + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if tokens_mask is not None and gate_prob.shape[0] != dispatch_tokens_mask.shape[0]: + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + # me = paddle.mean(gate_prob, axis=0) + # ce = paddle.mean(dispatch_mask.cast("float32"), axis=0) + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + ''' + if scale is not None: + # 前向用局部me, 反向用全局me + l_aux = l_aux + (scale - 1) * l_aux.detach() + ''' + return l_aux + + +def masked_fill(x, mask, value): + """ + 将输入的Tensor中根据mask进行掩盖,并用value值替换。 + + Args: + x (Tensor): 输入的Tensor。 + mask (Tensor): 用于掩盖的布尔Tensor,其形状应与x相同。 + value (Union[float, int]): 需要替换的值。 + + Returns: + Tensor: 返回一个新的Tensor,其形状与x相同,并且根据mask和value进行掩盖和替换。 + + """ + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + """ + Computes the optimal transport matrix and Slinkhorn distance using the + Sinkhorn-Knopp algorithm + + Inputs: + - M : cost matrix (n x m) + - r : vector of marginals (n, ) + - c : vector of marginals (m, ) + - lam : strength of the entropic regularization + - epsilon : convergence parameter + + Outputs: + - P : optimal transport matrix (n x m) + - dist : Sinkhorn distance + """ + n, _ = M.shape + # P = (- lam * M).exp() + # P /= P.sum() + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + # normalize this matrix + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + """ + FusedGateDetachMatmul + """ + + @staticmethod + def forward(ctx, x, w): + """ + forward + """ + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + + @staticmethod + def backward(ctx, y_grad): + """ + backward + """ + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = matmul_bwd(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype), y_grad, False, False) + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse): + """ + gate_detach_matmul + """ + if use_fuse: + return FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + return F.linear(x, weight) + + +class Top2Gate(nn.Layer): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + """ + 初始化 MoE 层,包含参数初始化和一些其他功能。 + + Args: + layer_idx (int): 当前层的索引号。 + group: 分组名称。 + + Returns: + None: 不返回任何内容。 + """ + super().__init__() + if False: + try: + from paddle_xpu.layers.nn import xpu_matmul + + self.xpu_matmul = xpu_matmul() + except ImportError: + self.xpu_matmul = None + + self.config = config + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + if self.fuse_gate_detach_matmul: + assert matmul_bwd is not None, "matmul_bwd is not supported" + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) if config.multimodel_experts else config.moe_num_experts + ) # paddle.to_tensor(config.moe_num_experts, dtype="float32").sum() + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_token_type_bias = config.moe_use_token_type_bias + self.use_correction_bias = config.moe_use_aux_free + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) # [S,E] + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.enable_logging = config.moe_logging + self.norm_gate_logits = config.moe_norm_gate_logits + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor(config.moe_aux_loss_lambda, dtype="float32") + self.moe_z_loss_lambda = paddle.to_tensor(config.moe_z_loss_lambda, dtype="float32") + self.moe_orthogonal_loss_lambda = paddle.to_tensor(config.moe_orthogonal_loss_lambda, dtype="float32") + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze(0) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap and not pp_config.sharding_comm_overlap + ), f"orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.moe_use_hard_gate: + self.num_experts_list = [] + self.experts_type_mask = [] + # hard-gate + group_experts 需要对gate_logits不同部分分开计算 + experts_ids = paddle.zeros([sum(self.num_experts)], dtype="int64").reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[:, offset : offset + expert_num // config.moe_world_size] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info(f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}") + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + # 非group_experts, 依赖token_type_bias实现hard-gate能力。 + assert not config.moe_group_experts, "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + # 强制在amp下任使用fp32精度 + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precison = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + """ + 创建参数权重。 + + Args: + None + + Returns: + weight (Parameter): 创建的参数权重。 + + """ + if self.config.multimodel_experts: + # support setting lambda for each expert group + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand(len(self.num_experts)) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand(len(self.num_experts)) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand(len(self.num_experts)) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + "weight" if i == 0 else f"weight_{i}", # 为了对齐原 state-dict,第一个 gate-weight 不改名. + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), # 特殊处理,有利于热启 dense-ckpt + ) + logger.info(f"moe-Gate, {self.weight}") + + if self.use_token_type_bias: + if self.config.multimodel_experts: + assert ( + not self.config.moe_use_hard_gate + ), "multimodel_experts with hard_gate is not support token_type_bias." + num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + bias_type_num = len(self.num_experts) if self.config.multimodel_experts else 1 + self.bias = self.create_parameter( + shape=[bias_type_num, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate_bias"), + initializer=paddle.nn.initializer.Assign(np.zeros([bias_type_num, num_experts])), + ), # 特殊处理,有利于热启 dense-ckpt + ) + logger.info(f"using token type bias, bias: {self.bias},") + # 强制在amp下任使用fp32精度 + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight): + """ + 在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 + transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 + """ + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [getattr(self, "weight" if i == 0 else f"weight_{i}") for i in range(len(self.num_experts))], -1 + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[:, :, offset : offset + num_experts // self.config.moe_world_size] = getattr( + self, "weight" if i == 0 else f"weight_{i}" + ).reshape([self.model_dim, self.config.moe_world_size, -1]) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + + return weight + + def forward( + self, + input: Tensor, + token_type_ids: Tensor = None, + transform_weight: bool = True, # [seq] + correction_bias: Tensor = None, # [seq] + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + """ + Args: + input: paddle.Tensor[Seq, Dim], hidden-states of layer + token_type_ids: paddle.Tensor[Seqw], token_type_ids of input + transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified + Retruns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + orig_dtype = input.dtype + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if False: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + + if self.use_token_type_bias: + assert token_type_ids is not None + bias = self.bias[token_type_ids] # [seq] + # logger.info(f"adding bias: {bias}") + logits = logits + bias + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + l_aux, + l_zloss, + ) = self.top2_gating(logits, correction_bias=correction_bias) + orthogonal_loss = self._cal_orthogonal_loss() + router_loss = ( + l_aux * self.moe_aux_loss_lambda + + l_zloss * self.moe_z_loss_lambda + + orthogonal_loss * self.moe_orthogonal_loss_lambda + ) + router_loss.stop_gradient = False + + combine_weights = combine_weights.cast(orig_dtype) + return capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits + + def get_capacity(self, num_tokens, cap_factor=None): + """ + return capcity + """ + num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + # capacity = 2S/E + capacity = int(cap * num_tokens // num_experts) + assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + """ + Args: + logits: 形状为[batch, vocab_size]的logits,用于计算top2 gate。 + cap[Optional]: capacity-factor, if none, read from config + correction_bias[Optional]: used for aux-free router + + Returns: + tuple: + - capacity: 每个token可分发的最大数量。 + - dispatch_masks: 用于dispatching的mask。第一个元素是第一类token的mask;第二个元素是第二类token的mask。 + - combine_weights:用于combining的权重。第一个元素是第一类token的权重;第二个元素是第二类token的权重。 + - scatter_indexes: 用于scattering的索引。第一个元素是第一类token的索引;第二个元素是第二类token的索引。 + - loss_aux: aux loss。 + - loss_z: z loss。 + """ + # logger.info(f'gate-input: {logits}') + l_zloss = self._cal_z_loss(logits) + gates = self.act(logits) + + # gates has shape of SE + assert logits.ndim == 2, logits.shape + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # capacity = 2S/E + capacity = self.get_capacity(logits.shape[0], cap) + + # Create a mask for 1st's expert per token + score_for_argmax = gates + correction_bias.unsqueeze(0) if correction_bias is not None else gates + indices1_s = paddle.argmax(score_for_argmax, axis=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast(paddle.int64) # [0,1] + + l_aux = self._cal_aux_loss(gates, mask1.sum(axis=0), self.num_experts_tensor) + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + if self.training and not self.no_jitter: + gumbels = ( + -paddle.empty_like( + logits, + ) + .exponential_() + .log() + ) # ~Gumbel(0,1) + logits_w_noise = logits + gumbels + else: + logits_w_noise = logits + + logits_except1 = masked_fill(logits_w_noise, mask1.cast(paddle.bool), float("-inf")) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) if correction_bias is not None else logits_except1 + ) + indices2_s_original = paddle.argmax(score_for_argmax, axis=1) + + if self.training and self.sinkhorn_2gate: + r = paddle.ones(num_tokens, "float32") / num_tokens + # c = paddle.ones(num_experts, "float32") / num_experts + # 非均匀c + c = capacity - mask1.cast("float32").sum(0) + c = paddle.maximum(c, paddle.zeros_like(c)) + c /= c.sum() + + pi, _ = compute_optimal_transport(-logits_except1.cast("float32").detach(), r, c, lam=self.sinkhorn_temp) + pi = masked_fill(pi, mask1.cast(paddle.bool), float("-inf")) + indices2_s = paddle.argmax(pi, axis=1) + else: + indices2_s = indices2_s_original + + + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast(paddle.int64) + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] + locations2 = paddle.cumsum(mask2, axis=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + # Remove locations outside capacity from mask + mask1 *= (locations1 < capacity).cast(paddle.int64) # [0,1,1,0,0,0,0] + mask2 *= (locations2 < capacity).cast(paddle.int64) + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = (gates * mask1_float).sum(axis=-1) + gates2_s = (gates * mask2_float).sum(axis=-1) + # logger.info(f'gates1_s:{gates1_s} gates2_s:{gates2_s} logits:{logits}') + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s # [0.2, 0.3] + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + # log.debug(gates2_s) + gates2_s = paddle.where( + 2 * gates2_s < paddle.rand_like(gates2_s), + paddle.zeros_like(gates2_s), + gates2_s, + ) + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + expert1_index = paddle.argmax(gates1, -1) + combine1_weight = paddle.max(gates1, -1, keepdim=True) + scatter1_index = expert1_index * capacity + locations1_s + scatter1_index = scatter1_index.cast("int64") + dispatch1_mask = combine1_weight.cast(paddle.bool).detach() + + expert2_index = paddle.argmax(gates2, -1) + combine2_weight = paddle.max(gates2, -1, keepdim=True) + scatter2_index = expert2_index * capacity + locations2_s + scatter2_index = scatter2_index.cast("int64") + dispatch2_mask = combine2_weight.cast(paddle.bool).detach() + # logger.info(f'expert-id: {expert1_index} vs {expert2_index}, mask:{mask1_float} vs {mask2_float}') + + return ( + capacity, + paddle.concat((dispatch1_mask, dispatch2_mask), 1), + paddle.concat((combine1_weight, combine2_weight), 1), + paddle.stack((scatter1_index, scatter2_index), 1), + l_aux, + l_zloss, + ) + + def _cal_aux_loss( + self, gate_prob, dispatch_mask, num_experts=None, use_group=None, tokens_mask=None, dispatch_tokens_mask=None + ): + """ + 计算辅助损失 + + Args: + gate_prob (paddle.Tensor[local_seq, num_experts]): + dispatch_mask (paddle.Tensor[num_experts]): 每个 expert 被分配的 token 数(不考虑 token drop) + tokens_mask (paddle.Tensor[Seq]): 每个 MP 内 token-type-id + dispatch_tokens_mask (paddle.Tensor): AllGather 后的`tokens_mask` + Returns: + paddle.Tensor: 辅助损失值。 + + """ + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk(k=self.config.moe_k, axis=-1) + if int_bincount is not None: + dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + else: + mask = paddle.zeros_like(gate_prob_this_modality).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + if int_bincount is not None: + dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + else: + mask = paddle.zeros_like(gate_prob).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + if ( + moe_router_loss_ops is not None + and (tokens_mask is None or len(tokens_mask.shape) == 1) + and (tokens_mask is None or tokens_mask.shape[0] == gate_prob.shape[0]) + and (gate_prob.shape[0] >= gate_prob.shape[1]) + and (not self.global_aux_loss) + and (gate_prob.dtype == paddle.float32) + ): + return CalAuxLossFunctor.apply( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + clip_min=1e-6, + ) + else: + return cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + self.global_aux_loss, + self.rank if self.global_aux_loss else None, + self.group if self.global_aux_loss else None, + ) + + def _cal_z_loss(self, logits, loss_mask=None): + """ + 计算 Z 损失。 + + Args: + logits (torch.Tensor): Logits Tensor,形状为 [batch_size, num_classes]。 + + Returns: + torch.Tensor: Z 损失 Tensor,形状为 []。 + + """ + if ( + (moe_router_loss_ops is not None) + and (loss_mask is None or len(loss_mask.shape) == 1) + and (logits.dtype == paddle.float32) + ): + return CalZLossFunctor.apply(logits, loss_mask) + else: + return cal_z_loss_func(logits, loss_mask) + + def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): + """ + gate正交loss(优化版) + """ + if weight.dtype != paddle.float32: + weight = weight.astype(paddle.float32) + + if (moe_router_loss_ops is not None) and (weight.dtype == paddle.float32): + return CalOrthogonalLossOptEachWeightFunctor.apply(weight, self.config.moe_k, use_group) + else: + return cal_orthogonal_loss_opt_each_weight_func( + weight, self.config.moe_k, use_group, self.eps, self.xpu_matmul, self.training + ) + + def _cal_orthogonal_loss(self, weight_id=None, use_group=None): + """ + gate正交Loss + """ + if use_group is None: + use_group = self.config.moe_group_experts and self.config.moe_group_orthogonal_loss + + if weight_id is not None: + if weight_id == 0: + w_ = self.weight + else: + assert self.config.multimodel_experts + w_ = getattr(self, f"weight_{weight_id}") + return self._cal_orthogonal_loss_opt_each_weight(w_, use_group) + + orthogonal_loss = self._cal_orthogonal_loss_opt_each_weight(self.weight, use_group) + if self.config.multimodel_experts: + for i in range(1, len(self.config.moe_num_experts)): + w_ = getattr(self, f"weight_{i}") + orthogonal_loss += self._cal_orthogonal_loss_opt_each_weight(w_, use_group=False) + return orthogonal_loss + + +class TopKGateFused(Top2Gate): + """doc""" + + def forward( + self, + input: Tensor, + token_type_ids=None, + transform_weight=True, + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + """ + Args: + input: paddle.Tensor, hidden-states of layer + token_type_ids: paddle.Tensor[Seqw], token_type_ids of input + transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified + Retruns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + capacity = self.get_capacity(input.shape[0]) + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if False: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] # [seq] + logits = logits + bias + orthogonal_loss = None + # 正交 loss 拿到 moe-layer 里去计算 + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + + return logits, capacity, router_loss + + +class DeepEPTop2Gate(TopKGateFused): + """DeepEPTop2Gate""" + + def forward(self, input, transform_weight=True, global_gate_mask=None, input_ids=None): + """forward""" + + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + + if global_gate_mask is not None: + logits = logits + global_gate_mask + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + return logits, router_loss + + def _cal_aux_loss(self, gates, dispatch_mask, input_ids=None): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. + The shape is [seq_len, num_experts] + dispatch_mask: (paddle.Tensor): Represents the number of tokens for each expert. + The shape is [num_experts] + topk_indices: + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + assert len(gates.shape) == 2, "gates.shape must be [sequence_lengh, num_experts]" + if input_ids is not None: + # has_padding = (input_ids == 0).any() + assert input_ids.shape[0] == gates.shape[0], f"check input_ids shape {input_ids.shape}" + valid_mask = (input_ids != 0).astype(paddle.float32) + seqlen_float = valid_mask.sum().item() + gates = gates * valid_mask.unsqueeze(-1) + else: + seqlen_float = float(gates.shape[0]) + me = paddle.sum(gates, axis=0) / seqlen_float + ce = dispatch_mask.astype(gates.dtype).detach() / seqlen_float + + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + if seqlen_float == 0: + return paddle.to_tensor(0.0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = paddle.logsumexp(logits, axis=1).square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss diff --git a/test/legacy_test/test_incubate_expand_modality_expert_id.py b/test/legacy_test/test_incubate_expand_modality_expert_id.py new file mode 100644 index 00000000000000..699c55131b33a6 --- /dev/null +++ b/test/legacy_test/test_incubate_expand_modality_expert_id.py @@ -0,0 +1,126 @@ +import os +import unittest + +import paddle +from paddle import _C_ops +import numpy as np +from op_test import convert_float_to_uint16 +import sys +from functools import partial +from collections import namedtuple +import paddle.nn.functional as F + +from paddle.autograd import PyLayer + +from paddle import base +from paddle.base import core +import paddle.incubate.nn.functional.expand_modality_expert_id as expand_modality_expert_id +from ernie_utils.moe_all_gather_layer import MOEAllGatherLayerV2 + +def fused_gate_logits_process_ref(self, gate_logits_lm, gate_logits_mm, token_type_ids): + """process gatelogits""" + top_k = self.k + num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] // self.config.moe_world_size + + @paddle.no_grad() + def shift_ids(ids, modality_offset): + # 现在认为所以模态的 expert 数都一样 + rank = ids // num_expert_per_rank_per_modality + expert_id_in_rank = ids % num_expert_per_rank_per_modality + return ( + rank * (num_expert_per_rank_per_modality * 2) + + expert_id_in_rank + + modality_offset * num_expert_per_rank_per_modality + ) + + if self.group_experts: + gate_logits_lm = gate_logits_lm.reshape([gate_logits_lm.shape[0], top_k, -1]) + prob_lm = self.gate.act(gate_logits_lm) + weight_lm, expert_id_lm = prob_lm.topk(k=1, axis=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + expert_id_lm = expert_id_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + scale = paddle.arange(0, top_k * group_size, group_size).unsqueeze(0) + expert_id_lm = expert_id_lm + scale + else: + prob_lm = self.gate.act(gate_logits_lm) + weight_lm, expert_id_lm = prob_lm.topk(k=top_k, axis=-1) + if token_type_ids is not None: + expert_id_lm = shift_ids(expert_id_lm, 0) + expert_id_lm.stop_gradient = True + lm_weight_and_expert_id = paddle.concat([weight_lm, expert_id_lm.astype("float32")], -1) + if token_type_ids is None: + return lm_weight_and_expert_id, prob_lm.reshape([prob_lm.shape[0], -1]), None + + prob_mm = self.gate.act(gate_logits_mm) + weight_mm, expert_id_mm = prob_mm.topk(k=top_k, axis=-1) + + expert_id_mm = shift_ids(expert_id_mm, 1) + expert_id_mm.stop_gradient = True + + mm_weight_and_expert_id = paddle.concat([weight_mm, expert_id_mm.astype("float32")], -1) + + token_type_ids_float = token_type_ids[:, None].astype("float32") + weight_and_expert = ( + 1 - token_type_ids_float + ) * lm_weight_and_expert_id + token_type_ids_float * mm_weight_and_expert_id + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + +def test_expand_modality_expert_id(): + def expand_id_one(expert_id, num_expert_per_modality, k, group_size, modality_offset, is_group_expert): + orig_shape = expert_id.shape + expert_id = expert_id.reshape([-1]) + xid = paddle.arange(len(expert_id)) + if is_group_expert: + eid = xid % k + expert_id += eid * group_size + + rank = expert_id // num_expert_per_modality + expert_id_in_rank = expert_id % num_expert_per_modality + ret = rank * (num_expert_per_modality * 2) + expert_id_in_rank + modality_offset * num_expert_per_modality + return ret.reshape(orig_shape) + + S, E, k = 100, 24, 3 + expert_id_mm = paddle.randint(0, 12, shape=[S, k]) + num_expert_per_rank_per_modality = E // 2 // 4 + group_size = E // 2 // k + print(f"num_expert_per_rank_per_modality: {num_expert_per_rank_per_modality}") + fused = expand_modality_expert_id(expert_id_mm, num_expert_per_rank_per_modality, group_size, 1, True) + + nonfused = expand_id_one(expert_id_mm, num_expert_per_rank_per_modality, k, group_size, 1, True) + # num_expert_per_rank_per_modality, group_size + assert (fused == nonfused).all().item() + + Config = namedtuple("Config", ["moe_world_size"]) + Self = namedtuple("Self", ["config", "k", "gate", "group_experts", "moe_statics", "use_correction_bias"]) + Gate = namedtuple("Gate", ["act"]) + fake_gate = Gate(act=partial(F.softmax, axis=-1)) + fake_self = Self( + config=Config( + moe_world_size=8, + ), + k=k, + gate=fake_gate, + moe_statics=None, + use_correction_bias=False, + group_experts=True, + ) + + fake_logits = paddle.randn([S, E]) + fake_logits_mm = paddle.randn([S, E]) + token_type_ids = paddle.randint(0, 2, shape=[S]) + w_and_e, prob_lm, prob_mm = MOEAllGatherLayerV2.fused_gate_logits_process_fused( + fake_self, fake_logits, fake_logits_mm, None + ) + w_and_e_ref, prob_lm_ref, prob_mm_ref = fused_gate_logits_process_ref(fake_self, fake_logits, fake_logits_mm, None) + assert (prob_lm == prob_lm_ref).all().item() + assert (w_and_e == w_and_e_ref).all().item() + w, e = w_and_e_ref.chunk(2, axis=-1) + +class Test_expand_modality_expert_id_API(unittest.TestCase): + def test_dygraph(self): + test_expand_modality_expert_id() + +if __name__ == "__main__": + + unittest.main() \ No newline at end of file diff --git a/test/legacy_test/test_incubate_moe_combine.py b/test/legacy_test/test_incubate_moe_combine.py new file mode 100644 index 00000000000000..abc9a7b95f0645 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_combine.py @@ -0,0 +1,175 @@ +import os +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 +import random +import paddle.nn.functional as F + +import paddle +from paddle import base +from paddle.base import core +from paddle.incubate.nn.functional import moe_combine +from ernie_utils.moe_layer_uneven import GateCombine + + +os.environ["FLAGS_flash_attn_version"] = "v1" +os.environ["FLAGS_cudnn_deterministic"] = "1" +os.environ["FLAGS_embedding_deterministic"] = "1" + + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [seq, k] + scatter_index: ** [seq, k] ** + + Returns: + y: Tensor[s, dim] + """ + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + if hard_gate: + return x_gatherd.squeeze(-2) + # logger.info(f'combinning: {combine_weights}') + y = (combine_weights.unsqueeze(-1) * x_gatherd).sum(1) + # y = paddle.matmul(combine_weights.unsqueeze(1), x_gatherd).squeeze() # [s,1,k] @ [s,k,dim] -> [s,1,dim] + return y + + +def baseline_result(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy): + """baseline_result""" + scatter_index = paddle.to_tensor(scatter_index_numpy) + x = paddle.to_tensor(x_numpy).cast("float32") + x.stop_gradient = False + + combine_weights = paddle.to_tensor(combine_weights_numpy).cast("float32") + combine_weights.stop_gradient = False + + scatter_index = paddle.to_tensor(scatter_index_numpy) + grad = paddle.to_tensor(grad_numpy).cast("float32") + + y = combining(x, combine_weights, scatter_index) + paddle.autograd.backward([y], [grad], True) + return [x.grad, combine_weights.grad, y] + + +def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy): + """baseline_result""" + x = paddle.to_tensor(x_numpy).cast("float32") + x.stop_gradient = False + + combine_weights = paddle.to_tensor(combine_weights_numpy).cast("float32") + combine_weights.stop_gradient = False + + scatter_index = paddle.to_tensor(scatter_index_numpy).cast("int32") + grad = paddle.to_tensor(grad_numpy).cast("float32") + + y = GateCombine.apply(x, combine_weights, scatter_index) + paddle.autograd.backward([y], [grad], True) + return [x.grad, combine_weights.grad, y] + + +def gen_test_case(S, K, Dim, capacity_factor, seed=1234): + """gen_test_case""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + x_numpy = np.random.rand(int(S * capacity_factor), Dim).astype(np.float32) + combine_weights_numpy = np.random.rand(S, K).astype(np.float32) + scatter_index_numpy = np.random.permutation(max(x_numpy.shape[0], S * K))[: S * K].astype("int64") + scatter_index_numpy = scatter_index_numpy.reshape([S, K]) + + combine_weights_numpy[scatter_index_numpy >= x_numpy.shape[0]] = 0 + scatter_index_numpy[scatter_index_numpy >= x_numpy.shape[0]] = 0 + grad_numpy = np.random.randn(S, Dim).astype(np.float32) + return x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy + + +def testing(test_case): + """testing""" + [bl_x_grad, bl_combine_weights_grad, bl_y] = baseline_result(*test_case) + [fused_x_grad, fused_combine_weights_grad, fused_y] = test_moe_combine(*test_case) + np.testing.assert_allclose( + fused_y.astype("float32").numpy(), bl_y.astype("float32").numpy(), err_msg="fwd precision not pass", rtol=1e-6 + ) + np.testing.assert_allclose( + fused_x_grad.astype("float32").numpy(), + bl_x_grad.astype("float32").numpy(), + rtol=1e-6, + err_msg="bwd grad precision not pass", + ) + np.testing.assert_allclose( + fused_combine_weights_grad.astype("float32").numpy(), + bl_combine_weights_grad.astype("float32").numpy(), + rtol=1e-6, + ) + +class TestFused(unittest.TestCase): + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_lt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=1.8)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_eq_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=2)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_gt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=2.2)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_k_gt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=8, Dim=4096, capacity_factor=2)) + +if __name__ == "__main__": + + unittest.main() \ No newline at end of file From 9e84e5b6a8761a843c49af2a871f80278fbcea88 Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Mon, 26 May 2025 06:51:52 +0000 Subject: [PATCH 29/71] add python interface --- .../build_src_rank_and_local_expert_id.py | 2 +- .../incubate/nn/functional/cal_aux_loss.py | 2 +- .../nn/functional/moe_gate_dispatch.py | 96 +++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 python/paddle/incubate/nn/functional/moe_gate_dispatch.py diff --git a/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py index 69f0a1fca12704..25195c236b629c 100644 --- a/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py +++ b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py @@ -64,4 +64,4 @@ def build_src_rank_and_local_expert_id( attrs=attrs, outputs=outputs, ) - return vector + return vector, local_expert_id diff --git a/python/paddle/incubate/nn/functional/cal_aux_loss.py b/python/paddle/incubate/nn/functional/cal_aux_loss.py index e759a62b77af6f..56676fa18b7d28 100644 --- a/python/paddle/incubate/nn/functional/cal_aux_loss.py +++ b/python/paddle/incubate/nn/functional/cal_aux_loss.py @@ -87,4 +87,4 @@ def cal_aux_loss( helper.append_op( type='cal_aux_loss', inputs=inputs, attrs=attrs, outputs=outputs ) - return l_aux_loss + return l_aux_loss, seqlen_float, ce diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py new file mode 100644 index 00000000000000..e0c9051804bfae --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def moe_gate_dispatch( + x: Tensor, + gate_logits: Tensor, + corr_bias: Tensor, + k: int, + capacity: int, + use_pad: bool, + name: str | None = None, +) -> Tensor: + """ + Args: + x, + gate_logits, + corr_bias, + k, + capacity, + use_pad + + Returns: + y, + combine_weights, + scatter_index, + expert_offset, + expert_id + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch( + x, gate_logits, corr_bias, k, capacity, use_pad + ) + + helper = LayerHelper('moe_gate_dispatch', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights = helper.create_variable_for_type_inference( + dtype=paddle.float32 + ) + scatter_index = helper.create_variable_for_type_inference( + dtype=paddle.int32 + ) + expert_offset = helper.create_variable_for_type_inference( + dtype=paddle.int64 + ) + expert_id = helper.create_variable_for_type_inference(dtype=paddle.int32) + + inputs = { + 'x': x, + 'gate_logits': gate_logits, + 'corr_bias': corr_bias, + } + attrs = { + 'k': k, + 'capacity': capacity, + 'use_pad': use_pad, + } + outputs = { + 'y': y, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index, + 'expert_offset': expert_offset, + 'expert_id': expert_id, + } + helper.append_op( + type='moe_gate_dispatch', + inputs=inputs, + attrs=attrs, + outputs=outputs, + ) + return y, combine_weights, scatter_index, expert_offset, expert_id From 6b815883bb9691bb936bf3070feccd734933a9ca Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Mon, 26 May 2025 08:50:05 +0000 Subject: [PATCH 30/71] nosoftmax forward has finished --- paddle/phi/infermeta/ternary.cc | 110 ++++ paddle/phi/infermeta/ternary.h | 17 + paddle/phi/kernels/CMakeLists.txt | 1 + paddle/phi/kernels/gpu/moe_fuse_op.h | 344 ++++++++++++ paddle/phi/kernels/gpu/moe_kernel_impl.h | 14 +- .../moe_ops_partial_nosoftmaxtopk_kernel.cu | 523 ++++++++++++++++++ .../moe_ops_partial_nosoftmaxtopk_kernel.h | 38 ++ paddle/phi/ops/yaml/ops.yaml | 10 + .../moe_ops_partial_nosoftmaxtopk.py | 63 +++ 9 files changed, 1113 insertions(+), 7 deletions(-) create mode 100644 paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu create mode 100644 paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h create mode 100644 python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index bdd4e074bdfe87..5b8f7447eb3361 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1630,6 +1630,116 @@ void MoeCombineInferMeta(const MetaTensor& x, y->set_dtype(x.dtype()); } +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local){ + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + common::errors::InvalidArgument("The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); + auto combine_weights_dims = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + combine_weights_dims.size(), + 2, + common::errors::InvalidArgument("The dimensions of Input(combine_weights) must be 2, but " + "received dimensions of" + "Input(combine_weights) is [%d]", + combine_weights_dims.size())); + PADDLE_ENFORCE_EQ( + combine_weights_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(combine_weights) must be equal to the first " + "dimension of Input(x), but received Input(combine_weights) shape is [%d]," + "Input(x) shape is [%d]", + combine_weights_dims[0], + x_dims[0])); + PADDLE_ENFORCE_GT( + expert_end_index, + expert_start_index, + common::errors::InvalidArgument( + "expert_end_index must be greater than expert_start_index, but received " + "expert_end_index = %d, expert_start_index = %d", + expert_end_index, + expert_start_index)); + PADDLE_ENFORCE_EQ( + combine_weights.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(combine_weights) must be FLOAT32, but received %s", + combine_weights.dtype())); + PADDLE_ENFORCE_EQ( + expert_id.dtype(), + phi::DataType::INT32, + common::errors::InvalidArgument( + "The dtype of Input(expert_id) must be INT32, but received %s", + expert_id.dtype())); + PADDLE_ENFORCE_GT( + k, + 0, + common::errors::InvalidArgument( + "k must be greater than 0, but received k = %d", + k)); + PADDLE_ENFORCE_GT( + x_dims[0], + 0, + common::errors::InvalidArgument( + "num_rows must be greater than 0, but received num_rows = %d", + x_dims[0])); + PADDLE_ENFORCE_GE( + num_experts, + k, + common::errors::InvalidArgument( + "num_experts must be greater than or equal to k, but received num_experts = %d, k = %d", + num_experts, + k)); + PADDLE_ENFORCE_EQ( + !reverse_token_drop || !use_pad, + true, + common::errors::InvalidArgument( + "use_pad must be false when reverse_token_drop is true, but received use_pad = %d, reverse_token_drop = %d", + use_pad, + reverse_token_drop)); + PADDLE_ENFORCE_EQ( + combine_weights.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(combine_weights) must be FLOAT32, but received %s", + combine_weights.dtype())); +int64_t num_experts_diff = expert_end_index - expert_start_index; +int64_t num_rows = x_dims[0]; +// if (use_pad) +// y->set_dims({num_experts_diff * capacity, x_dims[1]}) ; +y->set_dims({-1, x_dims[1]}); +y->set_dtype(x.dtype()); +scatter_index->set_dims({k, num_rows}); +scatter_index->set_dtype(phi::DataType::INT32); +scatter_index_rev->set_dims({num_experts*capacity}); +scatter_index_rev->set_dtype(phi::DataType::INT32); +expert_offset->set_dims({num_experts}); +expert_offset->set_dtype(phi::DataType::INT64); +expert_nums_local->set_dims({num_experts}); +expert_nums_local->set_dtype(phi::DataType::INT64); +combine_weights_out->share_meta(combine_weights); +} + void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, const MetaTensor& gate_logits, const MetaTensor& corr_bias, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 9b2c34aed451ae..ad729e7bd30afe 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -274,6 +274,23 @@ void MoeCombineInferMeta(const MetaTensor& x, const MetaTensor& scatter_index, MetaTensor* y); +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local); + void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, const MetaTensor& gate_logits, const MetaTensor& corr_bias, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 7eee630a8cb34a..b2f485c4d242ae 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -228,6 +228,7 @@ if(WITH_ROCM) list( REMOVE_ITEM kernel_gpu + "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" "gpu/moe_gate_dispatch_permute_kernel.cu" "gpu/expand_modality_expert_id_kernel.cu" "gpu/moe_combine_kernel.cu" diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.h b/paddle/phi/kernels/gpu/moe_fuse_op.h index 08b65f6dfd58d9..43249d5fc35404 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_op.h +++ b/paddle/phi/kernels/gpu/moe_fuse_op.h @@ -2,6 +2,13 @@ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" #include "paddle/phi/kernels/gpu/moe_kernel_impl.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/common/enforce.h" +#include // 包含常用的 thrust 算法 +#include +#include +#include +#include template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, @@ -448,3 +455,340 @@ void initialize_moe_routing_permute_kernelLauncher(const T* unpermuted_input } } +// moe_ops_partial_nosoftmaxtopk utils + +template +void compute_global_expert_offset(const T* expert_id, //[len] + T* sort_buffer, //[len] + int64_t* expert_offset,//[num_experts] + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator){ + auto ptr = thrust::device_pointer_cast(expert_id); + auto outptr = thrust::device_pointer_cast(sort_buffer); + auto offsetptr = thrust::device_pointer_cast(expert_offset); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::copy(exec_policy, ptr, ptr + len, outptr); + thrust::sort(exec_policy, outptr, outptr + len); + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sort_buffer, len, num_experts, expert_offset); + thrust::adjacent_difference(exec_policy, offsetptr, offsetptr + num_experts, offsetptr); + // thrust::transform(offsetptr, + // offsetptr + num_experts, + // thrust::constant_iterator(capacity), + // offsetptr, + // thrust::minimum() + // ); +} + +template +__global__ void modify_and_mask_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index + ){ + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) + return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d, s=%d, e=%d\n", idx, expert_id[idx], ik, expert_start_index, expert_end_index); + int offset = log2(k) + 1; + if (expert_id[idx] < expert_start_index || expert_id[idx] >= expert_end_index){ + expert_id_out[idx] = (num_experts << offset) ; // -1 means + }else{ + expert_id_out[idx] = (expert_id[idx]< +void modify_and_mask_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index, + const cudaStream_t& stream){ + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_and_mask_expert_id<<>>( + expert_id, + expert_id_out, + k, + num_rows, + num_experts, + expert_start_index, + expert_end_index + ); +} + +template +void compute_local_expert_offset(const T* sorted_expert_id, //[len] + int64_t* expert_offset,//[num_experts] + int64_t* expert_num, + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator){ + auto offset_ptr = thrust::device_pointer_cast(expert_offset); + auto expert_num_ptr = thrust::device_pointer_cast(expert_num); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::fill(exec_policy, offset_ptr, offset_ptr + num_experts, static_cast(0)); + + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_expert_id, len, num_experts, expert_offset); + // 不考虑 capcity 影响 + thrust::adjacent_difference(exec_policy, offset_ptr, offset_ptr + num_experts, expert_num_ptr); +} + +template +__global__ void cal_expert_size_and_filter( + T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capcity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse){ + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= len) + return; + int64_t off = reverse? expert_offset[expert_end_index-1] : 0; + if (reverse){ + for (int64_t i = expert_end_index - 1; i >= expert_start_index; --i){ + if (idx >= expert_offset[i]) + break; + off = expert_offset[i]; + } + }else{ + for (int64_t i = expert_start_index; i != expert_end_index; ++i){ + if (idx < expert_offset[i]) + break; + off = expert_offset[i]; + } + } + if (reverse){ + if(((off-1) - idx) >= capcity){ + expert_id[idx] = num_experts; + } + }else{ + if ((idx - off) >= capcity){ + expert_id[idx] = num_experts; + } + } +} + +template +void cal_expert_size_and_filter_launcher(T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capcity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse, + const cudaStream_t& stream){ + if (len <= 0) + return; + const int64_t threads = std::min(static_cast(1024), len); + const int64_t blocks = (len + threads - 1) / threads; + cal_expert_size_and_filter<<>>( + expert_id, + expert_offset, + len, + num_experts, + capcity, + expert_start_index, + expert_end_index, + reverse + ); +} + +template +__global__ void build_seqsort_kv_pairs_kernel( T* seqsort_key, + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int k, + const int64_t num_active, + const int64_t capacity, + int64_t expert_start_index, + bool use_pad) +{ + const int expanded_dest_row = blockIdx.x * blockDim.x + threadIdx.x; + if (expanded_dest_row >= num_rows * k){ + return; + } + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + // printf("DEBUG %d=>%d, num_active=%lld, offset=%lld, cap=%lld \n", expanded_dest_row, expanded_source_row, num_active, row_in_expert, capacity); + // 从此以后不会发生截断,后续的 seqsort 也不会截断。 + // printf("expanded_dest_row:%d row_in_expert:%lld capacity:%lld num_active:%lld\n", expanded_dest_row, row_in_expert, capacity, num_active); + if ((use_pad && row_in_expert >= capacity) || expanded_dest_row >= num_active){ + // expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; //unset combine-weight + return; + } + + // auto num_padded = use_pad ? (iexpert - expert_start_index) * capacity - offset : 0; + // expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; + + // Duplicate and permute rows + T source_row = expanded_source_row % num_rows; + + if (use_pad){ + // printf("inner print: k=%d num_row=%d before minus %d\n", k, num_rows, source_row); + seqsort_key [(iexpert - expert_start_index) * capacity + row_in_expert] = source_row; // 为保证 padding 位置(0)在最后, 所以对 pos-id 取减去其最大值 + seqsort_value[(iexpert - expert_start_index) * capacity + row_in_expert] = expanded_source_row; + }else{ + seqsort_key[expanded_dest_row] = source_row; + seqsort_value[expanded_dest_row] = expanded_source_row; + } +} + + + +template +void build_seqsort_kv_pairs_kernel_launcher(T* seqsort_key, // 实现初始化为 num-rows,保证 sort 到最后 + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, //output + const int num_rows, + const int k, + const int64_t num_active, // -1 expert pos + const int64_t capacity, + const int64_t expert_start_index, + bool use_pad, + cudaStream_t stream) +{ + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + build_seqsort_kv_pairs_kernel<<>>(seqsort_key, + seqsort_value, + expanded_dest_row_to_expanded_source_row, + // expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + k, + num_active, + capacity, + expert_start_index, + use_pad + ); + +} + + +template +__global__ void copy_unpermuted_to_permuted_kernel(const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, + const int64_t k, + const int64_t cols) +{ + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int padded_dest_row = blockIdx.x; + if (padded_out_to_unpermuted_input[padded_dest_row] == num_rows){ + // padded_out_to_unpermuted_input[padded_dest_row] = -1; + return; // padded place + } + const int source_row = padded_out_to_unpermuted_input[padded_dest_row]; + const int source_row_expanded = padded_out_to_expanded_input[padded_dest_row]; + if (threadIdx.x == 0){ + expanded_input_to_padded_out[source_row_expanded] = padded_dest_row; + } + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* padded_dest_row_ptr = permuted_output + padded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x* VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &padded_dest_row_ptr[tid]); + } + PADDLE_ENFORCE((padded_dest_row < padded_len)&&(source_row_expanded < num_rows * k), + "The index is out of bounds, " + "origin_input[%d] -> distributed_input:[%d], should < [%ld],[%ld] \n", + source_row_expanded, padded_dest_row, num_rows*k, padded_len); + + // for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + // padded_dest_row_ptr[tid] = source_row_ptr[tid]; // copy + // } +} + +template +void copy_unpermuted_to_permuted_kernelLauncher(const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, //unpermuted_input_len + const int64_t k, + const int64_t num_cols, + cudaStream_t stream) +{ + auto blocks = padded_len; + auto threads = std::min(num_cols, static_cast(1024)); + constexpr int64_t max_pack_size = 16 / sizeof(T); + if (num_cols % max_pack_size == 0) { + copy_unpermuted_to_permuted_kernel<<>>( + unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + }else{ + copy_unpermuted_to_permuted_kernel<<>>( + unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + } +} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_kernel_impl.h b/paddle/phi/kernels/gpu/moe_kernel_impl.h index e2ea1eac0d5554..ca9834cd827690 100644 --- a/paddle/phi/kernels/gpu/moe_kernel_impl.h +++ b/paddle/phi/kernels/gpu/moe_kernel_impl.h @@ -30,19 +30,19 @@ static inline size_t AlignTo16(const size_t& input) { class CubKeyValueSorter { public: - CubKeyValueSorter(); + inline CubKeyValueSorter(); - CubKeyValueSorter(cudaStream_t stream = 0); + inline CubKeyValueSorter(cudaStream_t stream = 0); - explicit CubKeyValueSorter(const int num_experts); + inline explicit CubKeyValueSorter(const int num_experts); - void update_num_experts(const int num_experts); + inline void update_num_experts(const int num_experts); - size_t getWorkspaceSize(const size_t num_key_value_pairs, + inline size_t getWorkspaceSize(const size_t num_key_value_pairs, bool descending = false); template - void run(void* workspace, + inline void run(void* workspace, const size_t workspace_size, const KeyT* keys_in, KeyT* keys_out, @@ -154,7 +154,7 @@ void CubKeyValueSorter::run(void* workspace, } template <> -void CubKeyValueSorter::run(void* workspace, +inline void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const __nv_bfloat16* keys_in, __nv_bfloat16* keys_out, diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu new file mode 100644 index 00000000000000..a006edc8e5403f --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -0,0 +1,523 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/moe_fuse_op.h" +#include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" + +namespace phi { + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// already defined need to revise! +// static inline size_t AlignTo16(const size_t &input){ +// static constexpr int ALIGNMENT = 16; +// return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +// } + +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int capacity, + const int k, + // const int max_seq_len, + bool use_pad, + phi::CubKeyValueSorter &sorter) +{ + + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + const int interbuf_size = AlignTo16(k * num_rows * inter_size); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + const int num_dispatched_size = AlignTo16(num_experts * capacity); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 4 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += 2 * num_dispatched_size * sizeof(int); + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + total_ws_bytes += num_softmax_outs * sizeof(KeyT); + const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const int sorter_ws_size_bytes = + std::max(AlignTo16(sorter.getWorkspaceSize(k * num_rows)), + AlignTo16(sorter.getWorkspaceSize(capacity))); + //sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity // 用所有 bit 做排序,会降低些许性能,但是防止越界 + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) + { + int remaining_bytes = AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + // std::cout<<"num_softmax_outs --"<< num_softmax_outs << std::endl; + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + // std::cout<<"buf_size --"<< buf_size<<" "< +void apply_moe_dispatch_fwd( + const Context& dev_ctx, + const DenseTensor& x, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + thrust::host_vector& expert_offset_host, + DenseTensor *y, + float *combine_weights, + int *scatter_index, + int * scatter_index_rev, + int64_t *expert_offset_global, + int64_t* expert_nums_local, + int *expert_id, + bool use_pad, + cudaStream_t stream) +{ + phi::CubKeyValueSorter sorter(stream); + // paddle::Tensor expanded_source_row_to_expanded_dest_row_tensor = + // paddle::empty({num_rows, k}, paddle::DataType::INT32, place); + // int* expanded_source_row_to_expanded_dest_row = + // expanded_source_row_to_expanded_dest_row_tensor.data(); + + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, paddle::DataType::FLOAT32, place); + // float* expert_scales_float = expert_scales_tensor_float.data(); + + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, paddle::DataType::INT32, place); + // int* expert_for_source_row = expert_for_source_row_tensor.data(); + // paddle::Tensor active_cnt_tensor = paddle::empty({1}, paddle::DataType::INT32, place); + + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, // hidden-size=0 + 0, // inter-size=0 + num_experts, + capacity, + k, + use_pad, + sorter); + + DenseTensor ws_ptr_tensor = phi::Empty(dev_ctx, {bytes}); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + phi::memory_utils::ThrustAllocator allocator(dev_ctx.GetPlace(), dev_ctx.stream()); + + // Pointers + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expert_id_; + int *source_rows_for_seqsort_; + int *source_rows_for_seqsort_out_; + int *source_pos_for_seqsort_; + int *source_pos_for_seqsort_out_; + int64_t *expert_offset_; // local-expert-offset + + char *sorter_ws_; + // T* permuted_data_; + float *softmax_out_; + // int64_t* total_rows_before_expert_; + T *fc1_result_; + + const int sorter_ws_size_bytes = AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + const int sorter_ws_size_bytes_seqsort = AlignTo16(sorter.getWorkspaceSize(capacity)); + + const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * 0); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + const int num_dispatched_size = AlignTo16(num_experts * capacity); + + // 4:ints [k*row] + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expert_id_ = permuted_experts_ + num_moe_inputs; + // 4:ints: [E*C] + source_rows_for_seqsort_ = expert_id_ + num_moe_inputs; + source_rows_for_seqsort_out_ = source_rows_for_seqsort_ + num_dispatched_size; + // 1:ints: [E] + expert_offset_ = reinterpret_cast (source_rows_for_seqsort_out_ + num_dispatched_size); + // permuted_data_ = reinterpret_cast(expert_offset_ + padded_experts); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + buf_size); + + // only use one number + // num_active = reinterpret_cast(permuted_experts_ + num_moe_inputs); + fc1_result_ = reinterpret_cast(expert_offset_ + padded_experts); + // fc1_result_ = reinterpret_cast(permuted_data_ + buf_size); + +#ifdef DEBUG_MOE_OP + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits before_topk")); + // print_to_screen1(finished, 2, 16, std::string("finished before_topk")); +#endif + + thrust::transform( + thrust::cuda::par.on(stream), + thrust::device_pointer_cast(source_rows_), + thrust::device_pointer_cast(source_rows_) + num_rows * k, + thrust::counting_iterator(0), + thrust::device_pointer_cast(source_rows_), + [num_rows, k] __device__ (int i, int cnt) { + int k_idx = cnt % k; + int block_row = cnt / k; + return k_idx * num_rows + block_row; + } + ); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1(combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1(expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1(source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + + // compute global expert offset, **not** consider capacity + // 必须在 modify_and_mask_expert_id_launcher 之前算出**全局 expert-offset** + + compute_global_expert_offset(expert_id, + expert_id_, //buffer + expert_offset_global, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + // modifiy expert-id according to k + modify_and_mask_expert_id_launcher(expert_id, + expert_id_, + k, + num_rows, + static_cast(num_experts), + static_cast(expert_start_index), + static_cast(expert_end_index), + stream); + + + #ifdef DEBUG_MOE_OP + print_to_screen1(expert_id_, 8, 16, std::string("expert-id after modified 22")); +#endif + sorter.run(fc1_result_, + sorter_ws_size_bytes, + expert_id_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + unmodify_expert_id_launcher(permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1(permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_local_expert_offset( + permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG] num_active v2: " << expert_offset_host.back() << std::endl; + print_to_screen1(expert_offset_global, 8, 16, std::string("expert_offset global")); + print_to_screen1(expert_offset_, 8, 16, std::string("expert_offset local")); + print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, std::string("expert-id after permute")); +#endif + + // calc expert-size + // 不 use-pad 的情况下,在此处标记截断位置。之后需要再 sort 一遍把截断 id 放到句尾 + if (!use_pad){ // 2sort + cal_expert_size_and_filter_launcher(permuted_experts_, + expert_offset_, + expert_offset_host.back(), + num_experts, + capacity, + expert_start_index, + expert_end_index, + reverse_token_drop, + stream); + //2sort + sorter.run(fc1_result_, + sorter_ws_size_bytes, + permuted_experts_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + permuted_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + compute_local_expert_offset( + permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG](after 2sort) num_active v2: " << expert_offset_host.back() << std::endl; + print_to_screen1(expert_id_, 8, 16, std::string(" permuted_experts")); + print_to_screen1(permuted_experts_, 8, 16, std::string(" permuted_experts")); + print_to_screen1(permuted_rows_, 8,16, std::string(" dest->src idx")); +#endif + } + + thrust::fill( + thrust::cuda::par.on(stream), + thrust::device_ptr(scatter_index_rev), + thrust::device_ptr(scatter_index_rev) + num_experts * capacity, + num_rows + ); + build_seqsort_kv_pairs_kernel_launcher(scatter_index_rev, //padded_to_unpermuted_input + source_rows_for_seqsort_, //seqsort-value + permuted_rows_, + // scatter_index, // 对截断位置置0 + permuted_experts_, + expert_offset_, + combine_weights, // 对截断位置置0 + static_cast(num_rows), + static_cast(k), + expert_offset_host.back(), //num_active + capacity, + expert_start_index, // expert start index + use_pad, + stream); + +#ifdef DEBUG_MOE_OP + + // print_to_screen1(scatter_index, 8, 16, std::string("scatter_index after build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1(source_rows_for_seqsort_, 8, 16, std::string("source_rows_for_seqsort_ after build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1(scatter_index_rev, 8, 16, std::string("scatter_index_rev after build_seqsort_kv_pairs_kernel_launcher")); +#endif + if (use_pad){ + for (auto iexpert = 0; iexpert != expert_end_index - expert_start_index; ++iexpert){ + sorter.run(fc1_result_, + sorter_ws_size_bytes_seqsort, + scatter_index_rev + (iexpert * capacity), // key in + scatter_index_rev + (iexpert * capacity), // key out + source_rows_for_seqsort_ + (iexpert * capacity), // value in + source_rows_for_seqsort_ + (iexpert * capacity), // value out //[num_row, k]: id在原 activation 中的位置 + capacity, // num_rows + false, + stream); + } + }else{ + auto sort_iter = thrust::make_zip_iterator(thrust::make_tuple( + thrust::device_pointer_cast(permuted_experts_), //key1 + thrust::device_pointer_cast(scatter_index_rev), //key2 + thrust::device_pointer_cast(source_rows_for_seqsort_) + )); + thrust::stable_sort( + thrust::cuda::par.on(stream), + sort_iter, + sort_iter + expert_offset_host.back(), + []__device__(auto lhs, auto rhs){ + if (thrust::get<0>(lhs) < thrust::get<0>(rhs)) + return true; + else if(thrust::get<0>(lhs) > thrust::get<0>(rhs)) + return false; + else + return thrust::get<1>(lhs) < thrust::get<1>(rhs); + } + ); + } + if (use_pad) { + int64_t num_experts_diff = expert_end_index - expert_start_index; + y->Resize({num_experts_diff * capacity, x.dims()[1]}); + dev_ctx.template Alloc(y); + } else { + y->Resize({expert_offset_host.back(), x.dims()[1]}); + dev_ctx.template Alloc(y); + } + copy_unpermuted_to_permuted_kernelLauncher(x.data(), + y->data(), //out + scatter_index_rev, //padded_out_to_unpermuted_input + source_rows_for_seqsort_, //padded_out_to_expanded_input + scatter_index, //out + use_pad? (expert_end_index - expert_start_index) * capacity : expert_offset_host.back(), //num_active + num_rows, + k, + hidden_size, + stream); + // cudaDeviceSynchronize(); //debug + // turn expert_offset_ptr into experts_num + return; +} + +template +void moe_dispatch_fwd(const Context& dev_ctx, + const DenseTensor& x, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + thrust::host_vector& expert_offset_host, + DenseTensor* y, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_nums_local, + const DenseTensor& expert_id, + bool use_pad){ + apply_moe_dispatch_fwd( + dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(scatter_index_rev.data()), + const_cast(expert_offset.data()), + const_cast(expert_nums_local.data()), + const_cast(expert_id.data()), + use_pad, + dev_ctx.stream()); +} + +template +void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor* y, + DenseTensor* combine_weights_out, + DenseTensor* scatter_index, + DenseTensor* scatter_index_rev, + DenseTensor* expert_offset, + DenseTensor* expert_nums_local){ + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(scatter_index_rev); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(expert_nums_local); + phi::Copy(dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); + const auto &x_shape = x.dims(); + int64_t num_rows = x_shape[0]; + int64_t hidden_size = x_shape[1]; + thrust::host_vector expert_offset_host(num_experts); + int64_t num_experts_diff = expert_end_index - expert_start_index; + moe_dispatch_fwd(dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + *combine_weights_out, + *scatter_index, + *scatter_index_rev, + *expert_offset, //global-offset + *expert_nums_local, + expert_id, + use_pad + ); + if(use_pad){ + // scatter_index_rev = scatter_index_rev.slice(0, num_experts_diff * capacity); + *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); + }else{ + if (expert_offset_host.back() > 0){ + // y = y.slice(0, expert_offset_host.back()); + // scatter_index_rev = scatter_index_rev.slice(0, expert_offset_host.back()); + *y = phi::Slice(dev_ctx, *y, {0}, {0}, {expert_offset_host.back()}); + *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); + }else{ + *y = phi::Empty(dev_ctx, {1, x_shape[1]}); + *scatter_index_rev = phi::Empty(dev_ctx, {}); //special treatment + } + } +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchPartialNoSoftMaxTopkKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h new file mode 100644 index 00000000000000..fbe517b531066c --- /dev/null +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi{ +template +void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor* y, + DenseTensor* combine_weights_out, + DenseTensor* scatter_index, + DenseTensor* scatter_index_rev, + DenseTensor* expert_offset, + DenseTensor* expert_nums_local); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index e183db924f844f..e7f80f5c5d8e82 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3602,6 +3602,16 @@ data_type : x backward : moe_combine_grad +- op : moe_gate_dispatch_partial_nosoftmaxtopk + args : (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) + output : Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) + infer_meta : + func : MoeGateDispatchPartialNoSoftmaxTopKInferMeta + kernel : + func : moe_gate_dispatch_partial_nosoftmaxtopk + data_type : x + inplace : (combine_weights -> combine_weights_out) + - op : moe_gate_dispatch_permute args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) diff --git a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py new file mode 100644 index 00000000000000..cb06fc6a856ac1 --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py @@ -0,0 +1,63 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +import paddle +from paddle import _C_ops +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + +def moe_ops_partial_nosoftmaxtopk( + x: Tensor, + combine_weights: Tensor, + expert_id: Tensor, + k: int, + capacity: int, + num_experts: int, + use_pad: bool, + expert_start_index: int, + expert_end_index: int, + reverse_token_drop: bool, + name: str | None = None +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch_partial_nosoftmaxtopk(x, combine_weights, expert_id, k, capacity, num_experts, use_pad, expert_start_index, expert_end_index, reverse_token_drop) + helper = LayerHelper("moe_ops_partial_nosoftmaxtopk", **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights_out = helper.create_variable_for_type_inference(dtype=combine_weights.dtype) + scatter_index = helper.create_variable_for_type_inference(dtype='int32') + scatter_index_rev = helper.create_variable_for_type_inference(dtype='int32') + expert_offset = helper.create_variable_for_type_inference(dtype='int64') + expert_nums_local = helper.create_variable_for_type_inference(dtype='int64') + inputs = { + "x": x, + "combine_weights": combine_weights, + "expert_id": expert_id, + } + outputs = { + "y": y, + "combine_weights_out": combine_weights_out, + "scatter_index": scatter_index, + "scatter_index_rev": scatter_index_rev, + "expert_offset": expert_offset, + "expert_nums_local": expert_nums_local, + } + attrs = { + "k": k, + "capacity": capacity, + "num_experts": num_experts, + "use_pad": use_pad, + "expert_start_index": expert_start_index, + "expert_end_index": expert_end_index, + "reverse_token_drop": reverse_token_drop, + } + helper.append_op( + type="moe_ops_partial_nosoftmaxtopk", + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + return (y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local) + + \ No newline at end of file From 829bad3028697aae2f41651adefbe39e62bb738d Mon Sep 17 00:00:00 2001 From: zhenghuaijin Date: Mon, 26 May 2025 19:38:22 +0800 Subject: [PATCH 31/71] finishi fused_rms_norm fwd --- paddle/phi/infermeta/backward.cc | 14 + paddle/phi/infermeta/backward.h | 9 + paddle/phi/infermeta/binary.cc | 14 + paddle/phi/infermeta/binary.h | 6 + paddle/phi/infermeta/multiary.cc | 1 + paddle/phi/infermeta/multiary.h | 2 + paddle/phi/kernels/CMakeLists.txt | 3 +- .../phi/kernels/gpu/layer_norm_cuda_kernel.cu | 108 ++ paddle/phi/kernels/layer_norm_cuda_kernel.h | 1084 +++++++++++++++++ paddle/phi/ops/yaml/backward.yaml | 10 + paddle/phi/ops/yaml/ops.yaml | 12 +- .../paddle/incubate/nn/functional/__init__.py | 2 + .../nn/functional/fused_rms_norm_ext.py | 33 + 13 files changed, 1296 insertions(+), 2 deletions(-) create mode 100644 paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu create mode 100644 paddle/phi/kernels/layer_norm_cuda_kernel.h create mode 100644 python/paddle/incubate/nn/functional/fused_rms_norm_ext.py diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a7d368ea869b22..9883a070ad3d20 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1887,4 +1887,18 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, value_grad->share_lod(values); } } + +void FusedRMSNormGradInferMeta(const MetaTensor &x, + const MetaTensor &scale, + const MetaTensor &invvar, + const MetaTensor &dy, + float epsilon, + MetaTensor* grad_x, + MetaTensor* grad_scale){ + + grad_x->set_dims(x.dims()); + grad_x->set_dtype(x.dtype()); + grad_scale->set_dims(scale.dims()); + grad_scale->set_dtype(scale.dtype()); +} } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bca0c6f53906f9..bbcfa711d5cf79 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -680,4 +680,13 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* value_grad); +void FusedRMSNormGradInferMeta(const MetaTensor &x, + const MetaTensor &scale, + const MetaTensor &invvar, + const MetaTensor &dy, + float epsilon, + MetaTensor* grad_x, + MetaTensor* grad_scale); + + } // namespace phi diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a23fa98a79af7f..7b75627d338ac5 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -4592,6 +4592,20 @@ void WeightDequantizeInferMeta(const MetaTensor& x, out->set_dtype(scale.dtype()); } +void FusedRMSNormInferMeta(const MetaTensor &x, + const MetaTensor &scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar){ + // Y: same shape, dtype, layout as X + y->set_dims(x.dims()); + y->set_dtype(x.dtype()); + // mean & invvar: 1-D length = x.dims()[0] + int64_t rows = x.dims()[0]; + invvar->set_dims(DDim({rows})); + invvar->set_dtype(DataType::FLOAT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 0b4d20862f4773..799fc1267e83c6 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -790,5 +790,11 @@ void WeightDequantizeInferMeta(const MetaTensor& x, const std::string& algo, const int32_t group_size, MetaTensor* out); +void FusedRMSNormInferMeta(const MetaTensor &x, + const MetaTensor &scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar); + } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ef5cd2679006a..5b431cb829f58d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6273,5 +6273,6 @@ void TopPSamplingInferMeta(const MetaTensor& x, } } + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index dfe1af6754aa9d..441b1156b42418 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -1284,4 +1284,6 @@ void TopPSamplingInferMeta(const MetaTensor& x, MetaTensor* topk_scores, MetaTensor* topk_ids); + + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 1c1f7ce4fc61fd..c521e865a077e8 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -237,7 +237,8 @@ if(WITH_ROCM) "gpu/matrix_rank_tol_kernel.cu" "gpu/svd_kernel.cu" "gpu/cuda_gemm_kernel.cu" - "gpu/int_bincount.cu") + "gpu/int_bincount.cu" + "gpu/layer_norm_cuda_kernel.cu") endif() # Remove AP kernel when CINN is not enabled. diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu new file mode 100644 index 00000000000000..a311c77e42f15e --- /dev/null +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -0,0 +1,108 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "paddle/common/exception.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" // NOLINT + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/layer_norm_cuda_kernel.h" // NOLINT + +namespace phi { +// #define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") + +static void GetRowsCols(const std::vector &shape, + int *p_rows, + int *p_cols) { + int rows = 1; + for (int i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + int cols = shape[shape.size() - 1]; + *p_rows = rows; + *p_cols = cols; +} + +template +void RMSLnFwd(const Context& ctx, + const DenseTensor &x, + const DenseTensor &scale, + float epsilon, + DenseTensor* y, + DenseTensor* invvar) { + const auto &scale_shape = scale.dims(); + const auto &x_shape = x.dims(); + PD_CHECK(scale_shape.size() == 1); + PD_CHECK(scale_shape[0] == x_shape[x_shape.size() - 1]); + + int rows, cols; + rows = x_shape[0]; + cols = x_shape[1]; + // GetRowsCols(x_shape, &rows, &cols); + + *y = phi::EmptyLike(ctx, x); + *invvar = phi::Empty(ctx, {rows}); + + cuda_rms_norm(ctx, x, scale, rows, cols, epsilon, y, invvar); +} + +template +void RMSLnBwd(const Context& ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &invvar, + const DenseTensor &dy, + float epsilon, + DenseTensor* grad_x, + DenseTensor* grad_scale) { + int rows, cols; + const auto &x_shape = x.dims(); + rows = x_shape[0]; + cols = x_shape[1]; + *grad_x = phi::EmptyLike(ctx, x); + *grad_scale = phi::EmptyLike(ctx, scale); + + cuda_rms_norm_gradient( + ctx, + x, + scale, + invvar, + dy, + rows, + cols, + epsilon, + grad_x, + grad_scale + ); +} + +} // namespace phi + + +PD_REGISTER_KERNEL(fused_rms_norm, + GPU, + ALL_LAYOUT, + phi::RMSLnFwd, + float, + double) {} + +PD_REGISTER_KERNEL(fused_rms_norm_grad, + GPU, + ALL_LAYOUT, + phi::RMSLnBwd, + float, + double) {} diff --git a/paddle/phi/kernels/layer_norm_cuda_kernel.h b/paddle/phi/kernels/layer_norm_cuda_kernel.h new file mode 100644 index 00000000000000..b54d44e2eb0825 --- /dev/null +++ b/paddle/phi/kernels/layer_norm_cuda_kernel.h @@ -0,0 +1,1084 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/common/exception.h" +#include "paddle/phi/core/dense_tensor.h" + +#include // NOLINT +#include // NOLINT + +namespace phi{ +#define DEFAULT_THROW(NAME, TYPE) \ + default: \ + do { \ + PD_THROW(#NAME, " not implemented for '", TYPE, "'"); \ + } while (0); \ + break + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + DEFAULT_THROW(NAME, TYPEOUT); \ + } \ + break; \ + } \ + DEFAULT_THROW(NAME, TYPEIN); \ + } + +#define WARP_SIZE 32 + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, + int laneMask, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { + return __shfl_xor_sync(mask, value, laneMask, width); +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, + int srcLane, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { + return __shfl_sync(mask, value, srcLane, width); +} + +template +__device__ void cuWelfordOnlineSum(const U curr, + U& mu, // NOLINT + U& sigma2, // NOLINT + U& count) { // NOLINT + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, + const U sigma2B, + const U countB, + U& mu, // NOLINT + U& sigma2, // NOLINT + U& count) { // NOLINT + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuRMSOnlineSum( + const U curr, + U& sigma2) +{ + sigma2 = sigma2 + curr * curr; +} + +template __device__ +void cuChanRMSOnlineSum( + const U sigma2B, + U& sigma2) +{ + sigma2 = sigma2 + sigma2B; +} + + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, // NOLINT + U& sigma2, // NOLINT + U* buf, bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; // NOLINT + U* ibuf = (U*)(ubuf + blockDim.y); // NOLINT + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + if (!rms_only) { + ubuf[2*wrt_y] = mu; + ibuf[wrt_y] = count; + } + ubuf[2 * wrt_y + 1] = sigma2; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + U muB = ubuf[2*threadIdx.y]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B,sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, // NOLINT + float& sigma2, // NOLINT + float* buf, bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); // NOLINT + sigma2 = float(0); // NOLINT + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const auto* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { // NOLINT + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); // NOLINT + if (!rms_only) { + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr.x, sigma2); + cuRMSOnlineSum(curr.y, sigma2); + } + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr,mu,sigma2,count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; // NOLINT + float* ibuf = (float*)(ubuf + blockDim.y); // NOLINT + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y + 1] = sigma2; + if (!rms_only) { + ubuf[2*wrt_y] = mu; + ibuf[wrt_y] = count; + } + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + float muB = ubuf[2*threadIdx.y]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / float(n2); // NOLINT + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); // NOLINT + } + } +} + +template +__inline__ __device__ U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +__inline__ __device__ float rsqrt(float v) { + return rsqrtf(v); +} +template <> +__inline__ __device__ double rsqrt(double v) { + return rsqrt(v); +} + +namespace { // NOLINT +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only) ) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * (curr - mu) + static_cast(beta[i])); + } else { + ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * curr); + } + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); +} + + +template __global__ +void cuApplyRMSNorm( + V* __restrict__ output_vals, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma) +{ + cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); +} + +template +__device__ void cuLoadWriteStridedInputs(const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + } + } else { + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs(const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + } + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta, bool rms_only) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; // NOLINT + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, + thr_load_row_off, + thr_load_col_off, + i2_off, + row_stride, + warp_buf1, + warp_buf2, + input, + dout, + i1_end, + n2, + mean, + invvar, rms_only); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, + thr_load_row_off, + thr_load_col_off, + i2_off, + row_stride, + warp_buf1, + warp_buf2, + input, + dout, + i1_end, + n2, + mean, + invvar,rms_only); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + if (!rms_only) { + acc1 += warp_buf1[idx1]; + } + acc2 += warp_buf2[idx1]; + } + + if (!rms_only) { + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + } + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + warp_buf1[idx1] += warp_buf1[idx2]; + } + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + } + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta, bool rms_only) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + if (!rms_only) { + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + if (!rms_only) { + buf[write_idx+nbsize3] = sum_beta; + } + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + if (!rms_only) { + sum_beta += buf[read_idx+nbsize3]; + } + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + if (!rms_only) { + grad_beta[i2] = sum_beta; + } + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input, bool rms_only) { + + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + U c_mean; + if (!rms_only) { + c_mean = mean[i1]; + } + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + const U gamma_tmp = static_cast(gamma[l + k]); + if (!rms_only) { + sum_loss1 += c_loss * gamma_tmp; + sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma_tmp * (c_h) * c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + const U gamma_tmp = static_cast(gamma[l]); + if (!rms_only) { + sum_loss1 += c_loss * gamma_tmp; + sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma_tmp * (c_h) * c_invvar; + } + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + if (!rms_only) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + if (!rms_only) { + buf[2*wrt_i] = sum_loss1; + } + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + if (!rms_only) { + sum_loss1 += buf[2*read_i]; + } + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (!rms_only) { + buf[2*threadIdx.x] = sum_loss1; + } + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + if (!rms_only) { + sum_loss1 = buf[2*threadIdx.x]; + } + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * static_cast(gamma[l]); + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h) * c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h) * c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + +static cudaDeviceProp GetDevicePropImpl() { + int device = -1; + PD_CHECK(cudaGetDevice(&device) == cudaSuccess); + cudaDeviceProp prop; + PD_CHECK(cudaGetDeviceProperties(&prop, device) == cudaSuccess); + return prop; +} + +static cudaDeviceProp* GetDeviceProp() { + static auto prop = GetDevicePropImpl(); + return ∝ +} + +template +void HostApplyLayerNorm(V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta, + cudaStream_t stream) { + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +template +void HostApplyRMSNorm( + V* output, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, cudaStream_t stream) +{ + // auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32,4,1); + // const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyRMSNorm<<>>( + output, invvar, input, n1, n2, U(epsilon), gamma); +} + +// template +// void cuda_layer_norm(const Context& ctx, +// const DenseTensor& x, +// const DenseTensor& scale, +// const DenseTensor& bias, +// int rows, +// int cols, +// float epsilon, +// DenseTensor* y, +// DenseTensor* mean, +// DenseTensor* invvar) { +// DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( +// x.dtype(), +// y->dtype(), +// "cuda_layer_norm_kernel", +// HostApplyLayerNorm(y->data(), +// mean->data(), +// invvar->data(), +// const_cast(x.data()), +// rows, +// cols, +// epsilon, +// const_cast(scale.data()), +// const_cast(bias.data()), +// ctx.stream())); +// } + +template +void cuda_rms_norm(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + int rows, + int cols, + float epsilon, + DenseTensor* y, + DenseTensor* invvar) { + HostApplyRMSNorm(y->data(), + invvar->data(), + const_cast(x.data()), + rows, + cols, + epsilon, + const_cast(scale.data()), + ctx.stream()); +} + +template +void HostRMSNormGradient( const Context& ctx, + const V* dout, + const U* invvar, + const DenseTensor& input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + cudaStream_t stream) { + if (gamma != NULL) { + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + auto place = input.place(); + DenseTensor part_grad_gamma = phi::Empty(ctx, {part_size, n2}); + + cuComputePartGradGammaBeta<<>>( + dout, + input.data(), + n1, + n2, + invvar, // unused + invvar, + U(epsilon), + part_grad_gamma.data(), + part_grad_gamma.data(), /* unused */ + true + ); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.data(), + part_grad_gamma.data(), /* unused */ + part_size, + n1, + n2, + grad_gamma, + grad_gamma, /* unused */ + true); + } + + // compute grad_input + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>(dout, + input.data(), + n1, + n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); +} + +template +void cuda_rms_norm_gradient(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& invvar, + const DenseTensor& dy, + int rows, + int cols, + float epsilon, + DenseTensor* grad_x, + DenseTensor* grad_scale) { + HostRMSNormGradient( + ctx, + dy.data(), + invvar.data(), + x, + rows, + cols, + scale.data(), + epsilon, + grad_x->data(), + grad_scale->data(), + ctx.stream()); +} + +} \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 52050ccd2bd805..392fe60ca30eda 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3926,3 +3926,13 @@ param : [condition] composite: where_double_grad(condition, grad_x_grad, grad_y_grad, grad_out_grad) optional: grad_x_grad, grad_y_grad + +# - backward_op: fused_rms_norm_grad +# forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) +# args: (Tensor x, Tensor scale, Tensor y, Tensor invvar, float epsilon) +# output: Tensor(grad_x), Tensor(grad_scale) +# infer_meta: +# func: FusedRMSNormGradInferMeta +# kernel: +# func: fused_rms_norm_grad +# data_type: x \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 98c89cca6b678b..d3c4368b21b6a7 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5682,4 +5682,14 @@ func: IntBincountInferMeta kernel: func: int_bincount - data_type: x \ No newline at end of file + data_type: x + +- op: fused_rms_norm + args: (Tensor x, Tensor scale, float epsilon) + output: Tensor(y), Tensor(invvar) + infer_meta: + func: FusedRMSNormInferMeta + kernel: + func: fused_rms_norm + data_type: x + # backward: fused_rms_norm_grad \ No newline at end of file diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 7ae9a98964a6f2..e26ec823a6af10 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -44,6 +44,7 @@ variable_length_memory_efficient_attention, ) from .int_bincount import int_bincount +from .fused_rms_norm_ext import fused_rms_norm_ext __all__ = [ 'fused_multi_head_attention', @@ -64,4 +65,5 @@ "block_multihead_attention", "swiglu", "int_bincount", + "fused_rms_norm_ext", ] diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py new file mode 100644 index 00000000000000..35fa675d94bb05 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py @@ -0,0 +1,33 @@ +# File: python/paddle/incubate/nn/functional/layer_norm_cuda.py +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import convert_dtype +import paddle + +def fused_rms_norm_ext(x, scale, bias=None, epsilon=1e-5, name=None): + """ + Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation. + Args: + x (Tensor): Input tensor of shape [rows, cols] or higher dimensions (flattened to 2D). + scale (Tensor): Scale tensor of shape [cols]. + bias (Tensor, optional): Bias tensor of shape [cols]. If None, no bias is added. + epsilon (float): Small constant to avoid division by zero. + name (str, optional): Name of the operator. + Returns: + y (Tensor): Normalized tensor of same shape as x. + mean (Tensor): Tensor of shape [rows], the mean of each row. + invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. + """ + helper = LayerHelper('fused_rms_norm', **locals()) + dtype = convert_dtype(x.dtype) + y = helper.create_variable_for_type_inference(dtype) + invvar = helper.create_variable_for_type_inference('float32') + + inputs = {'x': x, 'scale': scale} + + helper.append_op( + type='fused_rms_norm', + inputs=inputs, + outputs={'y': y, 'invvar': invvar}, + attrs={'epsilon': epsilon} + ) + return y, invvar \ No newline at end of file From 4f5ede6f527ee1b94fe960dd0ccb2c84566cc514 Mon Sep 17 00:00:00 2001 From: zhenghuaijin Date: Tue, 27 May 2025 02:22:29 +0800 Subject: [PATCH 32/71] finish rms_norm bwd --- paddle/phi/infermeta/backward.cc | 12 ++++++------ paddle/phi/infermeta/backward.h | 6 ++---- .../phi/kernels/gpu/layer_norm_cuda_kernel.cu | 12 ++++++------ paddle/phi/ops/yaml/backward.yaml | 17 ++++++++--------- paddle/phi/ops/yaml/ops.yaml | 2 +- .../incubate/nn/functional/int_bincount.py | 2 ++ 6 files changed, 25 insertions(+), 26 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 9883a070ad3d20..c1ca519362564e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1893,12 +1893,12 @@ void FusedRMSNormGradInferMeta(const MetaTensor &x, const MetaTensor &invvar, const MetaTensor &dy, float epsilon, - MetaTensor* grad_x, - MetaTensor* grad_scale){ + MetaTensor* x_grad, + MetaTensor* scale_grad){ - grad_x->set_dims(x.dims()); - grad_x->set_dtype(x.dtype()); - grad_scale->set_dims(scale.dims()); - grad_scale->set_dtype(scale.dtype()); + x_grad->set_dims(x.dims()); + x_grad->set_dtype(x.dtype()); + scale_grad->set_dims(scale.dims()); + scale_grad->set_dtype(scale.dtype()); } } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bbcfa711d5cf79..20dbacca2329b7 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -685,8 +685,6 @@ void FusedRMSNormGradInferMeta(const MetaTensor &x, const MetaTensor &invvar, const MetaTensor &dy, float epsilon, - MetaTensor* grad_x, - MetaTensor* grad_scale); - - + MetaTensor* x_grad, + MetaTensor* scale_grad); } // namespace phi diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu index a311c77e42f15e..9f40cf79c2a7d0 100644 --- a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -67,14 +67,14 @@ void RMSLnBwd(const Context& ctx, const DenseTensor &invvar, const DenseTensor &dy, float epsilon, - DenseTensor* grad_x, - DenseTensor* grad_scale) { + DenseTensor* x_grad, + DenseTensor* scale_grad) { int rows, cols; const auto &x_shape = x.dims(); rows = x_shape[0]; cols = x_shape[1]; - *grad_x = phi::EmptyLike(ctx, x); - *grad_scale = phi::EmptyLike(ctx, scale); + *x_grad = phi::EmptyLike(ctx, x); + *scale_grad = phi::EmptyLike(ctx, scale); cuda_rms_norm_gradient( ctx, @@ -85,8 +85,8 @@ void RMSLnBwd(const Context& ctx, rows, cols, epsilon, - grad_x, - grad_scale + x_grad, + scale_grad ); } diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 392fe60ca30eda..dc9eb43788e5ad 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3927,12 +3927,11 @@ composite: where_double_grad(condition, grad_x_grad, grad_y_grad, grad_out_grad) optional: grad_x_grad, grad_y_grad -# - backward_op: fused_rms_norm_grad -# forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) -# args: (Tensor x, Tensor scale, Tensor y, Tensor invvar, float epsilon) -# output: Tensor(grad_x), Tensor(grad_scale) -# infer_meta: -# func: FusedRMSNormGradInferMeta -# kernel: -# func: fused_rms_norm_grad -# data_type: x \ No newline at end of file +- backward_op: fused_rms_norm_grad + forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale, Tensor y_grad, Tensor invvar_grad, float epsilon) + output: Tensor(x_grad), Tensor(scale_grad) + infer_meta: + func: FusedRMSNormGradInferMeta + kernel: + func: fused_rms_norm_grad \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index d3c4368b21b6a7..89d1eb1d1f348e 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5692,4 +5692,4 @@ kernel: func: fused_rms_norm data_type: x - # backward: fused_rms_norm_grad \ No newline at end of file + backward: fused_rms_norm_grad \ No newline at end of file diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py index 6389f834a49e45..7c51cb0ac0aaa2 100644 --- a/python/paddle/incubate/nn/functional/int_bincount.py +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -4,6 +4,8 @@ def int_bincount(x, low, high, dtype=None, name=None): + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) helper = LayerHelper("int_bincount", **locals()) out_dtype = dtype if dtype is not None else x.dtype y = helper.create_variable_for_type_inference(dtype=out_dtype) From 8797589dc7518d54916e0935deadee99a7406bbd Mon Sep 17 00:00:00 2001 From: zhenghuaijin Date: Tue, 27 May 2025 02:24:43 +0800 Subject: [PATCH 33/71] finish rms norm bwd --- python/paddle/incubate/nn/functional/int_bincount.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py index 7c51cb0ac0aaa2..32a6697d1c210f 100644 --- a/python/paddle/incubate/nn/functional/int_bincount.py +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -4,8 +4,8 @@ def int_bincount(x, low, high, dtype=None, name=None): - if in_dynamic_or_pir_mode(): - return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) + # if in_dynamic_or_pir_mode(): + # return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) helper = LayerHelper("int_bincount", **locals()) out_dtype = dtype if dtype is not None else x.dtype y = helper.create_variable_for_type_inference(dtype=out_dtype) From b321097bc6d0059827e65f6951c57729a33eb91b Mon Sep 17 00:00:00 2001 From: feixi21 <1802550529@qq.com> Date: Tue, 27 May 2025 03:37:55 +0000 Subject: [PATCH 34/71] add optional in ops.yaml --- paddle/phi/ops/yaml/ops.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 5c2d7eee69009b..99a8f3ea1a5710 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -901,6 +901,7 @@ kernel : func : cal_aux_loss data_type : gate_prob + optional: tokens_mask, dispatch_tokens_mask backward : cal_aux_loss_grad - op : calc_reduced_attn_scores @@ -3610,6 +3611,7 @@ kernel : func : moe_gate_dispatch data_type : x + optional : corr_bias backward : moe_gate_dispatch_grad - op : momentum_ From 3e406ca8166bb3158d226273834b515e7a961f62 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 27 May 2025 06:15:33 +0000 Subject: [PATCH 35/71] nosoftmax bwd has finished --- paddle/phi/infermeta/backward.cc | 54 ++++ paddle/phi/infermeta/backward.h | 16 ++ paddle/phi/infermeta/ternary.cc | 4 +- paddle/phi/kernels/gpu/moe_fuse_bwd_op.h | 241 ++++++++++++++++++ ...e_ops_partial_nosoftmaxtopk_grad_kernel.cu | 143 +++++++++++ .../moe_ops_partial_nosoftmaxtopk_kernel.cu | 8 +- ...oe_ops_partial_nosoftmaxtopk_grad_kernel.h | 36 +++ paddle/phi/ops/yaml/backward.yaml | 10 + paddle/phi/ops/yaml/ops.yaml | 3 +- .../moe_ops_partial_nosoftmaxtopk.py | 47 +++- 10 files changed, 554 insertions(+), 8 deletions(-) create mode 100644 paddle/phi/kernels/gpu/moe_fuse_bwd_op.h create mode 100644 paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu create mode 100644 paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 805e4a0ac868a5..2a54371232f3e4 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1240,6 +1240,60 @@ void MoeCombineGradInferMeta(const MetaTensor& x, grad_combine_weights_helper->set_dtype(x.dtype()); } +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad){ + printf("check infer\n"); + printf("combine shape: %d, scatter shape: %d\n", combine_weights_out.dims().size(), scatter_index.dims().size()); + printf("sizeof(combine_weights_out): %d\n", sizeof(combine_weights_out)); + printf("sizeof(y_grad): %d\n", sizeof(y_grad)); + printf("sizeof combine_weights_out_grad: %d\n", sizeof(combine_weights_out_grad)); + // printf("size of combine_weights_out_grad: %d\n", combine_weights_out_grad.size()); + printf("combine_weights_out_grad shape: %d\n", combine_weights_out_grad.dims().size()); + int64_t num_experts = expert_offset.dims()[0]; + int64_t hidden_size = y_grad.dims()[1]; + int64_t num_rows = scatter_index.dims()[1]; + PADDLE_ENFORCE_GT( + num_experts, + 0, + common::errors::InvalidArgument("Input num_experts should be greater than 0")); + PADDLE_ENFORCE_EQ( + (expert_offset.dtype()==phi::DataType::INT64), + true, + common::errors::InvalidArgument("Input expert_offset type should be int64")); + if(use_pad){ + PADDLE_ENFORCE_GE( + num_experts, + y_grad.dims()[0] / capacity, + common::errors::InvalidArgument( + "Number of experts should be greater than or equal to y_grad.dims()[0]/capacity")); + } else { + PADDLE_ENFORCE_GT(y_grad.dims()[0], + 0, + common::errors::InvalidArgument("Input y_grad.dims()[0] should be greater than 0")); + } + printf("y_grad shape: %d", y_grad.dims().size()); + printf("combine_weights_out_grad shape: %d, y_grad shape: %d", combine_weights_out_grad.dims().size(), y_grad.dims().size()); + printf("allocate combine_weights_grad\n"); + combine_weights_grad->set_dims(combine_weights_out_grad.dims()); + combine_weights_grad->set_dtype(phi::DataType::FLOAT32); + printf("allocate x_grad\n"); + x_grad->set_dims({num_rows, hidden_size}); + x_grad->set_dtype(y_grad.dtype()); + printf("check infer over\n"); +} + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 50cd2500b26d72..6f328f2a5d56a7 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -468,6 +468,22 @@ void MoeCombineGradInferMeta(const MetaTensor& x, const MetaTensor& grad_y, MetaTensor* grad_x, MetaTensor* grad_combine_weights_helper); +//Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index) +// output : Tensor(x_grad), Tensor(combine_weights_grad) +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad); void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 5b8f7447eb3361..cd79ef9207a58d 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1737,7 +1737,9 @@ expert_offset->set_dims({num_experts}); expert_offset->set_dtype(phi::DataType::INT64); expert_nums_local->set_dims({num_experts}); expert_nums_local->set_dtype(phi::DataType::INT64); -combine_weights_out->share_meta(combine_weights); +combine_weights_out->set_dims(combine_weights_dims); +combine_weights_out->set_dtype(combine_weights.dtype()); +// combine_weights_out->share_meta(combine_weights); } void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h b/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h new file mode 100644 index 00000000000000..0862bce2a57d89 --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h @@ -0,0 +1,241 @@ +#pragma once +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/common/exception.h" +#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" + + +template +__global__ void gather_with_mask_permute_kernel(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num, + int64_t capacity, + int64_t world_size, + int64_t num_local_experts + ){ + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ + idx < N; idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active){ + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f){ + int64_t remaining_after_irank = id % (num_local_experts * capacity); + + int64_t irank = id / (num_local_experts * capacity); + int64_t local_iexpert = remaining_after_irank / capacity; + int64_t row_in_expert = remaining_after_irank % capacity; + int64_t permuted_id = local_iexpert * (world_size * capacity) + irank * capacity + row_in_expert; + int64_t in_offset = permuted_id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +__global__ void gather_with_mask_kernel(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num + ){ + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ + idx < N; idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active){ + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f){ + int64_t in_offset = id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} + +inline int64_t max_shared_s_num(int64_t num_rows, int64_t dim, int64_t threads, int64_t vec_size) { + if ((threads * vec_size) % dim == 0) { + return min(num_rows, threads * vec_size / dim); + } else { + int64_t max_res = DivUp(threads * 4, dim); + for (int64_t idx = 0; idx < num_rows * dim; idx += vec_size * threads) { + int64_t si_start = idx / dim; + int64_t si_end = min(num_rows * dim, idx + vec_size * threads - 1) / dim; + max_res = max(max_res, (si_end - si_start + 1)); + } + return min(num_rows, max_res); + } +} + +template +void gather_with_mask_launcher(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s,k,d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t num_active, + cudaStream_t stream, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1, + int64_t capacity = -1 +){ + int numel = num_rows * dim; + + int64_t threads = 512; + if (dim % 4 == 0) { + int64_t blocks = DivUp(DivUp(numel, 4), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + + if (!use_all2all_permute) { + gather_with_mask_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); + gather_with_mask_permute_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } else { + int64_t blocks = DivUp(DivUp(numel, 1), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" << s_shared_num << ", block=" << blocks << std::endl; +#endif + + if (!use_all2all_permute) { + gather_with_mask_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + gather_with_mask_permute_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } +} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu new file mode 100644 index 00000000000000..6994e7a749bd7c --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -0,0 +1,143 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/gpu/moe_fuse_bwd_op.h" +#include +#include + +namespace phi{ + +template +void apply_moe_dispatch_bwd( + const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_out_grad, + float* combine_weights_in_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t num_active, + cudaStream_t stream){ + printf("apply_moe_dispatch_bwd\n"); + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, num_rows, k, dim, num_active, stream); + auto out_grad_ptr = thrust::device_pointer_cast(combine_weights_out_grad); + auto in_grad_ptr = thrust::device_pointer_cast(combine_weights_in_grad); + auto combine_weight_ptr = thrust::device_pointer_cast(combine_weights); + printf("kernel over\n"); + thrust::transform( + thrust::cuda::par.on(stream), + out_grad_ptr, + out_grad_ptr + num_rows * k, + combine_weight_ptr, + in_grad_ptr, + [] __device__ (float g, float w){ + return w > static_cast(0) ? g : static_cast(0); + } + ); + // topk_grad_with_mask_launcher(combine_weights_grad, + // expert_id, + // combine_weights, + // gate_logtis_grad, + // num_rows, k, num_experts, stream); +} + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor &combine_weights, // [s, k] + const DenseTensor &scatter_index, // [k, s] + const DenseTensor &y_grad, // [num_experts * capacity, h] + const DenseTensor &combine_weights_out_grad, // [s, k] + DenseTensor *x_grad, + DenseTensor *combine_weights_in_grad, + int64_t num_experts){ + int64_t num_rows = combine_weights.dims()[0]; + int64_t k = combine_weights.dims()[1]; + int64_t hidden_size = y_grad.dims()[1]; + int64_t num_active = y_grad.dims()[0]; + + apply_moe_dispatch_bwd( + y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_out_grad.data(), + combine_weights_in_grad->data(), + x_grad->data(), + num_rows, + k, + hidden_size, + num_experts, + num_active, + dev_ctx.stream()); +} + +template +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad){ + printf("MoeGateDispatchPartialNoSoftMaxTopkGradKernel begin\n"); + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(combine_weights_grad); + // DenseTensor t_scatter_index; + // printf("check pass\n"); + // phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); + // DenseTensor t_scatter_index_out; + // phi::ContiguousKernel(dev_ctx, t_scatter_index, &t_scatter_index_out); + // t_scatter_index = t_scatter_index_out; + // int64_t num_experts = expert_offset.dims()[0]; + printf("dive into moe_dispatch_bwd\n"); + // moe_dispatch_bwd(dev_ctx, + // combine_weights_out, + // t_scatter_index, + // y_grad, + // combine_weights_out_grad, + // x_grad, + // combine_weights_grad, + // num_experts); + +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchPartialNoSoftMaxTopkGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu index a006edc8e5403f..9354f3f1cde87d 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -24,6 +24,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/gpu/moe_kernel_impl.h" namespace phi { @@ -470,6 +471,7 @@ void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, dev_ctx.template Alloc(scatter_index_rev); dev_ctx.template Alloc(expert_offset); dev_ctx.template Alloc(expert_nums_local); + dev_ctx.template Alloc(combine_weights_out); phi::Copy(dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); const auto &x_shape = x.dims(); int64_t num_rows = x_shape[0]; @@ -498,13 +500,11 @@ void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, ); if(use_pad){ // scatter_index_rev = scatter_index_rev.slice(0, num_experts_diff * capacity); - *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); + *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); }else{ if (expert_offset_host.back() > 0){ - // y = y.slice(0, expert_offset_host.back()); // scatter_index_rev = scatter_index_rev.slice(0, expert_offset_host.back()); - *y = phi::Slice(dev_ctx, *y, {0}, {0}, {expert_offset_host.back()}); - *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); + *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); }else{ *y = phi::Empty(dev_ctx, {1, x_shape[1]}); *scatter_index_rev = phi::Empty(dev_ctx, {}); //special treatment diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h new file mode 100644 index 00000000000000..9929687d42d4fa --- /dev/null +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi{ +template +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index c36a3c1694ae9b..3d77c3a1871075 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2256,6 +2256,16 @@ kernel : func : moe_combine_grad +- backward_op : moe_gate_dispatch_partial_nosoftmaxtopk_grad + forward : moe_gate_dispatch_partial_nosoftmaxtopk (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) -> Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) + args : (Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_nums_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index) + output : Tensor(x_grad), Tensor(combine_weights_grad) + infer_meta : + func : MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta + kernel : + func : moe_gate_dispatch_partial_nosoftmaxtopk_grad + data_type : y_grad + - backward_op : mp_allreduce_sum_grad forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out) args : (Tensor out_grad, int ring_id = 0) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index e7f80f5c5d8e82..26b11e8500f6e8 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3610,7 +3610,8 @@ kernel : func : moe_gate_dispatch_partial_nosoftmaxtopk data_type : x - inplace : (combine_weights -> combine_weights_out) + # inplace : (combine_weights -> combine_weights_out) + backward : moe_gate_dispatch_partial_nosoftmaxtopk_grad - op : moe_gate_dispatch_permute args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) diff --git a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py index cb06fc6a856ac1..b1d639c7d253aa 100644 --- a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py +++ b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py @@ -58,6 +58,49 @@ def moe_ops_partial_nosoftmaxtopk( outputs=outputs, attrs=attrs, ) - return (y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local) + return y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local - \ No newline at end of file + + +import paddle +import numpy as np + +# 假设自定义算子已经在 Paddle 环境中注册 +# 创建一些假数据 +num_rows = 4 +feature_dim = 8 +num_experts = 3 +k = 2 +capacity = 5 + +# 输入张量 +x = paddle.to_tensor(np.random.rand(num_rows, feature_dim).astype('float32'), stop_gradient=False) + +# 合并权重张量 +combine_weights = paddle.to_tensor(np.random.rand(num_rows, k).astype('float32'), stop_gradient=False) + +# 专家ID张量 +expert_id = paddle.to_tensor(np.random.randint(0, num_experts, size=(num_rows, k)).astype('int32'), stop_gradient=False) + +print("x type:", x.dtype) +print("combine_weights type:", combine_weights.dtype) +print("expert_id type:", expert_id.dtype) +# 其他参数 +use_pad = True +expert_start_index = 0 +expert_end_index = num_experts +reverse_token_drop = False + +# 调用自定义算子 +y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local = moe_ops_partial_nosoftmaxtopk( + x=x, + combine_weights=combine_weights, + expert_id=expert_id, + k=k, + capacity=capacity, + num_experts=num_experts, + use_pad=use_pad, + expert_start_index=expert_start_index, + expert_end_index=expert_end_index, + reverse_token_drop=reverse_token_drop +) \ No newline at end of file From 9105be0c86fb5b5ec968a39042d9b94163be17ea Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 27 May 2025 06:18:08 +0000 Subject: [PATCH 36/71] update python api --- .../moe_ops_partial_nosoftmaxtopk.py | 83 +++++++++++-------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py index b1d639c7d253aa..5fbf53515b75af 100644 --- a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py +++ b/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py @@ -62,45 +62,56 @@ def moe_ops_partial_nosoftmaxtopk( -import paddle -import numpy as np +# import paddle +# import numpy as np + +# num_rows = 4 +# feature_dim = 8 +# num_experts = 3 +# k = 2 +# capacity = 5 + +# # 输入张量 +# x = paddle.to_tensor(np.random.rand(num_rows, feature_dim).astype('float32'), stop_gradient=False) -# 假设自定义算子已经在 Paddle 环境中注册 -# 创建一些假数据 -num_rows = 4 -feature_dim = 8 -num_experts = 3 -k = 2 -capacity = 5 +# # 合并权重张量 +# combine_weights = paddle.to_tensor(np.random.rand(num_rows, k).astype('float32'), stop_gradient=False) -# 输入张量 -x = paddle.to_tensor(np.random.rand(num_rows, feature_dim).astype('float32'), stop_gradient=False) +# # 专家ID张量 +# expert_id = paddle.to_tensor(np.random.randint(0, num_experts, size=(num_rows, k)).astype('int32'), stop_gradient=False) -# 合并权重张量 -combine_weights = paddle.to_tensor(np.random.rand(num_rows, k).astype('float32'), stop_gradient=False) +# print("x type:", x.dtype) +# print("combine_weights type:", combine_weights.dtype) +# print("expert_id type:", expert_id.dtype) +# # 其他参数 +# use_pad = True +# expert_start_index = 0 +# expert_end_index = num_experts +# reverse_token_drop = False -# 专家ID张量 -expert_id = paddle.to_tensor(np.random.randint(0, num_experts, size=(num_rows, k)).astype('int32'), stop_gradient=False) +# # 调用自定义算子 +# y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local = moe_ops_partial_nosoftmaxtopk( +# x=x, +# combine_weights=combine_weights, +# expert_id=expert_id, +# k=k, +# capacity=capacity, +# num_experts=num_experts, +# use_pad=use_pad, +# expert_start_index=expert_start_index, +# expert_end_index=expert_end_index, +# reverse_token_drop=reverse_token_drop +# ) -print("x type:", x.dtype) -print("combine_weights type:", combine_weights.dtype) -print("expert_id type:", expert_id.dtype) -# 其他参数 -use_pad = True -expert_start_index = 0 -expert_end_index = num_experts -reverse_token_drop = False +# # 打印结果 +# print("y:", y.numpy()) +# print("combine_weights_out:", combine_weights_out.numpy()) +# print("scatter_index:", scatter_index.numpy()) +# print("scatter_index_rev:", scatter_index_rev.numpy()) +# print("expert_offset:", expert_offset.numpy()) +# print("expert_nums_local:", expert_nums_local.numpy()) -# 调用自定义算子 -y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local = moe_ops_partial_nosoftmaxtopk( - x=x, - combine_weights=combine_weights, - expert_id=expert_id, - k=k, - capacity=capacity, - num_experts=num_experts, - use_pad=use_pad, - expert_start_index=expert_start_index, - expert_end_index=expert_end_index, - reverse_token_drop=reverse_token_drop -) \ No newline at end of file +# a = paddle.sum(y)+paddle.sum(combine_weights_out) +# a.backward() +# print("\n##########backward output##########\n") +# print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n expert_id.grad: {expert_id.grad}") \ No newline at end of file From 0f096361be93b7d79dc1f19dff13a3caa8e04296 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Tue, 27 May 2025 07:10:50 +0000 Subject: [PATCH 37/71] Verified cal_aux_loss op and bwd. --- paddle/phi/api/lib/data_transform.cc | 3 + paddle/phi/infermeta/multiary.cc | 5 +- paddle/phi/kernels/funcs/math_cuda_utils.h | 4 +- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 21 ++- paddle/phi/kernels/gpu/cub_kv_sorter.cu | 50 ++++++ .../kernels/gpu/moe_gate_dispatch_kernel.cu | 3 + .../gpu/moe_gate_dispatch_permute_kernel.cu | 2 + paddle/phi/kernels/gpu/moe_kernel_impl.h | 53 +----- paddle/phi/kernels/moe_kernel_impl.h | 52 +----- paddle/phi/ops/yaml/ops.yaml | 1 + .../paddle/incubate/nn/functional/__init__.py | 2 + test/legacy_test/ernie_utils/top2_gate.py | 158 +----------------- test/legacy_test/test_incubate_fused_loss.py | 152 +++++++++++++++++ third_party/openblas | 2 +- 14 files changed, 250 insertions(+), 258 deletions(-) create mode 100644 paddle/phi/kernels/gpu/cub_kv_sorter.cu create mode 100644 test/legacy_test/test_incubate_fused_loss.py diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index c4feaded4174bb..19a690d093e4a8 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -418,6 +418,9 @@ std::shared_ptr PrepareData( out = Trans2Contiguous(out); return std::make_shared(std::move(out)); } + VLOG(6) << "tensor_in: "<< tensor_in; + auto tmp = std::static_pointer_cast(tensor_in); + VLOG(6) << "tensor_in.size(): " << tmp->dims(); return std::static_pointer_cast(tensor_in); } phi::DenseTensor out = TransformData( diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index a99f305e047971..b7576d437edfa4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6368,10 +6368,10 @@ void CalAuxLossInferMeta(const MetaTensor& gate_prob, "The input dispatch_tokens_mask type should be BOOL")); } - l_aux_loss->set_dims({1}); + l_aux_loss->set_dims(phi::make_ddim({})); l_aux_loss->set_dtype(gate_prob.dtype()); - seqlen_floats->set_dims({1}); + seqlen_floats->set_dims(phi::make_ddim({})); seqlen_floats->set_dtype(gate_prob.dtype()); ce->set_dims({gate_prob_dims[1]}); @@ -6393,7 +6393,6 @@ void MoeGateDispatchInferMeta(const MetaTensor& x, auto gate_logits_dims = gate_logits.dims(); const int64_t num_rows = x_dims[0]; - const int64_t hidden_size = x_dims[1]; const int64_t num_experts = gate_logits_dims[1]; PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index f14b2af8c72609..e5361b836e3c81 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -102,7 +102,7 @@ __device__ __forceinline__ float exp_func(float a) { template <> __device__ __forceinline__ half exp_func(half a) { -#if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if defined(__HIPCC__) || (__CUDA_ARCH__ > 600) return hexp(a); #else return FromFloat(expf(ToFloat(a))); @@ -144,7 +144,7 @@ struct KeyValuePair { const half2 a2 = __halves2half2(key, value); const half2 b2 = __halves2half2(a.key, a.value); #ifdef PADDLE_WITH_CUDA -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if (__CUDA_ARCH__ > 600) const half2 res = __hadd2(a2, b2); #else float a2_1 = __low2float(a2); diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index 3eb13b55da78f8..60ee81cdd93d77 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -219,13 +219,26 @@ void CalAuxLossKernel(const Context& dev_ctx, auto dispatch_mask_dims = dispatch_mask.dims(); int64_t dispatch_tokens_mask_len = 0; + auto dispatch_tokens_mask_ptr = dispatch_tokens_mask.get_ptr(); if (dispatch_tokens_mask) { - dispatch_tokens_mask_len = dispatch_tokens_mask.get_ptr()->dims()[0]; + const auto mask_dims = dispatch_tokens_mask_ptr->dims(); + const auto dim_size = mask_dims.size(); + const bool is_not_zero_size = (dim_size > 0); + if (is_not_zero_size) { + dispatch_tokens_mask_len = dispatch_tokens_mask_ptr->dims()[0]; + } else { + dispatch_tokens_mask_len = 0; + } } + /* T* l_aux_loss_data = dev_ctx.template Alloc(l_aux_loss); T* seqlen_float_data = dev_ctx.template Alloc(seqlen_float); T* ce_data = dev_ctx.template Alloc(ce); + */ + dev_ctx.template Alloc(l_aux_loss); + dev_ctx.template Alloc(seqlen_float); + dev_ctx.template Alloc(ce); cal_aux_loss(gate_prob.data(), gate_prob_dims[0], @@ -243,9 +256,9 @@ void CalAuxLossKernel(const Context& dev_ctx, use_group, moe_k, clip_min, - l_aux_loss_data, - seqlen_float_data, - ce_data, + l_aux_loss->data(), + seqlen_float->data(), + ce->data(), dev_ctx.stream()); } diff --git a/paddle/phi/kernels/gpu/cub_kv_sorter.cu b/paddle/phi/kernels/gpu/cub_kv_sorter.cu new file mode 100644 index 00000000000000..ce087e4712c5bc --- /dev/null +++ b/paddle/phi/kernels/gpu/cub_kv_sorter.cu @@ -0,0 +1,50 @@ +#include "moe_kernel_impl.h" +namespace phi{ + // ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + + 3; // 额外增加 3 位用于标记 topk的位置 +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32, + stream_); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_, + stream_); + } + return required_storage; +} +} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu index 98a75af605cb60..e45cbab45932a5 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -20,6 +20,7 @@ namespace phi { // -------- getWorkspaceSize -------- // +namespace { template size_t getWorkspaceSize(const int num_rows, const int hidden_size, @@ -56,6 +57,8 @@ size_t getWorkspaceSize(const int num_rows, // "< void apply_moe_dispatch_fwd(const Context &dev_ctx, diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu index 31941833196e4d..535b9bff6ea95d 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/phi/kernels/empty_kernel.h" namespace phi { +namespace { // -------- getWorkspaceSize -------- // template size_t getWorkspaceSize(const int num_rows, @@ -49,6 +50,7 @@ size_t getWorkspaceSize(const int num_rows, // std::cout<<"sorter_ws_size_bytes = "< Date: Tue, 27 May 2025 08:00:34 +0000 Subject: [PATCH 38/71] Verified build_src_rank_and_local_expert_id --- ...ild_src_rank_and_local_expert_id_kernel.cu | 12 ++-- .../paddle/incubate/nn/functional/__init__.py | 2 + ...bate_build_src_rank_and_local_expert_id.py | 58 +++++++++++++++++++ 3 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index d3d8f68e311942..f9508cb47023ac 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -78,10 +78,11 @@ void BuildSrcRankAndLocalExpertIdKernel( const int64_t* expert_num_global_tensor_data = expert_num_global_tensor.data(); - T* src_rank_data = dev_ctx.template Alloc(src_rank); - T* local_expert_id_data = dev_ctx.template Alloc(local_expert_id); + // Hard coded as ernie-core did. + int* src_rank_data = dev_ctx.template Alloc(src_rank); + int* local_expert_id_data = dev_ctx.template Alloc(local_expert_id); - build_srcrank_and_local_expert_id(src_rank_data, + build_srcrank_and_local_expert_id(src_rank_data, local_expert_id_data, expert_num_global_tensor_data, token_num, @@ -92,8 +93,9 @@ void BuildSrcRankAndLocalExpertIdKernel( } // namespace phi -PD_REGISTER_KERNEL(build_srcrank_and_local_expert_id, +PD_REGISTER_KERNEL(build_src_rank_and_local_expert_id, GPU, ALL_LAYOUT, phi::BuildSrcRankAndLocalExpertIdKernel, - float) {} + int32_t, + int64_t) {} diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index d0fbd85e74b321..e1f32ad3d31f2e 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -47,6 +47,7 @@ from .expand_modality_expert_id import expand_modality_expert_id from .cal_aux_loss import cal_aux_loss # from .moe_gate_dispatch_permute import moe_gate_dispatch_permute +from .build_src_rank_and_local_expert_id import build_src_rank_and_local_expert_id __all__ = [ 'fused_multi_head_attention', @@ -69,4 +70,5 @@ "moe_combine", "expand_modality_expert_id", "cal_aux_loss" + "build_src_rank_and_local_expert_id" ] diff --git a/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py new file mode 100644 index 00000000000000..4b89731112bbf5 --- /dev/null +++ b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py @@ -0,0 +1,58 @@ +import os +import unittest + +from op_test import convert_float_to_uint16 +import random +import paddle.nn.functional as F + +import paddle +import numpy as np +import random +import logging + +import paddle +from paddle.nn.clip import _squared_l2_norm + +from ernie_utils.top2_gate import ( + CalAuxLossFunctor, + cal_aux_loss_func, +) +from paddle.incubate.nn.functional import build_src_rank_and_local_expert_id +from ernie_utils.moe_layer import fuse_logging + +logger = logging.getLogger(__name__) + + + +class TestFusedCalculateAuxLoss(unittest.TestCase): + def test_build_src_rank_and_local_expert_id(self): + def orig_func(expert_num_global_list, num_local_experts): + send_rank_cpu = np.concatenate( # TOO SLOW!!! break every thing + [np.full([j], i // num_local_experts, dtype="int32") for i, j in enumerate(expert_num_global_list)], + 0, + ) + local_expert_id_cpu = np.concatenate( + [np.full([j], i % num_local_experts, dtype="int32") for i, j in enumerate(expert_num_global_list)], + 0, + ) + send_rank = paddle.to_tensor(send_rank_cpu) + local_expert_id = paddle.to_tensor(local_expert_id_cpu) + return send_rank, local_expert_id + + def fused_func(expert_num_global_tensor, expert_num_global, num_local_experts): + return build_src_rank_and_local_expert_id( + expert_num_global_tensor, expert_num_global, num_local_experts + ) + + expert_num_global = np.random.randint(0, 512, size=[12 * 8],dtype="int32") + expert_num_global_tensor = paddle.to_tensor(expert_num_global, dtype="int64") + + s1, l1 = orig_func(expert_num_global, 12) + s2, l2 = fused_func(expert_num_global_tensor, expert_num_global, 12) + assert ((s1 - s2) == 0).all(), (s1, s2) + assert ((l1 - l2) == 0).all(), (l1, l2) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 8aa9b3359bfb34e49e10950d2ff286ca3639158f Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 27 May 2025 08:19:30 +0000 Subject: [PATCH 39/71] gate_dispatch_permute has finished --- paddle/phi/infermeta/backward.cc | 6 +- paddle/phi/infermeta/backward.h | 2 +- paddle/phi/kernels/gpu/moe_fleety_utils.h | 108 ------ paddle/phi/kernels/gpu/moe_fuse_bwd_op.h | 314 ++++++++++++++++++ .../moe_gate_dispatch_permute_grad_kernel.cu | 52 +-- .../moe_gate_dispatch_permute_grad_kernel.h | 2 +- paddle/phi/ops/yaml/backward.yaml | 2 +- .../functional/moe_gate_dispatch_permute.py | 47 ++- 8 files changed, 399 insertions(+), 134 deletions(-) delete mode 100644 paddle/phi/kernels/gpu/moe_fleety_utils.h diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 30591de9397228..2b7af9021dcd4f 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1253,7 +1253,7 @@ void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, int64_t capacity, int64_t world_size, MetaTensor* x_grad, - MetaTensor* gate_logtis_grad){ + MetaTensor* gate_logits_grad){ auto y_grad_dims = y_grad.dims(); PADDLE_ENFORCE_EQ( @@ -1268,8 +1268,8 @@ void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, int64_t num_rows = scatter_index.dims()[1]; x_grad->set_dims({num_rows, hidden_size}); x_grad->set_dtype(y_grad.dtype()); - gate_logtis_grad->set_dims({num_rows, num_experts}); - gate_logtis_grad->set_dtype(phi::DataType::FLOAT32); + gate_logits_grad->set_dims({num_rows, num_experts}); + gate_logits_grad->set_dtype(phi::DataType::FLOAT32); } void MultiDotGradInferMeta(const std::vector& x, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index a2f61c761228a1..d336bb0130c0ae 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -478,7 +478,7 @@ void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, int64_t capacity, int64_t world_size, MetaTensor* x_grad, - MetaTensor* gate_logtis_grad); + MetaTensor* gate_logits_grad); void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, diff --git a/paddle/phi/kernels/gpu/moe_fleety_utils.h b/paddle/phi/kernels/gpu/moe_fleety_utils.h deleted file mode 100644 index 027ed66fde6bc3..00000000000000 --- a/paddle/phi/kernels/gpu/moe_fleety_utils.h +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include "paddle/extension.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/include/kernels.h" - -namespace phi { - -template -void ContiguousKernel(const Context& dev_ctx, - const DenseTensor& input, - DenseTensor* out); - -} // namespace phi - -namespace fleety_utils { - -namespace internal { - -template -struct TensorHasStrideImpl { -private: - struct YesType {}; - struct NoType {}; - - template - static YesType Check(decltype(std::declval().is_contiguous())) { - return 0; - } - - template - static NoType Check(...) { - return 0; - } - -public: - static constexpr bool kValue = - std::is_same(false))>::value; -}; - - -template -struct ContiguousTensorHelperImpl { - static_assert(_SupportStride, "_SupportStride should be true"); - - static bool IsContiguousTensor(const DenseT &t) { - return t.meta().is_contiguous(); - } - - static typename std::enable_if<_SupportStride, void>::type TensorTrans2Contiguous(DenseT *t) { - if (t != nullptr && t->initialized() && !t->meta().is_contiguous()) { - auto place = t->place(); - auto is_gpu_place = place.GetType() == phi::AllocationType::GPU; - PD_CHECK(is_gpu_place, "Only support GPU place"); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto gpu_ctx = reinterpret_cast(dev_ctx); - auto dtype = t->dtype(); - - PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(dtype, "contiguous_kernel", ([&] { - DenseT out; - phi::ContiguousKernel(*gpu_ctx, *t, &out); - *t = out; - })); - } - } - - static void TensorTrans2Contiguous(PaddleT *t) { - if (t != nullptr) { - if (!t->is_dense_tensor()) { - PD_THROW("Trans2Contiguous only supports DenseTensor"); - } - auto *dense_t = static_cast(t->impl().get()); - TensorTrans2Contiguous(dense_t); - } - } -}; - - -template -struct ContiguousTensorHelperImpl { - static bool IsContiguousTensor(const DenseT &t) { return true; } - static void TensorTrans2Contiguous(DenseT *t) {} - static void TensorTrans2Contiguous(PaddleT *t) {} -}; - - -} // namespace internal - - -inline constexpr bool SupportStride() { - return internal::TensorHasStrideImpl::kValue; -} - -using ContiguousTensorHelper = internal::ContiguousTensorHelperImpl; - -inline bool IsContiguousTensor(const phi::DenseTensor &t) { - return ContiguousTensorHelper::IsContiguousTensor(t); -} - -inline void TensorTrans2Contiguous(phi::DenseTensor *t) { - return ContiguousTensorHelper::TensorTrans2Contiguous(t); -} - -inline void TensorTrans2Contiguous(paddle::Tensor *t) { - return ContiguousTensorHelper::TensorTrans2Contiguous(t); -} - -} // namespace fleety_utils diff --git a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h b/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h index e69de29bb2d1d6..c349b0ca73330d 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h @@ -0,0 +1,314 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +template +__global__ void topk_grad_with_mask(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts // e + ){ + // init dx to zero + for (int i = blockIdx.x; i < num_rows; i+=gridDim.x){ + int base_grad = i * num_experts; + for (int j = threadIdx.x; j < num_experts; j+=blockDim.x){ + dx[base_grad + j] = static_cast(0); + } + __syncthreads(); + int base_index = i * k; + for (int j = threadIdx.x; j < k; j+=blockDim.x){ + int64_t idx = topk_idx[base_index + j]; + if (combine_weights[base_index + j] > static_cast(0)){ + dx[base_grad + idx] = dy[base_index + j]; + } + } + } +} + + +// y=zero_part(topk(x)) 的反向过程 +// x: [s,e] +// dy: [s,k] +// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> conbine_weights: [s, k] +template +void topk_grad_with_mask_launcher( + const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts, // e + cudaStream_t stream){ + + int blocks = num_rows; + int threads = 1024; + + topk_grad_with_mask<<>>(dy, + topk_idx, + combine_weights, + dx, + num_rows, + k, + num_experts + ); +} + +template +__global__ void gather_with_mask_permute_kernel(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num, + int64_t capacity, + int64_t world_size, + int64_t num_local_experts + ){ + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ + idx < N; idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active){ + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f){ + int64_t remaining_after_irank = id % (num_local_experts * capacity); + + int64_t irank = id / (num_local_experts * capacity); + int64_t local_iexpert = remaining_after_irank / capacity; + int64_t row_in_expert = remaining_after_irank % capacity; + int64_t permuted_id = local_iexpert * (world_size * capacity) + irank * capacity + row_in_expert; + int64_t in_offset = permuted_id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +__global__ void gather_with_mask_kernel(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num + ){ + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ + idx < N; idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active){ + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f){ + int64_t in_offset = id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} + +inline int64_t max_shared_s_num(int64_t num_rows, int64_t dim, int64_t threads, int64_t vec_size) { + if ((threads * vec_size) % dim == 0) { + return min(num_rows, threads * vec_size / dim); + } else { + int64_t max_res = DivUp(threads * 4, dim); + for (int64_t idx = 0; idx < num_rows * dim; idx += vec_size * threads) { + int64_t si_start = idx / dim; + int64_t si_end = min(num_rows * dim, idx + vec_size * threads - 1) / dim; + max_res = max(max_res, (si_end - si_start + 1)); + } + return min(num_rows, max_res); + } +} + + +template +void gather_with_mask_launcher(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s,k,d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t num_active, + cudaStream_t stream, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1, + int64_t capacity = -1 +){ + int numel = num_rows * dim; +// #ifdef DEBUG_MOE_OP +// std::cerr << "[DEBUG-BWD] launch kernel, num_active=" << num_active << ", num_rows=" << num_rows << ", dim=" << dim << std::endl; +// #endif + + int64_t threads = 512; + if (dim % 4 == 0) { + int64_t blocks = DivUp(DivUp(numel, 4), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +// #ifdef DEBUG_MOE_OP +// std::cerr << "[DEBUG-BWD] gather_with_mask with vectorized, s_shared_num=" << s_shared_num << ", block=" << blocks << std::endl; +// #endif + if (!use_all2all_permute) { + gather_with_mask_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); + gather_with_mask_permute_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } else { + int64_t blocks = DivUp(DivUp(numel, 1), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +// #ifdef DEBUG_MOE_OP +// std::cerr << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" << s_shared_num << ", block=" << blocks << std::endl; +// #endif + + if (!use_all2all_permute) { + gather_with_mask_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + gather_with_mask_permute_kernel<<>>( + dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } +} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu index 134da55ec26eef..e8c7146285bb4a 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu @@ -1,8 +1,11 @@ #include "paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h" -#include "paddle/phi/core/kernel_registry.h" // 注册相关 -#include "paddle/phi/backends/gpu/gpu_context.h" // context相关 +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/gpu/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace phi{ template @@ -12,7 +15,7 @@ void apply_moe_dispatch_bwd( const int* scatter_index, // [s, k] const float* combine_weights_grad, const int* expert_id, // [s, k] - float* gate_logtis_grad, + float* gate_logits_grad, T* x_grad, int64_t num_rows, int64_t k, @@ -31,7 +34,7 @@ void apply_moe_dispatch_bwd( topk_grad_with_mask_launcher(combine_weights_grad, expert_id, combine_weights, - gate_logtis_grad, + gate_logits_grad, num_rows, k, num_experts, stream); } @@ -43,8 +46,8 @@ void moe_dispatch_bwd(const Context& dev_ctx, const DenseTensor& expert_id, // [s, k] const DenseTensor& y_grad, // [num_experts * capacity, h] const DenseTensor& combine_weights_grad, // [s, k] - const DenseTensor&x_grad, - const DenseTensor& gate_logtis_grad, + DenseTensor& x_grad, + DenseTensor& gate_logits_grad, int64_t capacity, bool use_all2all_permute = false, int64_t world_size = -1, @@ -53,12 +56,8 @@ void moe_dispatch_bwd(const Context& dev_ctx, int64_t num_rows = combine_weights_dims[0]; int64_t k = combine_weights_dims[1]; auto y_grad_dims = y_grad.dims(); -#ifdef MOE_OPS_AUTO - int64_t hidden_size = y_grad_dims[2]; -#else int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; -#endif - int64_t num_experts = gate_logtis_grad.dims()[1]; + int64_t num_experts = gate_logits_grad.dims()[1]; apply_moe_dispatch_bwd( y_grad.data(), @@ -66,7 +65,7 @@ void moe_dispatch_bwd(const Context& dev_ctx, scatter_index.data(), combine_weights_grad.data(), expert_id.data(), - gate_logtis_grad.data(), + gate_logits_grad.data(), x_grad.data(), num_rows, k, @@ -90,20 +89,26 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, int64_t capacity, int64_t world_size, DenseTensor* x_grad, - DenseTensor* gate_logtis_grad){ + DenseTensor* gate_logits_grad){ int64_t num_local_experts = y_grad.dims()[0]; auto scatter_index_dims = scatter_index.dims(); - DenseTensor t_scatter_index = phi::Empty(dev_ctx, {scatter_index_dims[1], scatter_index_dims[0]}); - phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); - fleety_utils::TensorTrans2Contiguous(&t_scatter_index); + + DenseTensor t_scatter_index; + phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); + DenseTensor t_scatter_index_; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_); + + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(gate_logits_grad); moe_dispatch_bwd(dev_ctx, combine_weights, - t_scatter_index, + t_scatter_index_, expert_id, y_grad, combine_weights_grad, - x_grad, - gate_logtis_grad, + *x_grad, + *gate_logits_grad, capacity, true, /*use_all2all_permute*/ world_size, @@ -111,3 +116,12 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, } } // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_permute_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h index 10ecb4978b71ff..cce8c2c7cc0c4a 100644 --- a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h @@ -27,5 +27,5 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, int64_t capacity, int64_t world_size, DenseTensor* x_grad, - DenseTensor* gate_logtis_grad); + DenseTensor* gate_logits_grad); } // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 92895c820709ac..ee1790d2faa404 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2259,7 +2259,7 @@ - backward_op : moe_gate_dispatch_permute_grad forward : moe_gate_dispatch_permute (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, int64_t world_size) - output : Tensor(x_grad), Tensor(gate_logtis_grad) + output : Tensor(x_grad), Tensor(gate_logits_grad) infer_meta : func : MoeGateDispatchPermuteGradInferMeta kernel : diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py index ea4e54cc653eba..0874e2a9d71d79 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py @@ -62,4 +62,49 @@ def moe_gate_dispatch_permute( } helper.append_op(type='moe_gate_dispatch_permute', inputs=inputs, outputs=outputs, attrs=attrs) - return y, combine_weights, scatter_index, expert_offset, expert_id \ No newline at end of file + return y, combine_weights, scatter_index, expert_offset, expert_id + +# # 定义输入参数 +# num_rows = 10 # 示例行数 +# hidden_size = 128 # 隐藏层维度 +# num_experts = 4 # 专家数 +# world_size = 2 # 分布式世界大小 +# k = 2 # 选择的Top-k专家 +# capacity = 5 # 每个专家的处理容量 + +# # 确保num_experts可以被world_size整除 +# assert num_experts % world_size == 0 + +# # 生成输入数据 +# x = paddle.randn([num_rows, hidden_size], dtype='float32') +# gate_logits = paddle.randn([num_rows, num_experts], dtype='float32') +# x.stop_gradient = False +# gate_logits.stop_gradient = False + +# # 可选的修正偏差 +# # corr_bias = paddle.randn([num_rows], dtype='float32') +# corr_bias = None + +# # 调用封装的API +# y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( +# x=x, +# gate_logits=gate_logits, +# corr_bias=corr_bias, +# k=k, +# capacity=capacity, +# world_size=world_size +# ) + +# # 打印输出结果的形状和类型,验证结果 +# print("Output y shape:", y.shape) +# print("Combine weights shape:", combine_weights.shape) +# print("Scatter index shape:", scatter_index.shape) +# print("Expert offset shape:", expert_offset.shape) +# print("Expert ID shape:", expert_id.shape) + +# a = paddle.sum(y)+paddle.sum(combine_weights)+paddle.sum(scatter_index)+paddle.sum(expert_offset)+paddle.sum(expert_id) +# a.backward() + +# print("Gradient of x:", x.grad) +# print("Gradient of gate_logits:", gate_logits.grad) +# print("Gradient of corr_bias:", corr_bias.grad) \ No newline at end of file From 7e529c38fe714cd8895e59e1b288b21e54cbe1a5 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 02:23:00 +0000 Subject: [PATCH 40/71] Verified fused_rms_norm_ext(with bwd) and int_bincount. --- paddle/phi/api/lib/data_transform.cc | 3 - paddle/phi/infermeta/backward.cc | 1 + paddle/phi/kernels/gpu/int_bincount.cu | 15 ++- .../phi/kernels/gpu/layer_norm_cuda_kernel.cu | 9 +- paddle/phi/kernels/layer_norm_cuda_kernel.h | 1 - paddle/phi/ops/yaml/backward.yaml | 2 +- paddle/phi/ops/yaml/ops.yaml | 2 +- .../nn/functional/fused_rms_norm_ext.py | 11 +- .../incubate/nn/functional/int_bincount.py | 10 +- .../test_incubate_fused_rmsnorm_ext.py | 119 ++++++++++++++++++ .../legacy_test/test_incubate_int_bincount.py | 30 +++++ 11 files changed, 179 insertions(+), 24 deletions(-) create mode 100644 test/legacy_test/test_incubate_fused_rmsnorm_ext.py create mode 100644 test/legacy_test/test_incubate_int_bincount.py diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 19a690d093e4a8..c4feaded4174bb 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -418,9 +418,6 @@ std::shared_ptr PrepareData( out = Trans2Contiguous(out); return std::make_shared(std::move(out)); } - VLOG(6) << "tensor_in: "<< tensor_in; - auto tmp = std::static_pointer_cast(tensor_in); - VLOG(6) << "tensor_in.size(): " << tmp->dims(); return std::static_pointer_cast(tensor_in); } phi::DenseTensor out = TransformData( diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e53e3b35c8df87..fdba6df2b3ed5f 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1982,6 +1982,7 @@ void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, x_grad->set_dims(common::make_ddim({num_rows, hidden_size})); x_grad->set_dtype(y_grad.dtype()); +} void FusedRMSNormGradInferMeta(const MetaTensor &x, const MetaTensor &scale, const MetaTensor &invvar, diff --git a/paddle/phi/kernels/gpu/int_bincount.cu b/paddle/phi/kernels/gpu/int_bincount.cu index a662f97b64777c..287ef0cbe2d733 100644 --- a/paddle/phi/kernels/gpu/int_bincount.cu +++ b/paddle/phi/kernels/gpu/int_bincount.cu @@ -63,8 +63,8 @@ void IntBincount(const Context& ctx, const DenseTensor &x, int64_t low, int64_t int64_t bins_width = high - low; PD_CHECK(bins_width + 1 < std::numeric_limits::max()); - auto bins_dtype = TransToPhiDataType(out_dtype); - DenseTensor bins = phi::Empty(ctx, {bins_width}); + auto bins_dtype = TransToDataType(out_dtype); + // auto x_dytpe = x.dtype(); auto low_v = static_cast(low); @@ -72,18 +72,21 @@ void IntBincount(const Context& ctx, const DenseTensor &x, int64_t low, int64_t PD_CHECK(static_cast(low_v) == low); PD_CHECK(static_cast(high_v) == high); const auto *x_data = x.data(); - void *bins_data = bins.data(); int64_t n = x.numel(); if (bins_dtype == phi::DataType::INT32) { - IntBincountImpl(ctx, x_data, n, low_v, high_v, static_cast(bins_data)); + ctx.template Alloc(out); + uint32_t *out_ptr = static_cast(out->data()); + IntBincountImpl(ctx, x_data, n, low_v, high_v, out_ptr); } else if (bins_dtype == phi::DataType::INT64) { using ULLI = unsigned long long int; + ctx.template Alloc(out); static_assert(sizeof(int64_t) == sizeof(ULLI)); - IntBincountImpl(ctx, x_data, n, low_v, high_v, static_cast(bins_data)); + // WARNING: unsafe type cast used in original impl. + ULLI* out_ptr = static_cast (out->data()); + IntBincountImpl(ctx, x_data, n, low_v, high_v, out_ptr); } else { PD_THROW("Only support INT32 and INT64, but got %s", bins_dtype); } - out = &bins; } } // namespace phi diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu index 9f40cf79c2a7d0..3f9df7fcb9a1db 100644 --- a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -65,7 +65,7 @@ void RMSLnBwd(const Context& ctx, const DenseTensor &x, const DenseTensor &scale, const DenseTensor &invvar, - const DenseTensor &dy, + const DenseTensor &y_grad, float epsilon, DenseTensor* x_grad, DenseTensor* scale_grad) { @@ -73,15 +73,14 @@ void RMSLnBwd(const Context& ctx, const auto &x_shape = x.dims(); rows = x_shape[0]; cols = x_shape[1]; - *x_grad = phi::EmptyLike(ctx, x); - *scale_grad = phi::EmptyLike(ctx, scale); - + ctx.template Alloc(x_grad); + ctx.template Alloc(scale_grad); cuda_rms_norm_gradient( ctx, x, scale, invvar, - dy, + y_grad, rows, cols, epsilon, diff --git a/paddle/phi/kernels/layer_norm_cuda_kernel.h b/paddle/phi/kernels/layer_norm_cuda_kernel.h index b54d44e2eb0825..1dcfebf890be98 100644 --- a/paddle/phi/kernels/layer_norm_cuda_kernel.h +++ b/paddle/phi/kernels/layer_norm_cuda_kernel.h @@ -1011,7 +1011,6 @@ void HostRMSNormGradient( const Context& ctx, const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; auto place = input.place(); DenseTensor part_grad_gamma = phi::Empty(ctx, {part_size, n2}); - cuComputePartGradGammaBeta<<>>( dout, input.data(), diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 58463ac518f11e..3c74657a806387 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3950,7 +3950,7 @@ - backward_op: fused_rms_norm_grad forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) - args: (Tensor x, Tensor scale, Tensor y_grad, Tensor invvar_grad, float epsilon) + args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) output: Tensor(x_grad), Tensor(scale_grad) infer_meta: func: FusedRMSNormGradInferMeta diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 80c354bcfa1a67..b1733a357b5a1b 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5729,7 +5729,7 @@ traits : paddle::dialect::ForwardOnlyTrait - op: int_bincount - args: (Tensor x, int low, int high, int dtype) + args: (Tensor x, int64_t low, int64_t high, int64_t dtype) output: Tensor(out) infer_meta: func: IntBincountInferMeta diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py index 35fa675d94bb05..bb0a2d9a8245fe 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py @@ -1,9 +1,10 @@ # File: python/paddle/incubate/nn/functional/layer_norm_cuda.py -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.data_feeder import convert_dtype import paddle +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper +from paddle import _C_ops -def fused_rms_norm_ext(x, scale, bias=None, epsilon=1e-5, name=None): +def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): """ Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation. Args: @@ -17,6 +18,10 @@ def fused_rms_norm_ext(x, scale, bias=None, epsilon=1e-5, name=None): mean (Tensor): Tensor of shape [rows], the mean of each row. invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ + if in_dynamic_or_pir_mode(): + return _C_ops.fused_rms_norm( + x,scale,epsilon + ) helper = LayerHelper('fused_rms_norm', **locals()) dtype = convert_dtype(x.dtype) y = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py index 32a6697d1c210f..171fd29ed68483 100644 --- a/python/paddle/incubate/nn/functional/int_bincount.py +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -1,11 +1,13 @@ import paddle -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.data_feeder import convert_dtype +from paddle import _C_ops +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper def int_bincount(x, low, high, dtype=None, name=None): - # if in_dynamic_or_pir_mode(): - # return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) + if in_dynamic_or_pir_mode(): + return _C_ops.int_bincount(x, low, high, dtype) + helper = LayerHelper("int_bincount", **locals()) out_dtype = dtype if dtype is not None else x.dtype y = helper.create_variable_for_type_inference(dtype=out_dtype) diff --git a/test/legacy_test/test_incubate_fused_rmsnorm_ext.py b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py new file mode 100644 index 00000000000000..cf9877cb6f5994 --- /dev/null +++ b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import fused_rms_norm_ext + +# 假设 fused_rms_norm_ext 已经被导入 +# from your_module import fused_rms_norm_ext + +class TestFusedRMSNorm(unittest.TestCase): + def setUp(self): + # 设置随机种子以确保结果可复现 + paddle.seed(2023) + np.random.seed(2023) + + def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): + """ + 使用 Paddle 原生操作实现 RMS Normalization 作为参考 + """ + # 计算均方根 + variance = paddle.mean(paddle.square(x), axis=-1, keepdim=True) + # 计算 RMS + rms = paddle.sqrt(variance + epsilon) + # 归一化 + y = x / rms + # 应用缩放 + y = y * scale.reshape([1, -1]) + # 应用偏置(如果有) + if bias is not None: + y = y + bias.reshape([1, -1]) + + # 返回归一化后的张量、均值(RMS Norm 中为0)和逆标准差 + return y, (1.0 / rms).squeeze(-1) + + def test_2d_input(self): + # 测试 2D 输入 + rows, cols = 32, 64 + x = paddle.randn([rows, cols]) + scale = paddle.randn([cols]) + + # 使用我们的实现 + y_fused, invvar_fused = fused_rms_norm_ext(x, scale) + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + + # 验证结果 + np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5) + + + def test_without_bias(self): + # 测试没有偏置的情况 + rows, cols = 32, 64 + x = paddle.randn([rows, cols]) + scale = paddle.randn([cols]) + + # 使用我们的实现 + y_fused, invvar_fused = fused_rms_norm_ext(x, scale) + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + + # 验证结果 + np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5) + + def test_backward(self): + # 测试反向传播 + rows, cols = 16, 32 + x = paddle.randn([rows, cols], dtype='float32') + x.stop_gradient = False + scale = paddle.randn([cols], dtype='float32') + scale.stop_gradient = False + + # 前向传播 + y_fused, invvar = fused_rms_norm_ext(x, scale) + + # 计算损失并反向传播 + loss = paddle.mean(y_fused) + loss.backward() + + # 获取梯度 + x_grad_fused = x.grad.clone() + scale_grad_fused = scale.grad.clone() + + # 重置梯度 + x.clear_gradient() + scale.clear_gradient() + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + loss_ref = paddle.mean(y_ref) + loss_ref.backward() + + # 获取参考梯度 + x_grad_ref = x.grad + scale_grad_ref = scale.grad + + # 验证梯度 + np.testing.assert_allclose(x_grad_fused, x_grad_ref, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(scale_grad_fused, scale_grad_ref, rtol=1e-4, atol=1e-4) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/legacy_test/test_incubate_int_bincount.py b/test/legacy_test/test_incubate_int_bincount.py new file mode 100644 index 00000000000000..daf983d4a65a97 --- /dev/null +++ b/test/legacy_test/test_incubate_int_bincount.py @@ -0,0 +1,30 @@ +import paddle +import numpy as np +import unittest +from paddle.incubate.nn.functional import int_bincount + +class TestIntBincount(unittest.TestCase): + def setUp(self): + paddle.set_device('gpu') + + def test_basic(self): + x = paddle.to_tensor([1, 2, 3, 1, 2, 3], dtype=paddle.int32) + out = int_bincount(x, low=1, high=4, dtype=paddle.int32) + expected = np.array([2, 2, 2,0]) + np.testing.assert_array_equal(out.numpy(), expected) + + def test_empty_input(self): + x = paddle.to_tensor([], dtype=paddle.int32) + out = int_bincount(x, low=0, high=10, dtype=paddle.int32) + self.assertEqual(out.shape, [11]) + self.assertEqual(out.sum().item(), 0) + + def test_different_dtypes(self): + x = paddle.to_tensor([1, 3, 5, 3, 1], dtype=paddle.int64) + out = int_bincount(x, low=1, high=6, dtype=paddle.int64) + expected = np.array([2, 0, 2, 0, 1, 0]) + np.testing.assert_array_equal(out.numpy(), expected) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 5d2e65da31185d3b5b66943abfdb5e6f74143c87 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 03:05:21 +0000 Subject: [PATCH 41/71] Add stage2 fwd and bwd optests. --- ...moe_gate_dispatch_partial_nosoftmaxtopk.py | 199 ++++++++++++++++++ ...st_incubate_moe_gate_dispatch_w_permute.py | 150 +++++++++++++ ...ncubate_moe_gate_dispatch_w_permute_bwd.py | 142 +++++++++++++ 3 files changed, 491 insertions(+) create mode 100644 test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py create mode 100644 test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py create mode 100644 test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py new file mode 100644 index 00000000000000..05fbfe3aed7eb8 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -0,0 +1,199 @@ +import unittest +import sys +from functools import partial +import numpy as np +from collections import namedtuple + +import paddle +from paddle.autograd import PyLayer +import paddle.nn.functional as F +from ernie_utils.moe_layer_uneven import GateDispatch +from paddle.incubate.nn.functional import moe_gate_dispatch_partial_nosoftmaxtopk, moe_gate_dispatch + + +def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(): + import moe_ops_partial_nosoftmaxtopk + + s, d, e = 4, 100, 8 + k, cap = 4, 3 + local_expert_num = 2 + + # x = paddle.randn([s, d]) + # gate_logits = paddle.randn([s, e]) + x = paddle.arange(1, s + 1).unsqueeze(-1).expand([s, d]).astype("bfloat16") + x_ = x.clone().detach() + + t = ((paddle.arange(0, e)).unsqueeze(0) + paddle.arange(0, -s, -1).unsqueeze(-1)) % e + gate_logits = (1 / (t + 1)).astype("float32") + # gate_logits = F.softmax(paddle.randn([s,e]),-1).astype('float32') + gate_logits_ = gate_logits.clone().detach() + s = x.shape[0] + d = x.shape[1] + e = gate_logits.shape[1] + x.stop_gradient = False + x_.stop_gradient = False + gate_logits.stop_gradient = False + gate_logits_.stop_gradient = False + print(f"gate_logits:{gate_logits}") + + def check_ascend(index_rev, chunks): + for idx in index_rev.split(chunks.tolist()): + if len(idx) > 2: + assert (paddle.diff(idx) >= 0).all(), (index_rev,) + + ys, comm, scatter_idx = [], [], [] + for ilocal_expert in range(0, e, local_expert_num): + combine_weihgts, expert_id = gate_logits.topk(k=k, axis=1) + ( + y, + combine_weihgts, + scatter_index, + scatter_index_rev, + expert_offset, + expert_num_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, + combine_weihgts, + expert_id.astype("int32"), + k=k, + capacity=cap, + num_experts=gate_logits.shape[-1], + use_pad=False, + expert_start_index=ilocal_expert, + expert_end_index=ilocal_expert + local_expert_num, # k # cap + reverse_token_drop=False, + ) + check_ascend(scatter_index_rev, expert_num_local) + print(f"y:{y.mean(-1)}") + print(f"combine_weihgts:{combine_weihgts}") + print(f"expert_num_local:{expert_num_local}") + print(f"scatter_index:{scatter_index.transpose([1,0])}") + print(f"scatter_index_rev:{scatter_index_rev}") + + ys.append(y) + comm.append(combine_weihgts) + scatter_idx.append(scatter_index) + + comm_sum = paddle.stack(comm).sum(0) + ys_sum = paddle.concat(ys) + + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = moe_gate_dispatch( + x_, + gate_logits_, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + valid_y = y_.sum(-1) > 0.0 + y_2 = y_[valid_y].squeeze() + + print( + f""" + y: {ys_sum.astype("float32").mean(axis=-1)} + y_: {y_2.astype("float32").mean(axis=-1)} + + comm-weight: {comm_sum} + comm-weight_: {combine_weihgts_} + + expert_id:{expert_id} + scatter_index:{scatter_index} + scatter_index_rev: {scatter_index_rev} + expert_num_global:{expert_offset} + expert_num_local:{expert_num_local} + """ + ) + + print(f"<<< begin backward>>>") + + assert combine_weihgts_.shape == combine_weihgts.shape, (combine_weihgts_.shape, combine_weihgts.shape) + + dysum, dcombine_weights_sum = paddle.ones_like(ys_sum), paddle.randn(comm_sum.shape).astype(comm_sum.dtype) + dy_, dcombine_weights_ = paddle.ones_like(y_), paddle.ones_like(combine_weihgts_) + dy_[~valid_y] = 0 + + y_shapes = [len(y) for y in ys] + for dyy, yy, commm in zip( + paddle.split(dysum, y_shapes), + ys, + comm, + ): + print(f"dyy:{dyy.shape}, {yy.shape} {commm.shape}") + paddle.autograd.backward([yy, commm], [dyy, dcombine_weights_sum]) + print(x.grad.astype("float32").mean(axis=-1)) + print(f"bwd original:{y_.shape} {dy_.shape}") + paddle.autograd.backward([y_, combine_weihgts_], [dy_, dcombine_weights_]) + + print(x_.grad.astype("float32").mean(axis=-1)) + + print( + f""" + x: {x.grad.astype('float32').mean(axis=-1)} + x_: {x_.grad.astype('float32').mean(axis=-1)} + """ + ) + + + + +def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): + import moe_ops_partial_nosoftmaxtopk + + S, E, D = 3, 4, 3 + k = 2 + capacity = 2 + x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") + cw = paddle.randn([S, k]) + eid = paddle.to_tensor([[0, 1], [0, 1], [0, 2]], dtype="int32") # 1 # 2 # 3 + ( + y, + cw_, + idx, + idx_rev, + num_ex_global, + num_ex_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, cw, eid, k, capacity, E, False, 0, 2, reverse_token_drop=True + ) + + y0, y1 = y.split([i for i in num_ex_local.tolist() if i > 0]) + assert y0[:, 0].astype("int32").tolist() == [2, 3], y0[:, 0] + assert y1[:, 0].astype("int32").tolist() == [1, 2] + + +def test_moe_ops_partial_nosoftmax_topk_empty_output(): + import moe_ops_partial_nosoftmaxtopk + + S, E, D = 3, 4, 3 + k = 2 + capacity = 2 + x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") + cw = paddle.randn([S, k]) + eid = paddle.to_tensor([[0, 1], [0, 1], [0, 2]], dtype="int32") # 1 # 2 # 3 + ( + y, + cw_, + idx, + idx_rev, + num_ex_global, + num_ex_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, cw, eid, k, capacity, E, False, 3, 4, reverse_token_drop=True + ) + assert all([i == 0 for i in num_ex_local.tolist()]), num_ex_local + + +class TestAddition(unittest.TestCase): + + def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(self): + test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op() + + def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(self): + test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop() + + def test_moe_ops_partial_nosoftmax_topk_empty_output(self): + test_moe_ops_partial_nosoftmax_topk_empty_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py new file mode 100644 index 00000000000000..0b1909d287be76 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py @@ -0,0 +1,150 @@ +# !/usr/bin/env python3 +import os +import sys +import unittest +import contextlib + +import numpy as np +import random +import time +import paddle +from paddle import _C_ops +from paddle.autograd import PyLayer +import paddle.nn.functional as F +from paddle.incubate.nn.functional import moe_gate_dispatch, moe_gate_dispatch_permute + + +os.environ["FLAGS_flash_attn_version"] = "v1" +os.environ["FLAGS_cudnn_deterministic"] = "1" +os.environ["FLAGS_embedding_deterministic"] = "1" + + +class TestFused(unittest.TestCase): + + def test_moe_ops(self): + """ + test `moe-ops` w/ bias + """ + S, E, D = 8192, 64, 128 + k = 4 + x = paddle.randn([S, D], dtype="bfloat16") + gate_logits = paddle.randn([S, E], dtype="float32") + x_ = x.clone() + gate_logits_ = gate_logits.clone() + x.stop_gradient = True + x_.stop_gradient = True + gate_logits.stop_gradient = True + gate_logits_.stop_gradient = True + bias = paddle.zeros([E], dtype="float32") + cap = 512 + + y, combine_weihgts, scatter_index, expert_offset_, expert_id_ = moe_gate_dispatch( + x, + gate_logits, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = moe_gate_dispatch( + x_, + gate_logits_, + bias + 1, # +1也不会破坏路由结果 + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + bias_unbalanced = bias.clone() + bias_unbalanced[0] += 1 + y__, combine_weihgts__, scatter_index__, expert_offset__, expert_id__ = moe_gate_dispatch( + x_, + gate_logits_, + bias_unbalanced, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + np.testing.assert_equal( + y.astype("float32").numpy(), y_.astype("float32").numpy(), err_msg="incubate w bias not match" + ) + # bias 不影响 prob 概率 + np.testing.assert_equal( + combine_weihgts.astype("float32").numpy(), + combine_weihgts_.astype("float32").numpy(), + err_msg="incubate w bias not match", + ) + np.testing.assert_( + (y.astype("float32").numpy(0) != y__.astype("float32").numpy()).any(), + ) + + +class TestDispatchPermute(unittest.TestCase): + def get_detached_input(self, input, prob): + ret_input = input.detach() + ret_prob = prob.detach() + ret_input.stop_gradient = input.stop_gradient + ret_prob.stop_gradient = prob.stop_gradient + return ret_input, ret_prob + + def get_stage_input_list(self, x, world_size, stage): + print(world_size, stage, x.shape) + x = x.reshape([world_size * stage, -1, x.shape[-1]]) + stage_input_list = [] + x_list = paddle.split(x, num_or_sections=(world_size * stage), axis=0) + for stage_id in range(stage): + stage_input_list.append(paddle.unsqueeze(paddle.concat(x_list[stage_id::stage], axis=0), axis=0)) + stage_input_list = paddle.concat(stage_input_list, axis=0) + return stage_input_list + + def test_moe_permute_ops(self): + paddle.seed(2025) + + test_cases = [(8, 4, 2), (64, 16, 32), (1024, 1024, 1024), (8, 2, 4), (4096, 4096, 4096)] + cases = list(zip(*test_cases)) + for _, case in enumerate(cases): + world_size, num_experts, num_tokens, k, hidden_size = case + capacity = num_tokens // k + stages = num_experts // world_size + + input = paddle.randn([num_tokens, hidden_size], dtype="float32") + prob_logits = paddle.randn([num_tokens, num_experts], dtype="float32") + prob = F.softmax(prob_logits, axis=-1) + input.stop_gradient = False + prob.stop_gradient = False + + compat_args = (None,) + + ref_input, ref_prob = self.get_detached_input(input, prob) + ( + ref_dispatched_input, + ref_combine_weights_unnorm, + ref_scatter_index, + ref_dispatch_mask, + _, + ) = moe_gate_dispatch(ref_input, ref_prob, *compat_args, k=k, capacity=capacity, use_pad=True) + + ref_stage_input_list = self.get_stage_input_list(ref_dispatched_input, world_size, stages) + + test_input, test_prob = self.get_detached_input(input, prob) + ( + test_dispatched_input, + test_combine_weights_unnorm, + test_scatter_index, + test_dispatch_mask, + _, + ) = moe_gate_dispatch_permute( + test_input, test_prob, *compat_args, k=k, capacity=capacity, world_size=world_size + ) + + np.testing.assert_equal( + test_dispatched_input.shape, ref_stage_input_list.shape, err_msg="moe_permute_ops not match" + ) + np.testing.assert_equal( + test_dispatched_input._md5sum(), ref_stage_input_list._md5sum(), err_msg="moe_permute_ops not match" + ) + + +if __name__ == "__main__": + + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py new file mode 100644 index 00000000000000..a1a9f61aee3440 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py @@ -0,0 +1,142 @@ +# !/usr/bin/env python3 +import os +import sys +import unittest +import contextlib + +import numpy as np +import random +import time +import paddle +from paddle import _C_ops +from paddle.autograd import PyLayer +import paddle.nn.functional as F +from paddle.incubate.nn.functional import moe_gate_dispatch, moe_gate_dispatch_permute + +batch_size = 4 +hidden_size = 2 +k = 16 +capacity = 2 +num_experts = 16 + +world_size = 2 + + +class TestLayer(paddle.nn.Layer): + def forward(self, x, gate_prob, k, capacity): + y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch( + x, gate_prob, None, k, capacity, True + ) + return y, combine_weights, scatter_index, expert_offset, expert_id + + +class TestLayerPermute(paddle.nn.Layer): + def forward(self, x, gate_prob, k, capacity): + y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( + x, gate_prob, None, k, capacity, world_size=world_size + ) + return y, combine_weights, scatter_index, expert_offset, expert_id + + +def check_backward_correctness(layer_cls): + paddle.seed(1024) + + dtype = "bfloat16" + layer = layer_cls() + input = paddle.randn([batch_size, hidden_size]) + + gate_weight = paddle.randn([hidden_size, num_experts]) + logits = paddle.matmul(input, gate_weight) + gate_prob = F.softmax(logits, axis=-1) + print(f"gate_prob: {gate_prob}") + + input = paddle.cast(input, "bfloat16") + input.stop_gradient = False + gate_prob.stop_gradient = False + + output, combine_weights, scatter_index, expert_offset, expert_id = layer(input, gate_prob, k, capacity) + + print(f"output: {output}") + print(f"combine_weights: {combine_weights}") + print(f"scatter_index: {scatter_index}") + print(f"expert_offset: {expert_offset}") + print(f"expert_id: {expert_id}") + + # output_g = paddle.randn(output.shape).astype(output.dtype) + # combine_weights_g = paddle.randn(combine_weights.shape).astype(combine_weights.dtype) + output_g = paddle.ones_like(output) + combine_weights_g = paddle.ones_like(combine_weights) + print(f"output_g: {output_g}") + print(f"combine_weights_g: {combine_weights_g}") + + paddle.autograd.backward( + tensors=[output, combine_weights], + grad_tensors=[output_g, combine_weights_g], + ) + # 数值估算 + epsilon = 0.005 + input_numpy = input.detach().astype("float32").numpy() + num_grad = paddle.zeros_like(input) + flattened = num_grad.reshape([-1]) + + for i in range(input.numel()): + input_pos = input_numpy.copy() + input_neg = input_numpy.copy() + input_pos.flat[i] += epsilon + input_neg.flat[i] -= epsilon + + output_pos, _, _, _, _ = layer(paddle.to_tensor(input_pos), gate_prob, k, capacity) + output_neg, _, _, _, _ = layer(paddle.to_tensor(input_neg), gate_prob, k, capacity) + + ''' + flattened[i] = (output_pos.astype("float32").numpy() - output_neg.astype("float32").numpy()).sum() / ( + 2 * epsilon + ) + ''' + grad_value = (output_pos - output_neg).sum() / (2 * epsilon) + flattened[i] = grad_value + + flattened = flattened.reshape(input.shape) + + print(f"input gradient: {input.grad}") + print(f"numerical gradient: {flattened}") + np.testing.assert_allclose( + input.grad.astype("float32").numpy(), flattened.astype("float32").numpy(), rtol=1e-5, atol=0 + ) + + # 数值估算 gate_prob + epsilon = 0.0005 + gate_prob_numpy = gate_prob.detach().astype("float32").numpy() + num_grad = paddle.zeros_like(gate_prob) + flattened = num_grad.reshape([-1]) + + for i in range(gate_prob.numel()): + input_pos = gate_prob_numpy.copy() + input_neg = gate_prob_numpy.copy() + input_pos.flat[i] += epsilon + input_neg.flat[i] -= epsilon + + _, output_pos, _, _, _ = layer(input, paddle.to_tensor(input_pos), k, capacity) + _, output_neg, _, _, _ = layer(input, paddle.to_tensor(input_neg), k, capacity) + + flattened[i] = (output_pos.numpy() - output_neg.numpy()).sum() / (2 * epsilon) + + flattened = flattened.reshape(gate_prob.shape) + + print(f"gate_prob gradient: {gate_prob.grad}") + print(f"numerical gradient: {flattened}") + np.testing.assert_allclose( + gate_prob.grad.astype("float32").numpy(), flattened.astype("float32").numpy(), rtol=1e-4, atol=0 + ) + + +class TestFused(unittest.TestCase): + def test_moe_backward(self): + check_backward_correctness(TestLayer) + + def test_moe_permute_backward(self): + check_backward_correctness(TestLayerPermute) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b0dc90c6f885fe00d06c212f72d8dd0998f65ca0 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 03:42:11 +0000 Subject: [PATCH 42/71] Clean print --- paddle/phi/infermeta/backward.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 16c7e66241cee7..710c300f7a3af7 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1258,13 +1258,6 @@ void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_ int64_t expert_end_index, MetaTensor* x_grad, MetaTensor* combine_weights_grad){ - printf("check infer\n"); - printf("combine shape: %d, scatter shape: %d\n", combine_weights_out.dims().size(), scatter_index.dims().size()); - printf("sizeof(combine_weights_out): %d\n", sizeof(combine_weights_out)); - printf("sizeof(y_grad): %d\n", sizeof(y_grad)); - printf("sizeof combine_weights_out_grad: %d\n", sizeof(combine_weights_out_grad)); - // printf("size of combine_weights_out_grad: %d\n", combine_weights_out_grad.size()); - printf("combine_weights_out_grad shape: %d\n", combine_weights_out_grad.dims().size()); int64_t num_experts = expert_offset.dims()[0]; int64_t hidden_size = y_grad.dims()[1]; int64_t num_rows = scatter_index.dims()[1]; From 9cd02e8b092bf6be573b85f0027d074ea101b563 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 06:00:20 +0000 Subject: [PATCH 43/71] Fix conflict, move some headers. --- paddle/phi/infermeta/ternary.cc | 2 +- paddle/phi/kernels/gpu/cub_kv_sorter.cu | 50 -- paddle/phi/kernels/gpu/moe_fuse_op.h | 794 ------------------ .../moe_gate_dispatch_permute_grad_kernel.cu | 2 +- .../gpu/moe_gate_dispatch_permute_kernel.cu | 2 +- paddle/phi/kernels/gpu/moe_kernel_impl.h | 601 ------------- ...e_ops_partial_nosoftmaxtopk_grad_kernel.cu | 2 +- .../moe_ops_partial_nosoftmaxtopk_kernel.cu | 7 +- .../phi/kernels/{gpu => }/moe_fuse_bwd_op.h | 2 +- paddle/phi/kernels/moe_fuse_op.h | 370 +++++++- paddle/phi/kernels/moe_kernel_impl.h | 79 +- 11 files changed, 431 insertions(+), 1480 deletions(-) delete mode 100644 paddle/phi/kernels/gpu/cub_kv_sorter.cu delete mode 100644 paddle/phi/kernels/gpu/moe_fuse_op.h delete mode 100644 paddle/phi/kernels/gpu/moe_kernel_impl.h rename paddle/phi/kernels/{gpu => }/moe_fuse_bwd_op.h (99%) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index cd79ef9207a58d..a30097e46b6a7e 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1723,7 +1723,7 @@ void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(const MetaTensor& x, common::errors::InvalidArgument( "The dtype of Input(combine_weights) must be FLOAT32, but received %s", combine_weights.dtype())); -int64_t num_experts_diff = expert_end_index - expert_start_index; +//int64_t num_experts_diff = expert_end_index - expert_start_index; int64_t num_rows = x_dims[0]; // if (use_pad) // y->set_dims({num_experts_diff * capacity, x_dims[1]}) ; diff --git a/paddle/phi/kernels/gpu/cub_kv_sorter.cu b/paddle/phi/kernels/gpu/cub_kv_sorter.cu deleted file mode 100644 index ce087e4712c5bc..00000000000000 --- a/paddle/phi/kernels/gpu/cub_kv_sorter.cu +++ /dev/null @@ -1,50 +0,0 @@ -#include "moe_kernel_impl.h" -namespace phi{ - // ===== CUB Sorting things ===== -CubKeyValueSorter::CubKeyValueSorter() - : num_experts_(0), num_bits_(sizeof(int) * 8) {} - -CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) - : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} - -CubKeyValueSorter::CubKeyValueSorter(const int num_experts) - : num_experts_(num_experts), - num_bits_(static_cast(log2(num_experts)) + 1) {} - -void CubKeyValueSorter::update_num_experts(const int num_experts) { - num_experts_ = num_experts; - num_bits_ = static_cast(log2(num_experts)) + - 3; // 额外增加 3 位用于标记 topk的位置 -} - -size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, - bool descending) { - num_key_value_pairs_ = num_key_value_pairs; - size_t required_storage = 0; - int* null_int = nullptr; - if (descending) { - cub::DeviceRadixSort::SortPairsDescending(NULL, - required_storage, - null_int, - null_int, - null_int, - null_int, - num_key_value_pairs, - 0, - 32, - stream_); - } else { - cub::DeviceRadixSort::SortPairs(NULL, - required_storage, - null_int, - null_int, - null_int, - null_int, - num_key_value_pairs, - 0, - num_bits_, - stream_); - } - return required_storage; -} -} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_fuse_op.h b/paddle/phi/kernels/gpu/moe_fuse_op.h deleted file mode 100644 index 43249d5fc35404..00000000000000 --- a/paddle/phi/kernels/gpu/moe_fuse_op.h +++ /dev/null @@ -1,794 +0,0 @@ -#pragma once -#include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/common/exception.h" -#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/common/enforce.h" -#include // 包含常用的 thrust 算法 -#include -#include -#include -#include - -template -__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, - const T* bias, //bias could be nullptr if not used - T* output, - int* indices, - int* source_rows, - const int num_experts, - const int k){ - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - const int num_rows = gridDim.x; - const int block_row = blockIdx.x; - const int thread_read_offset = blockIdx.x * num_experts; - for (int k_idx = 0; k_idx < k; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities - - cub_kvp inp_kvp; - for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const int prior_winning_expert = indices[k * block_row + prior_k]; - - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } - - thread_kvp = arg_max(inp_kvp, thread_kvp); - } - - const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int idx = k * block_row + k_idx; - output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; - indices[idx] = result_kvp.key; - source_rows[idx] = k_idx * num_rows + block_row; - } - __syncthreads(); - } -} - -template -void topk_gating_softmax_kernelLauncher(const T* input, - const T* bias, - T* output, - T* softmax, //no use - int* indices, - int* source_row, - const int num_rows, - const int num_experts, - const int k, - cudaStream_t stream){ - static constexpr int WARPS_PER_TB = 4; - static constexpr int TPB = 256; - moe_top_k<<>>( - input, bias, output, indices, source_row, num_experts, k); -} - -template -__global__ void modify_expert_id(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int64_t num_experts){ - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= k * num_rows) - return; - int ik = idx % k; - int irow = idx / k; - // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 - int mask = ik; // k => 2(11) - // printf("before: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id[idx], ik); - int offset = log2(k) + 1; - expert_id_out[idx] = (expert_id[idx]< -void modify_expert_id_launcher(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int64_t num_experts, - const cudaStream_t& stream){ - int max = 1024; - const int threads = std::min(max, num_rows * k); - const int blocks = (num_rows * k + threads - 1) / threads; - - modify_expert_id<<>>( - expert_id, - expert_id_out, - k, - num_rows, - num_experts - ); -} - -template -__global__ void -unmodify_expert_id(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int64_t num_experts){ - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= k * num_rows) - return; - int ik = idx % k; - int irow = idx / k; - int offset = log2(k) + 1; - expert_id_out[idx] = (expert_id[idx]>>offset); -} - -template -void unmodify_expert_id_launcher(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int64_t num_experts, - const cudaStream_t& stream){ - int max = 1024; - const int threads = std::min(max, num_rows * k); - const int blocks = (num_rows * k + threads - 1) / threads; - - unmodify_expert_id<<>>( - expert_id, - expert_id_out, - k, - num_rows, - num_experts - ); -} - -template -__device__ inline int find_total_elts_leq_target(const T* sorted_indices, const int arr_length, const int target) -{ - int64_t low = 0, high = arr_length - 1, target_location = -1; - while (low <= high) { - int64_t mid = (low + high) / 2; - - if (sorted_indices[mid] > target) { - high = mid - 1; - } - else { - low = mid + 1; - target_location = mid; - } - } - return target_location + 1; -} - -template -__global__ void compute_total_rows_before_expert_kernel(const T* sorted_experts, - const int sorted_experts_len, - const int64_t num_experts, - int64_t* total_rows_before_expert) -{ - - // First, compute the global tid. We only need 1 thread per expert. - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - if (expert >= num_experts) - return; - - - // This should construct the last index where each expert occurs. - total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); - // total_rows_before_expert[0] = 0; - // total_rows_before_expert[1] = 1; - // if (sorted_experts_len > 3) { - // for (int i=0; i<35;i++){ - // total_rows_before_expert[i] = i; - // } - // } - - -} - -template -void compute_total_rows_before_expert(const T* sorted_indices, - const int total_indices, - const int64_t num_experts, - int64_t* total_rows_before_expert, - const cudaStream_t& stream) -{ - const int threads = std::min(static_cast(1024), num_experts); - const int blocks = (num_experts + threads - 1) / threads; - - - compute_total_rows_before_expert_kernel<<>>( - sorted_indices, total_indices, num_experts, total_rows_before_expert); -} - -template -__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int cols, - const int k, - const int64_t capacity, - bool use_pad - ) -{ - - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. - using LoadT = phi::AlignedVector; - LoadT src_vec; - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - const int64_t iexpert = permuted_experts[expanded_dest_row]; - const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); - const int64_t row_in_expert = expanded_dest_row - offset; - if (row_in_expert >= capacity){ - if (threadIdx.x == 0) { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx - auto ik = expanded_source_row / num_rows; - auto isent = expanded_source_row % num_rows; // transpose - combine_weights[isent * k + ik] = 0.f; //unset combine-weight - } - return; - } - int64_t num_padded = 0; - if (threadIdx.x == 0) { - // printf("going through: capacity=%lld, num_active=%lld, row=[%d->%d], row-in-expert %lld\n", - // capacity, - // num_active, - // expanded_dest_row, expanded_source_row, - // row_in_expert - // ); - if (use_pad) - num_padded = iexpert * capacity - offset; - expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; - } - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* dest_row_ptr; - if (use_pad){ - dest_row_ptr = permuted_output + - iexpert * capacity * cols + - row_in_expert * cols; - }else{ - dest_row_ptr = permuted_output + expanded_dest_row * cols; - } - - - for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x* VecSize) { - phi::Load(&source_row_ptr[tid], &src_vec); - phi::Store(src_vec, &dest_row_ptr[tid]); - } -} - -template -void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int cols, - const int k, - const int64_t capacity, - bool use_pad, - cudaStream_t stream) -{ - const int blocks = num_rows * k; - const int threads = std::min(cols, 1024); - constexpr int max_pack_size = 16 / sizeof(T); - if (cols % max_pack_size == 0) { - initialize_moe_routing_kernel<<>>(unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, - permuted_experts, - expert_offset, - combine_weights, - num_rows, - cols, - k, - capacity, - use_pad - ); - } else { - initialize_moe_routing_kernel<<>>(unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, - permuted_experts, - expert_offset, - combine_weights, - num_rows, - cols, - k, - capacity, - use_pad - ); - } -} - -/** - * 原逻辑的output: - * R0E0 - * R0E1 - * R1E0 - * R1E1 - * - * 我们想对all2all和专家gemm做overlap, 所以需要将all2all拆成流水线, 为了便于后续计算, 此kernel的output: - * R0E0 - * R1E0 - * R0E1 - * R1E1 -*/ -template -__global__ void initialize_moe_routing_permute_kernel(const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int cols, - const int k, - const int64_t capacity, - const int64_t world_size, - const int64_t num_local_experts - ) -{ - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. -#pragma unroll - for (int i = 0; i < LoopSize; i++) { - using LoadT = phi::AlignedVector; - LoadT src_vec; - const int expanded_dest_row = blockIdx.x + i * gridDim.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - const int64_t iexpert = permuted_experts[expanded_dest_row]; - const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); - const int64_t row_in_expert = expanded_dest_row - offset; - if (row_in_expert >= capacity){ - if (threadIdx.x == 0) { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx - auto ik = expanded_source_row / num_rows; - auto isent = expanded_source_row % num_rows; // transpose - combine_weights[isent * k + ik] = 0.f; //unset combine-weight - } - continue; - } - int64_t num_padded = 0; - if (threadIdx.x == 0) { - num_padded = iexpert * capacity - offset; - expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; - } - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* dest_row_ptr; - - const int64_t irank = iexpert / num_local_experts; - const int64_t local_iexpert = iexpert % num_local_experts; - dest_row_ptr = permuted_output + local_iexpert * world_size * capacity * cols + irank * capacity * cols + row_in_expert * cols; - - for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) { - phi::Load(&source_row_ptr[tid], &src_vec); - phi::Store(src_vec, &dest_row_ptr[tid]); - } - } -} - -template -void initialize_moe_routing_permute_kernelLauncher(const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int cols, - const int k, - const int64_t capacity, - const int64_t world_size, - const int64_t num_local_experts, - cudaStream_t stream) -{ - const int loop_size = 2; - const int blocks = (num_rows * k) / loop_size; - assert((num_rows * k) % loop_size == 0); - const int threads = std::min(cols, 1024); - constexpr int max_pack_size = 16 / sizeof(T); - if (cols % max_pack_size == 0) { - initialize_moe_routing_permute_kernel<<>>(unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, - permuted_experts, - expert_offset, - combine_weights, - num_rows, - cols, - k, - capacity, - world_size, - num_local_experts - ); - } else { - initialize_moe_routing_permute_kernel<<>>(unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, - permuted_experts, - expert_offset, - combine_weights, - num_rows, - cols, - k, - capacity, - world_size, - num_local_experts - ); - } -} - -// moe_ops_partial_nosoftmaxtopk utils - -template -void compute_global_expert_offset(const T* expert_id, //[len] - T* sort_buffer, //[len] - int64_t* expert_offset,//[num_experts] - const int64_t len, - const int64_t num_experts, - const int64_t capacity, - const cudaStream_t& stream, - const phi::memory_utils::ThrustAllocator& allocator){ - auto ptr = thrust::device_pointer_cast(expert_id); - auto outptr = thrust::device_pointer_cast(sort_buffer); - auto offsetptr = thrust::device_pointer_cast(expert_offset); - const auto& exec_policy = thrust::cuda::par(allocator).on(stream); - thrust::copy(exec_policy, ptr, ptr + len, outptr); - thrust::sort(exec_policy, outptr, outptr + len); - const int threads = std::min(static_cast(1024), num_experts); - const int blocks = (num_experts + threads - 1) / threads; - - compute_total_rows_before_expert_kernel<<>>( - sort_buffer, len, num_experts, expert_offset); - thrust::adjacent_difference(exec_policy, offsetptr, offsetptr + num_experts, offsetptr); - // thrust::transform(offsetptr, - // offsetptr + num_experts, - // thrust::constant_iterator(capacity), - // offsetptr, - // thrust::minimum() - // ); -} - -template -__global__ void modify_and_mask_expert_id(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int num_experts, - const int expert_start_index, - const int expert_end_index - ){ - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= k * num_rows) - return; - int ik = idx % k; - int irow = idx / k; - // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 - int mask = ik; // k => 2(11) - // printf("before: idx=%d, expert-id:%d, ik=%d, s=%d, e=%d\n", idx, expert_id[idx], ik, expert_start_index, expert_end_index); - int offset = log2(k) + 1; - if (expert_id[idx] < expert_start_index || expert_id[idx] >= expert_end_index){ - expert_id_out[idx] = (num_experts << offset) ; // -1 means - }else{ - expert_id_out[idx] = (expert_id[idx]< -void modify_and_mask_expert_id_launcher(const T* expert_id, - T* expert_id_out, - const int k, - const int num_rows, - const int num_experts, - const int expert_start_index, - const int expert_end_index, - const cudaStream_t& stream){ - int max = 1024; - const int threads = std::min(max, num_rows * k); - const int blocks = (num_rows * k + threads - 1) / threads; - - modify_and_mask_expert_id<<>>( - expert_id, - expert_id_out, - k, - num_rows, - num_experts, - expert_start_index, - expert_end_index - ); -} - -template -void compute_local_expert_offset(const T* sorted_expert_id, //[len] - int64_t* expert_offset,//[num_experts] - int64_t* expert_num, - const int64_t len, - const int64_t num_experts, - const int64_t capacity, - const cudaStream_t& stream, - const phi::memory_utils::ThrustAllocator& allocator){ - auto offset_ptr = thrust::device_pointer_cast(expert_offset); - auto expert_num_ptr = thrust::device_pointer_cast(expert_num); - const auto& exec_policy = thrust::cuda::par(allocator).on(stream); - thrust::fill(exec_policy, offset_ptr, offset_ptr + num_experts, static_cast(0)); - - const int threads = std::min(static_cast(1024), num_experts); - const int blocks = (num_experts + threads - 1) / threads; - - compute_total_rows_before_expert_kernel<<>>( - sorted_expert_id, len, num_experts, expert_offset); - // 不考虑 capcity 影响 - thrust::adjacent_difference(exec_policy, offset_ptr, offset_ptr + num_experts, expert_num_ptr); -} - -template -__global__ void cal_expert_size_and_filter( - T* expert_id, - const int64_t* expert_offset, - int64_t len, - int64_t num_experts, - int64_t capcity, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse){ - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= len) - return; - int64_t off = reverse? expert_offset[expert_end_index-1] : 0; - if (reverse){ - for (int64_t i = expert_end_index - 1; i >= expert_start_index; --i){ - if (idx >= expert_offset[i]) - break; - off = expert_offset[i]; - } - }else{ - for (int64_t i = expert_start_index; i != expert_end_index; ++i){ - if (idx < expert_offset[i]) - break; - off = expert_offset[i]; - } - } - if (reverse){ - if(((off-1) - idx) >= capcity){ - expert_id[idx] = num_experts; - } - }else{ - if ((idx - off) >= capcity){ - expert_id[idx] = num_experts; - } - } -} - -template -void cal_expert_size_and_filter_launcher(T* expert_id, - const int64_t* expert_offset, - int64_t len, - int64_t num_experts, - int64_t capcity, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse, - const cudaStream_t& stream){ - if (len <= 0) - return; - const int64_t threads = std::min(static_cast(1024), len); - const int64_t blocks = (len + threads - 1) / threads; - cal_expert_size_and_filter<<>>( - expert_id, - expert_offset, - len, - num_experts, - capcity, - expert_start_index, - expert_end_index, - reverse - ); -} - -template -__global__ void build_seqsort_kv_pairs_kernel( T* seqsort_key, - T* seqsort_value, - const int* expanded_dest_row_to_expanded_source_row, - // int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int k, - const int64_t num_active, - const int64_t capacity, - int64_t expert_start_index, - bool use_pad) -{ - const int expanded_dest_row = blockIdx.x * blockDim.x + threadIdx.x; - if (expanded_dest_row >= num_rows * k){ - return; - } - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - const int64_t iexpert = permuted_experts[expanded_dest_row]; - const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); - const int64_t row_in_expert = expanded_dest_row - offset; - // printf("DEBUG %d=>%d, num_active=%lld, offset=%lld, cap=%lld \n", expanded_dest_row, expanded_source_row, num_active, row_in_expert, capacity); - // 从此以后不会发生截断,后续的 seqsort 也不会截断。 - // printf("expanded_dest_row:%d row_in_expert:%lld capacity:%lld num_active:%lld\n", expanded_dest_row, row_in_expert, capacity, num_active); - if ((use_pad && row_in_expert >= capacity) || expanded_dest_row >= num_active){ - // expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // unset scatter-idx - auto ik = expanded_source_row / num_rows; - auto isent = expanded_source_row % num_rows; // transpose - combine_weights[isent * k + ik] = 0.f; //unset combine-weight - return; - } - - // auto num_padded = use_pad ? (iexpert - expert_start_index) * capacity - offset : 0; - // expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row + num_padded; - - // Duplicate and permute rows - T source_row = expanded_source_row % num_rows; - - if (use_pad){ - // printf("inner print: k=%d num_row=%d before minus %d\n", k, num_rows, source_row); - seqsort_key [(iexpert - expert_start_index) * capacity + row_in_expert] = source_row; // 为保证 padding 位置(0)在最后, 所以对 pos-id 取减去其最大值 - seqsort_value[(iexpert - expert_start_index) * capacity + row_in_expert] = expanded_source_row; - }else{ - seqsort_key[expanded_dest_row] = source_row; - seqsort_value[expanded_dest_row] = expanded_source_row; - } -} - - - -template -void build_seqsort_kv_pairs_kernel_launcher(T* seqsort_key, // 实现初始化为 num-rows,保证 sort 到最后 - T* seqsort_value, - const int* expanded_dest_row_to_expanded_source_row, - // int* expanded_source_row_to_expanded_dest_row, - const int* permuted_experts, - const int64_t* expert_offset, - float* combine_weights, //output - const int num_rows, - const int k, - const int64_t num_active, // -1 expert pos - const int64_t capacity, - const int64_t expert_start_index, - bool use_pad, - cudaStream_t stream) -{ - int max = 1024; - const int threads = std::min(max, num_rows * k); - const int blocks = (num_rows * k + threads - 1) / threads; - build_seqsort_kv_pairs_kernel<<>>(seqsort_key, - seqsort_value, - expanded_dest_row_to_expanded_source_row, - // expanded_source_row_to_expanded_dest_row, - permuted_experts, - expert_offset, - combine_weights, - num_rows, - k, - num_active, - capacity, - expert_start_index, - use_pad - ); - -} - - -template -__global__ void copy_unpermuted_to_permuted_kernel(const T* unpermuted_input, - T* permuted_output, - const int* padded_out_to_unpermuted_input, - const int* padded_out_to_expanded_input, - int* expanded_input_to_padded_out, - const int64_t padded_len, - const int64_t num_rows, - const int64_t k, - const int64_t cols) -{ - using LoadT = phi::AlignedVector; - LoadT src_vec; - const int padded_dest_row = blockIdx.x; - if (padded_out_to_unpermuted_input[padded_dest_row] == num_rows){ - // padded_out_to_unpermuted_input[padded_dest_row] = -1; - return; // padded place - } - const int source_row = padded_out_to_unpermuted_input[padded_dest_row]; - const int source_row_expanded = padded_out_to_expanded_input[padded_dest_row]; - if (threadIdx.x == 0){ - expanded_input_to_padded_out[source_row_expanded] = padded_dest_row; - } - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* padded_dest_row_ptr = permuted_output + padded_dest_row * cols; - - for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x* VecSize) { - phi::Load(&source_row_ptr[tid], &src_vec); - phi::Store(src_vec, &padded_dest_row_ptr[tid]); - } - PADDLE_ENFORCE((padded_dest_row < padded_len)&&(source_row_expanded < num_rows * k), - "The index is out of bounds, " - "origin_input[%d] -> distributed_input:[%d], should < [%ld],[%ld] \n", - source_row_expanded, padded_dest_row, num_rows*k, padded_len); - - // for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - // padded_dest_row_ptr[tid] = source_row_ptr[tid]; // copy - // } -} - -template -void copy_unpermuted_to_permuted_kernelLauncher(const T* unpermuted_input, - T* permuted_output, - const int* padded_out_to_unpermuted_input, - const int* padded_out_to_expanded_input, - int* expanded_input_to_padded_out, - const int64_t padded_len, - const int64_t num_rows, //unpermuted_input_len - const int64_t k, - const int64_t num_cols, - cudaStream_t stream) -{ - auto blocks = padded_len; - auto threads = std::min(num_cols, static_cast(1024)); - constexpr int64_t max_pack_size = 16 / sizeof(T); - if (num_cols % max_pack_size == 0) { - copy_unpermuted_to_permuted_kernel<<>>( - unpermuted_input, - permuted_output, - padded_out_to_unpermuted_input, - padded_out_to_expanded_input, - expanded_input_to_padded_out, - padded_len, - num_rows, - k, - num_cols); - }else{ - copy_unpermuted_to_permuted_kernel<<>>( - unpermuted_input, - permuted_output, - padded_out_to_unpermuted_input, - padded_out_to_expanded_input, - expanded_input_to_padded_out, - padded_len, - num_rows, - k, - num_cols); - } -} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu index e8c7146285bb4a..03b446e8dad0bf 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu @@ -4,7 +4,7 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/contiguous_kernel.h" -#include "paddle/phi/kernels/gpu/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi{ diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu index 535b9bff6ea95d..0a7eba28955832 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/moe_fuse_op.h" +#include "paddle/phi/kernels/moe_fuse_op.h" #include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" diff --git a/paddle/phi/kernels/gpu/moe_kernel_impl.h b/paddle/phi/kernels/gpu/moe_kernel_impl.h deleted file mode 100644 index c054024ffbbb67..00000000000000 --- a/paddle/phi/kernels/gpu/moe_kernel_impl.h +++ /dev/null @@ -1,601 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "cub/cub.cuh" -#include "paddle/phi/kernels/funcs/math_cuda_utils.h" -#include -#include -#include -#include -namespace phi { - -static const float HALF_FLT_MAX = 65504.F; -static const float HALF_FLT_MIN = -65504.F; -static inline size_t AlignTo16(const size_t& input) { - static constexpr int ALIGNMENT = 16; - return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); -} - -class CubKeyValueSorter { - public: - inline CubKeyValueSorter(); - - inline CubKeyValueSorter(cudaStream_t stream = 0); - - inline explicit CubKeyValueSorter(const int num_experts); - - inline void update_num_experts(const int num_experts); - - inline size_t getWorkspaceSize(const size_t num_key_value_pairs, - bool descending = false); - - template - inline void run(void* workspace, - const size_t workspace_size, - const KeyT* keys_in, - KeyT* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream); - - private: - size_t num_key_value_pairs_; - int num_experts_; - int num_bits_; - cudaStream_t stream_; -}; - - - - -template -inline void CubKeyValueSorter::run(void* workspace, - const size_t workspace_size, - const KeyT* keys_in, - KeyT* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream) { - size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); - size_t actual_ws_size = workspace_size; - - if (expected_ws_size > workspace_size) { - std::stringstream err_ss; - err_ss << "[Error][CubKeyValueSorter::run]\n"; - err_ss - << "Error. The allocated workspace is too small to run this problem.\n"; - err_ss << "Expected workspace size of at least " << expected_ws_size - << " but got problem size " << workspace_size << "\n"; - throw std::runtime_error(err_ss.str()); - } - if (descending) { - cub::DeviceRadixSort::SortPairsDescending(workspace, - actual_ws_size, - keys_in, - keys_out, - values_in, - values_out, - num_key_value_pairs, - 0, - 32, - stream); - } else { - cub::DeviceRadixSort::SortPairs(workspace, - actual_ws_size, - keys_in, - keys_out, - values_in, - values_out, - num_key_value_pairs, - 0, - num_bits_, - stream); - } -} - -template <> -inline void CubKeyValueSorter::run(void* workspace, - const size_t workspace_size, - const __nv_bfloat16* keys_in, - __nv_bfloat16* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream) {} - -// CubKeyValueSorter sorter_(stream); - -// -------- initialize_expert_choice_route_kernel -------- // -template -__global__ void initialize_expert_choice_route_kernel( - int* expert_for_source_row, - int* source_row, - int* expanded_source_row_to_expanded_dest_row, - int64_t* total_rows_before_expert, - T* attr_mask, - const int cols, - const int k, - const int batch_size) { - int start = cols * blockIdx.x; - - for (int i = threadIdx.x; i < cols; i += blockDim.x) { - expert_for_source_row[start + i] = blockIdx.x; - source_row[start + i] = start + i; - expanded_source_row_to_expanded_dest_row[start + i] = -1; - attr_mask[start + i] = (T)1.0f; - } - if (threadIdx.x == 0) { - total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); - } -} - -// -------- softmax_kernel -------- // -template -__global__ void softmax_kernel_v4( - T* qk_buf_, - const T* qk_buf_src, // shape [batch_size, seq_len] - const T* attr_mask, // shape [batch_size, seq_len] - const int batch_size, - const int seq_len) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - float data[ITEMS_PER_THREAD]; - int qk_offset; - __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { - qk_offset = - ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; - int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x; - - float qk = static_cast(qk_buf_src[qk_offset]); - float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); - - mask_val = (1.0f - mask_val) * -10000.0f; - - data[i] = qk + mask_val; - local_max = fmax(local_max, data[i]); - } - - float max_val = - blockDim.x <= 32 - ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) - : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); - if (threadIdx.x == 0) { - s_max = max_val; - } - __syncthreads(); - - float local_sum = 0; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { - data[i] = __expf(data[i] - s_max); - local_sum += data[i]; - } - float sum_val = - blockDim.x <= 32 - ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) - : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); - if (threadIdx.x == 0) { - s_mean = sum_val + 1e-6f; - s_mean = __fdividef(1.0f, s_mean); - } - __syncthreads(); - - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { - qk_offset = - ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; - qk_buf_[qk_offset] = (T)(data[i] * s_mean); - } -#endif -} - -template -__global__ void softmax_kernel_v4_half2(T* qk_buf_, - const T* attr_mask, - const int batch_size, - const int seq_len) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - using T2 = half2; - T2* qk_buf_half2 = reinterpret_cast(qk_buf_); - const T2* attr_mask_half2 = (const T2*)attr_mask; - - T2 data[ITEMS_PER_THREAD]; - int qk_offset; - __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { - qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + - threadIdx.x; - int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x; - - T2 qk = qk_buf_half2[qk_offset]; - T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); - mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), - __float2half2_rn(-10000.0f)); - - data[i] = __hadd2(qk, mask_val); - - local_max = fmax( - local_max, - fmax(static_cast(data[i].x), static_cast(data[i].y))); - } - - float max_val = - blockDim.x <= 32 - ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) - : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); - if (threadIdx.x == 0) { - s_max = max_val; - } - __syncthreads(); - - float local_sum = 0; - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { - data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); - local_sum += static_cast(data[i].x + data[i].y); - } - - float sum_val = - blockDim.x <= 32 - ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) - : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); - - if (threadIdx.x == 0) { - s_mean = sum_val + 1e-6f; - s_mean = __fdividef(1.0f, s_mean); - } - __syncthreads(); - - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { - qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + - threadIdx.x; - qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); - } -#endif -} - -template -__global__ void softmax_kernel_v5_half2(T* qk_buf_, - const T* attr_mask, - const int batch_size, - const int seq_len) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - using T2 = half2; - T2* qk_buf_half2 = reinterpret_cast(qk_buf_); - const T2* attr_mask_half2 = (const T2*)attr_mask; - - T2 data[NUM][ITEMS_PER_THREAD]; - - int qk_offset[NUM]; - - __shared__ float s_sum[NUM], s_max[NUM]; - float local_max[NUM]; -#pragma unroll - for (int j = 0; j < NUM; j++) { - local_max[j] = -1e20f; - } - - const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM); - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { - int mask_offset[NUM]; -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_offset[j] = - ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + - blockDim.x * i + threadIdx.x; - mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) + - blockDim.x * i + threadIdx.x; - } - - T2 mask_val[NUM]; -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); - } - - T2 qk[NUM]; -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk[j] = qk_buf_half2[qk_offset[j]]; - } -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), - __float2half2_rn(-10000.0f)); - } -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - data[j][i] = __hadd2(qk[j], mask_val[j]); - local_max[j] = fmax(local_max[j], - fmax(static_cast(data[j][i].x), - static_cast(data[j][i].y))); - } - } - if (blockDim.x <= 32) { - phi::funcs::WarpReduceMaxV2(local_max); - } else { - phi::funcs::BlockReduceMaxV2(local_max); - } - - if (threadIdx.x == 0) { -#pragma unroll - for (int j = 0; j < NUM; j++) { - s_max[j] = local_max[j]; - } - } - __syncthreads(); - float local_sum[NUM]; -#pragma unroll - for (int j = 0; j < NUM; j++) { - local_sum[j] = {0.f}; - } - - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); - } - -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - local_sum[j] += static_cast(data[j][i].x + data[j][i].y); - } - } - - if (blockDim.x <= 32) { - phi::funcs::WarpReduceSumV2(local_sum); - - } else { - phi::funcs::BlockReduceSumV2(local_sum); - } - - if (threadIdx.x == 0) { -#pragma unroll - for (int j = 0; j < NUM; j++) { - s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); - } - } - __syncthreads(); - - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; - i++) { -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_offset[j] = - ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + - blockDim.x * i + threadIdx.x; - } - -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_buf_half2[qk_offset[j]] = - __hmul2(data[j][i], __float2half2_rn(s_sum[j])); - } - } -#endif -} - -// -------- transpose_kernel -------- // -template -__global__ void transposeAxis01( - T* out, T* in, const int dim0, const int dim1, const int dim2) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < dim0 * dim1 * dim2) { - const int input_dim2_index = index % dim2; - index = (index - input_dim2_index) / dim2; - const int input_dim1_index = index % dim1; - index = (index - input_dim1_index) / dim1; - const int input_dim0_index = index % dim0; - - out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + - input_dim2_index] = in[input_dim0_index * dim1 * dim2 + - input_dim1_index * dim2 + input_dim2_index]; - } -} - -// -------- padding_kernel -------- // -template -__global__ void paddingKernel(T* output1, - int* output2, - const T* input1, - const int* input2, - const int* input_lengths, - const int num_tokens, - const int batch_size, - const int max_seq_len, - const int num_experts) { - const bool IS_FP32 = std::is_same::value; - const T MIN_T_VAL = (!IS_FP32) ? (T)HALF_FLT_MIN : (T)FLT_MIN; - int offset1 = blockIdx.x * num_tokens; - int offset2 = blockIdx.x * batch_size * max_seq_len; - for (int i = 0; i < batch_size; i++) { - const T* in1_ptr = input1 + offset1; - const int* in2_ptr = input2 + offset1; - int input_length = input_lengths[i]; - offset1 += input_length; - - T* out1_ptr = output1 + offset2; - int* out2_ptr = output2 + offset2; - offset2 += max_seq_len; - - for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) { - if (j < input_length) { - out1_ptr[j] = in1_ptr[j]; - out2_ptr[j] = in2_ptr[j]; - } else { - out1_ptr[j] = MIN_T_VAL; - out2_ptr[j] = 0; - } - } - } -} - -// -------- general_topk_pair_sort_kernel -------- // -template -__global__ void general_topk_pair_sort(T* out_keys, - int* out_values, - T* in_keys, - int* in_values) { - typedef cub::BlockRadixSort - BlockRadixSort; - typedef cub:: - BlockLoad - BlockLoadKey; - typedef cub:: - BlockLoad - BlockLoadValue; - typedef cub:: - BlockStore - BlockStoreKey; - typedef cub::BlockStore - BlockStoreValue; - - __shared__ union { - typename BlockRadixSort::TempStorage sort; - typename BlockLoadKey::TempStorage loadkey; - typename BlockLoadValue::TempStorage loadvalue; - typename BlockStoreKey::TempStorage storekey; - typename BlockStoreValue::TempStorage storevalue; - } temp_storage; - - int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; - - T thread_keys[ITEMS_PER_THREAD]; - int thread_values[ITEMS_PER_THREAD]; - BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); - BlockLoadValue(temp_storage.loadvalue) - .Load(in_values + block_offset, thread_values); - __syncthreads(); - - BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); - __syncthreads(); - - BlockStoreKey(temp_storage.storekey) - .Store(out_keys + block_offset, thread_keys); - BlockStoreValue(temp_storage.storevalue) - .Store(out_values + block_offset, thread_values); -} - -// -------- finalize_moe_routing_kernel -------- // -template -__global__ void finalize_moe_routing_kernel( - const T* expanded_permuted_rows, - T* reduced_unpermuted_output, - const T* skip, - const T* bias, - const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, - const int cols, - const int k, - bool ec_route) { - const int original_row = blockIdx.x; - const int num_rows = gridDim.x; - T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; - const T* skip_row_ptr = skip + original_row * cols; - - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - T thread_output = skip_row_ptr[tid]; - for (int k_idx = 0; k_idx < k; ++k_idx) { - const int expanded_original_row = original_row + k_idx * num_rows; - const int expanded_permuted_row = - expanded_source_row_to_expanded_dest_row[expanded_original_row]; - - if (ec_route && expanded_permuted_row == -1) continue; - const int64_t k_offset = - ec_route ? expanded_original_row : original_row * k + k_idx; - const T row_scale = scales[k_offset]; - const T* expanded_permuted_rows_row_ptr = - expanded_permuted_rows + expanded_permuted_row * cols; - - const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset]; - const T* bias_ptr = bias + expert_idx * cols; - - thread_output = - thread_output + - row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); - } - reduced_row_ptr[tid] = thread_output; - } -} - -// -------- initialize_moe_routing_kernel -------- // -template -__global__ void initialize_moe_routing_kernel( - const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int num_rows, - const int active_rows, - const int cols, - const int k, - const int max_seq_len, - bool ec_route) { - // using LoadT = phi::AlignedVector; - // LoadT src_vec; - - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way - // reduction and unpermuting. I need the reverse map for that reduction to - // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = - ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row / - k * max_seq_len + - expanded_dest_row % k] - : expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - if (threadIdx.x == 0) { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = - expanded_dest_row; - } - - if (blockIdx.x < active_rows) { - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* dest_row_ptr = permuted_output + expanded_dest_row * cols; - - for (int tid = threadIdx.x * VecSize; tid < cols; - tid += blockDim.x * VecSize) { - dest_row_ptr[tid] = source_row_ptr[tid]; - // phi::Load(&source_row_ptr[tid], &src_vec); - // phi::Store(src_vec, &dest_row_ptr[tid]); - } - } -} - -} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu index 6994e7a749bd7c..1c45398a60a2f7 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/contiguous_kernel.h" -#include "paddle/phi/kernels/gpu/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" #include #include diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu index 9354f3f1cde87d..6570adec9a813a 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -19,13 +19,13 @@ * with minor changes. */ #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/moe_fuse_op.h" +#include "paddle/phi/kernels/moe_fuse_op.h" #include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/slice_kernel.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" namespace phi { @@ -44,7 +44,7 @@ namespace phi { // static constexpr int ALIGNMENT = 16; // return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); // } - +namespace{ // -------- getWorkspaceSize -------- // template size_t getWorkspaceSize(const int num_rows, @@ -88,6 +88,7 @@ size_t getWorkspaceSize(const int num_rows, // std::cout<<"buf_size --"<< buf_size<<" "< void apply_moe_dispatch_fwd( diff --git a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h similarity index 99% rename from paddle/phi/kernels/gpu/moe_fuse_bwd_op.h rename to paddle/phi/kernels/moe_fuse_bwd_op.h index 41e386256f371e..9b7f669729dea4 100644 --- a/paddle/phi/kernels/gpu/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" -#include "paddle/phi/kernels/gpu/moe_kernel_impl.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" template diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index b2b94b9e1faf1e..bbda9e9e21e45c 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -1,19 +1,12 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #pragma once +#include // 包含常用的 thrust 算法 +#include +#include +#include +#include +#include "paddle/common/enforce.h" #include "paddle/common/exception.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" @@ -455,3 +448,352 @@ void initialize_moe_routing_permute_kernelLauncher( num_local_experts); } } + +// moe_ops_partial_nosoftmaxtopk utils + +template +void compute_global_expert_offset( + const T* expert_id, //[len] + T* sort_buffer, //[len] + int64_t* expert_offset, //[num_experts] + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator) { + auto ptr = thrust::device_pointer_cast(expert_id); + auto outptr = thrust::device_pointer_cast(sort_buffer); + auto offsetptr = thrust::device_pointer_cast(expert_offset); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::copy(exec_policy, ptr, ptr + len, outptr); + thrust::sort(exec_policy, outptr, outptr + len); + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sort_buffer, len, num_experts, expert_offset); + thrust::adjacent_difference( + exec_policy, offsetptr, offsetptr + num_experts, offsetptr); + // thrust::transform(offsetptr, + // offsetptr + num_experts, + // thrust::constant_iterator(capacity), + // offsetptr, + // thrust::minimum() + // ); +} + +template +__global__ void modify_and_mask_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d, s=%d, e=%d\n", idx, + // expert_id[idx], ik, expert_start_index, expert_end_index); + int offset = log2(k) + 1; + if (expert_id[idx] < expert_start_index || + expert_id[idx] >= expert_end_index) { + expert_id_out[idx] = (num_experts << offset); // -1 means + } else { + expert_id_out[idx] = (expert_id[idx] << offset) | mask; + } + // printf("after: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id_out[idx], + // ik); +} + +template +void modify_and_mask_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_and_mask_expert_id + <<>>(expert_id, + expert_id_out, + k, + num_rows, + num_experts, + expert_start_index, + expert_end_index); +} + +template +void compute_local_expert_offset( + const T* sorted_expert_id, //[len] + int64_t* expert_offset, //[num_experts] + int64_t* expert_num, + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator) { + auto offset_ptr = thrust::device_pointer_cast(expert_offset); + auto expert_num_ptr = thrust::device_pointer_cast(expert_num); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::fill( + exec_policy, offset_ptr, offset_ptr + num_experts, static_cast(0)); + + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_expert_id, len, num_experts, expert_offset); + // 不考虑 capcity 影响 + thrust::adjacent_difference( + exec_policy, offset_ptr, offset_ptr + num_experts, expert_num_ptr); +} + +template +__global__ void cal_expert_size_and_filter(T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capcity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse) { + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= len) return; + int64_t off = reverse ? expert_offset[expert_end_index - 1] : 0; + if (reverse) { + for (int64_t i = expert_end_index - 1; i >= expert_start_index; --i) { + if (idx >= expert_offset[i]) break; + off = expert_offset[i]; + } + } else { + for (int64_t i = expert_start_index; i != expert_end_index; ++i) { + if (idx < expert_offset[i]) break; + off = expert_offset[i]; + } + } + if (reverse) { + if (((off - 1) - idx) >= capcity) { + expert_id[idx] = num_experts; + } + } else { + if ((idx - off) >= capcity) { + expert_id[idx] = num_experts; + } + } +} + +template +void cal_expert_size_and_filter_launcher(T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capcity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse, + const cudaStream_t& stream) { + if (len <= 0) return; + const int64_t threads = std::min(static_cast(1024), len); + const int64_t blocks = (len + threads - 1) / threads; + cal_expert_size_and_filter + <<>>(expert_id, + expert_offset, + len, + num_experts, + capcity, + expert_start_index, + expert_end_index, + reverse); +} + +template +__global__ void build_seqsort_kv_pairs_kernel( + T* seqsort_key, + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int k, + const int64_t num_active, + const int64_t capacity, + int64_t expert_start_index, + bool use_pad) { + const int expanded_dest_row = blockIdx.x * blockDim.x + threadIdx.x; + if (expanded_dest_row >= num_rows * k) { + return; + } + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + // printf("DEBUG %d=>%d, num_active=%lld, offset=%lld, cap=%lld \n", + // expanded_dest_row, expanded_source_row, num_active, row_in_expert, + // capacity); 从此以后不会发生截断,后续的 seqsort 也不会截断。 + // printf("expanded_dest_row:%d row_in_expert:%lld capacity:%lld + // num_active:%lld\n", expanded_dest_row, row_in_expert, capacity, + // num_active); + if ((use_pad && row_in_expert >= capacity) || + expanded_dest_row >= num_active) { + // expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // + // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + return; + } + + // auto num_padded = use_pad ? (iexpert - expert_start_index) * capacity - + // offset : 0; expanded_source_row_to_expanded_dest_row[expanded_source_row] = + // expanded_dest_row + num_padded; + + // Duplicate and permute rows + T source_row = expanded_source_row % num_rows; + + if (use_pad) { + // printf("inner print: k=%d num_row=%d before minus %d\n", k, num_rows, + // source_row); + seqsort_key[(iexpert - expert_start_index) * capacity + row_in_expert] = + source_row; // 为保证 padding 位置(0)在最后, 所以对 pos-id + // 取减去其最大值 + seqsort_value[(iexpert - expert_start_index) * capacity + row_in_expert] = + expanded_source_row; + } else { + seqsort_key[expanded_dest_row] = source_row; + seqsort_value[expanded_dest_row] = expanded_source_row; + } +} + +template +void build_seqsort_kv_pairs_kernel_launcher( + T* seqsort_key, // 实现初始化为 num-rows,保证 sort 到最后 + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int k, + const int64_t num_active, // -1 expert pos + const int64_t capacity, + const int64_t expert_start_index, + bool use_pad, + cudaStream_t stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + build_seqsort_kv_pairs_kernel<<>>( + seqsort_key, + seqsort_value, + expanded_dest_row_to_expanded_source_row, + // expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + k, + num_active, + capacity, + expert_start_index, + use_pad); +} + +template +__global__ void copy_unpermuted_to_permuted_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, + const int64_t k, + const int64_t cols) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int padded_dest_row = blockIdx.x; + if (padded_out_to_unpermuted_input[padded_dest_row] == num_rows) { + // padded_out_to_unpermuted_input[padded_dest_row] = -1; + return; // padded place + } + const int source_row = padded_out_to_unpermuted_input[padded_dest_row]; + const int source_row_expanded = padded_out_to_expanded_input[padded_dest_row]; + if (threadIdx.x == 0) { + expanded_input_to_padded_out[source_row_expanded] = padded_dest_row; + } + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* padded_dest_row_ptr = permuted_output + padded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &padded_dest_row_ptr[tid]); + } + PADDLE_ENFORCE( + (padded_dest_row < padded_len) && (source_row_expanded < num_rows * k), + "The index is out of bounds, " + "origin_input[%d] -> distributed_input:[%d], should < [%ld],[%ld] \n", + source_row_expanded, + padded_dest_row, + num_rows * k, + padded_len); + + // for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + // padded_dest_row_ptr[tid] = source_row_ptr[tid]; // copy + // } +} + +template +void copy_unpermuted_to_permuted_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, // unpermuted_input_len + const int64_t k, + const int64_t num_cols, + cudaStream_t stream) { + auto blocks = padded_len; + auto threads = std::min(num_cols, static_cast(1024)); + constexpr int64_t max_pack_size = 16 / sizeof(T); + if (num_cols % max_pack_size == 0) { + copy_unpermuted_to_permuted_kernel + <<>>(unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + } else { + copy_unpermuted_to_permuted_kernel + <<>>(unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + } +} \ No newline at end of file diff --git a/paddle/phi/kernels/moe_kernel_impl.h b/paddle/phi/kernels/moe_kernel_impl.h index f3f3a2984fdce9..4db8ee46954116 100644 --- a/paddle/phi/kernels/moe_kernel_impl.h +++ b/paddle/phi/kernels/moe_kernel_impl.h @@ -1,8 +1,11 @@ -/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -10,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include -#include -#include #include #include "cub/cub.cuh" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" +#include +#include +#include +#include namespace phi { static const float HALF_FLT_MAX = 65504.F; @@ -27,19 +31,19 @@ static inline size_t AlignTo16(const size_t& input) { class CubKeyValueSorter { public: - CubKeyValueSorter(); + inline CubKeyValueSorter(); - explicit CubKeyValueSorter(cudaStream_t stream = 0); + inline CubKeyValueSorter(cudaStream_t stream = 0); - explicit CubKeyValueSorter(const int num_experts); + inline explicit CubKeyValueSorter(const int num_experts); - void update_num_experts(const int num_experts); + inline void update_num_experts(const int num_experts); - size_t getWorkspaceSize(const size_t num_key_value_pairs, + inline size_t getWorkspaceSize(const size_t num_key_value_pairs, bool descending = false); template - void run(void* workspace, + inline void run(void* workspace, const size_t workspace_size, const KeyT* keys_in, KeyT* keys_out, @@ -56,6 +60,55 @@ class CubKeyValueSorter { cudaStream_t stream_; }; + +// ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + 3; //额外增加 3 位用于标记 topk的位置 +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32, + stream_); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_, + stream_); + } + return required_storage; +} + + template inline void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, @@ -425,8 +478,8 @@ __global__ void paddingKernel(T* output1, const int batch_size, const int max_seq_len, const int num_experts) { - const bool IS_FP16 = std::is_same::value; - const T MIN_T_VAL = (IS_FP16) ? (T)HALF_FLT_MIN : (T)FLT_MIN; + const bool IS_FP32 = std::is_same::value; + const T MIN_T_VAL = (!IS_FP32) ? (T)HALF_FLT_MIN : (T)FLT_MIN; int offset1 = blockIdx.x * num_tokens; int offset2 = blockIdx.x * batch_size * max_seq_len; for (int i = 0; i < batch_size; i++) { @@ -591,4 +644,4 @@ __global__ void initialize_moe_routing_kernel( } } -} // namespace phi +} // namespace phi \ No newline at end of file From e8e494e97a531266aa2fd9c77d132e69fdd033d3 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 06:02:17 +0000 Subject: [PATCH 44/71] sync with dev --- .pre-commit-config.yaml | 175 ---------------------------------------- third_party/openblas | 2 +- 2 files changed, 1 insertion(+), 176 deletions(-) delete mode 100755 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100755 index a72994c245fa32..00000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,175 +0,0 @@ -# Exclude all third-party libraries and auto-generated files globally -exclude: | - (?x)^( - patches/.+| - paddle/fluid/framework/fleet/heter_ps/cudf/.+| - paddle/fluid/distributed/ps/thirdparty/round_robin.h| - python/paddle/utils/gast/.+| - third_party/.+ - )$ -repos: - # Common hooks - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-added-large-files - - id: check-merge-conflict - - id: check-symlinks - - id: detect-private-key - - id: end-of-file-fixer - - id: sort-simple-yaml - files: (ops|backward|op_[a-z_]+)\.yaml$ - - id: trailing-whitespace - - repo: https://github.com/Lucas-C/pre-commit-hooks.git - rev: v1.5.1 - hooks: - - id: remove-crlf - - id: remove-tabs - name: Tabs remover (C++) - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$ - args: [--whitespaces-count, '2'] - - id: remove-tabs - name: Tabs remover (Python) - files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ - args: [--whitespaces-count, '4'] - # Exclude some unit test files that require tabs. - exclude: | - (?x)^( - test/dygraph_to_static/test_error.py - )$ - - repo: local - hooks: - - id: copyright_checker - name: copyright_checker - entry: python ./tools/codestyle/copyright.py - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|pyi|sh)$ - exclude: | - (?x)^( - paddle/utils/.*| - paddle/cinn/utils/registry.h - )$ - - repo: https://github.com/PFCCLab/typos-pre-commit-mirror.git - rev: v1.30.2 - hooks: - - id: typos - args: [--force-exclude] - # For Python files - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.1.0 - hooks: - - id: black - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.11 - hooks: - - id: ruff - args: [--fix, --exit-non-zero-on-fix, --no-cache] - # For C++ files - - repo: local - hooks: - - id: clang-format - name: clang-format - description: Format files with ClangFormat. - entry: bash ./tools/codestyle/clang_format.sh -i - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$ - - repo: local - hooks: - - id: cpplint-cpp-source - name: cpplint - description: Check C++ code style using cpplint.py. - entry: bash ./tools/codestyle/cpplint_pre_commit.sh - language: system - files: \.(cc|cxx|cpp|cu|h|hpp|hxx)$ - args: - - --extensions=cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps - - --filter=-readability/fn_size,-build/include_what_you_use,-build/c++11,-whitespace/parens - - --quiet - # Exclude third-party libraries - exclude: | - (?x)^( - paddle/utils/flat_hash_map\.h - )$ - - repo: local - hooks: - - id: clang-tidy - name: clang-tidy - description: Parallel clang-tidy runner. - entry: python ./tools/codestyle/clang-tidy.py - language: system - files: \.(c|cc|cxx|cpp|h|hpp|hxx)$ - args: - - -p=build/ - - -extra-arg=-Wno-unknown-warning-option - - -extra-arg=-Wno-pessimizing-move - - -extra-arg=-Wno-braced-scalar-init - - -extra-arg=-Wno-dangling-gsl - - -extra-arg=-Wno-deprecated-copy - - -extra-arg=-Wno-final-dtor-non-final-class - - -extra-arg=-Wno-implicit-int-float-conversion - - -extra-arg=-Wno-inconsistent-missing-override - - -extra-arg=-Wno-infinite-recursion - - -extra-arg=-Wno-mismatched-tags - - -extra-arg=-Wno-self-assign - - -extra-arg=-Wno-sign-compare - - -extra-arg=-Wno-sometimes-uninitialized - - -extra-arg=-Wno-tautological-overlap-compare - - -extra-arg=-Wno-unused-const-variable - - -extra-arg=-Wno-unused-lambda-capture - - -extra-arg=-Wno-unused-private-field - - -extra-arg=-Wno-unused-value - - -extra-arg=-Wno-unused-variable - - -extra-arg=-Wno-overloaded-virtual - - -extra-arg=-Wno-defaulted-function-deleted - - -extra-arg=-Wno-delete-non-abstract-non-virtual-dtor - - -extra-arg=-Wno-return-type-c-linkage - # For CMake files - - repo: local - hooks: - - id: auto-generate-cmakelists - name: auto-generate-cmakelists - entry: bash ./tools/gen_ut_cmakelists.hook - language: system - files: testslist.csv$ - - repo: https://github.com/cheshirekow/cmake-format-precommit - rev: v0.6.13 - hooks: - - id: cmake-format - # exclude paddle/fluid/operators/CMakeLists.txt, see the comment - # https://github.com/PaddlePaddle/Paddle/pull/43057#pullrequestreview-993471860 - exclude: | - (?x)^( - paddle/fluid/operators/CMakeLists.txt - )$ - - repo: https://github.com/PFCCLab/cmake-lint-paddle - rev: v1.5.1 - hooks: - - id: cmakelint - args: [--config=./tools/codestyle/.cmakelintrc] - # Exclude some files has false positive warnings - # Need to fix them in the future - exclude: | - (?x)^( - cmake/external/onnxruntime.cmake - )$ - # For YAML files - - repo: https://github.com/PFCCLab/yamlfmt-pre-commit-mirror.git - rev: v0.16.0 - hooks: - - id: yamlfmt - files: | - (?x)^( - \.github/.+\.(yaml|yml)| - \.pre-commit-config\.yaml| - \.yamlfmt - ) - # Others - - repo: local - hooks: - - id: sort-txt-file - name: sort-txt-file - description: Sorts each line string in a text file - entry: python ./tools/codestyle/sort_txt_file.py - language: python - files: test/white_list/pir_op_test_white_list - args: [] diff --git a/third_party/openblas b/third_party/openblas index 5ef8b1964658f9..5f36f18148603f 160000 --- a/third_party/openblas +++ b/third_party/openblas @@ -1 +1 @@ -Subproject commit 5ef8b1964658f9cb6a6324a06f6a1a022609b0c5 +Subproject commit 5f36f18148603facb6c3540e673610d6b24cbfbb From f3d732043c0d625c88728c04cb22f14aa085ce2d Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 06:18:41 +0000 Subject: [PATCH 45/71] Add incubate port. --- paddle/phi/kernels/moe_fuse_bwd_op.h | 2 +- python/paddle/incubate/nn/functional/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index 9b7f669729dea4..e8fdc561c7ef79 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -17,7 +17,7 @@ #include "paddle/common/exception.h" #include "paddle/phi/kernels/moe_kernel_impl.h" - +p/ template __global__ void gather_with_mask_permute_kernel(const T* dy, // [s*k, d] const int* scatter_index, // [s, k] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index a64a911a9c29e8..66286cbe957900 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -50,6 +50,9 @@ from .build_src_rank_and_local_expert_id import build_src_rank_and_local_expert_id from .int_bincount import int_bincount from .fused_rms_norm_ext import fused_rms_norm_ext +from .moe_gate_dispatch import moe_gate_dispatch +from .moe_gate_dispatch_permute import moe_gate_dispatch_permute +from .moe_ops_partial_nosoftmaxtopk import moe_gate_dispatch_partial_nosoftmaxtopk __all__ = [ 'fused_multi_head_attention', @@ -75,4 +78,8 @@ "build_src_rank_and_local_expert_id" "int_bincount", "fused_rms_norm_ext", + "moe_gate_dispatch", + "moe_gate_dispatch_permute", + "moe_gate_dispatch_partial_nosoftmaxtopk", + ] From c76a357dc3e212c902a0fe48fe724cb6ce5d00c1 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 06:21:30 +0000 Subject: [PATCH 46/71] fix miscs --- paddle/phi/kernels/moe_fuse_bwd_op.h | 1 - third_party/openblas | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index e8fdc561c7ef79..b2b30638d59966 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -17,7 +17,6 @@ #include "paddle/common/exception.h" #include "paddle/phi/kernels/moe_kernel_impl.h" -p/ template __global__ void gather_with_mask_permute_kernel(const T* dy, // [s*k, d] const int* scatter_index, // [s, k] diff --git a/third_party/openblas b/third_party/openblas index 5f36f18148603f..5ef8b1964658f9 160000 --- a/third_party/openblas +++ b/third_party/openblas @@ -1 +1 @@ -Subproject commit 5f36f18148603facb6c3540e673610d6b24cbfbb +Subproject commit 5ef8b1964658f9cb6a6324a06f6a1a022609b0c5 From d4a34722ae2f686d74821fb6edb64155184aee11 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 06:34:44 +0000 Subject: [PATCH 47/71] Fix module issue --- python/paddle/incubate/nn/functional/__init__.py | 2 +- ...axtopk.py => moe_gate_dispatch_partial_nosoftmaxtopk.py} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename python/paddle/incubate/nn/functional/{moe_ops_partial_nosoftmaxtopk.py => moe_gate_dispatch_partial_nosoftmaxtopk.py} (95%) diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 66286cbe957900..a408c2cd16192a 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -52,7 +52,7 @@ from .fused_rms_norm_ext import fused_rms_norm_ext from .moe_gate_dispatch import moe_gate_dispatch from .moe_gate_dispatch_permute import moe_gate_dispatch_permute -from .moe_ops_partial_nosoftmaxtopk import moe_gate_dispatch_partial_nosoftmaxtopk +from .moe_gate_dispatch_partial_nosoftmaxtopk import moe_gate_dispatch_partial_nosoftmaxtopk __all__ = [ 'fused_multi_head_attention', diff --git a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py similarity index 95% rename from python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py rename to python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py index 5fbf53515b75af..07853b7a4e49fb 100644 --- a/python/paddle/incubate/nn/functional/moe_ops_partial_nosoftmaxtopk.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from paddle import Tensor -def moe_ops_partial_nosoftmaxtopk( +def moe_gate_dispatch_partial_nosoftmaxtopk( x: Tensor, combine_weights: Tensor, expert_id: Tensor, @@ -23,7 +23,7 @@ def moe_ops_partial_nosoftmaxtopk( ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: if in_dynamic_or_pir_mode(): return _C_ops.moe_gate_dispatch_partial_nosoftmaxtopk(x, combine_weights, expert_id, k, capacity, num_experts, use_pad, expert_start_index, expert_end_index, reverse_token_drop) - helper = LayerHelper("moe_ops_partial_nosoftmaxtopk", **locals()) + helper = LayerHelper("moe_gate_dispatch_partial_nosoftmaxtopk", **locals()) y = helper.create_variable_for_type_inference(dtype=x.dtype) combine_weights_out = helper.create_variable_for_type_inference(dtype=combine_weights.dtype) scatter_index = helper.create_variable_for_type_inference(dtype='int32') @@ -53,7 +53,7 @@ def moe_ops_partial_nosoftmaxtopk( "reverse_token_drop": reverse_token_drop, } helper.append_op( - type="moe_ops_partial_nosoftmaxtopk", + type="moe_gate_dispatch_partial_nosoftmaxtopk", inputs=inputs, outputs=outputs, attrs=attrs, From 36a450f9f73563bd1f25d5593df8d4f97454ac8b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 07:11:52 +0000 Subject: [PATCH 48/71] Add missing yamls --- paddle/phi/ops/yaml/backward.yaml | 10 ++++++++++ paddle/phi/ops/yaml/ops.yaml | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 27af6dddb4fbc3..151e0006a57369 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2280,6 +2280,16 @@ kernel : func : moe_combine_grad +- backward_op : moe_gate_dispatch_grad + forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(x_grad), Tensor(gate_logits_grad) + infer_meta : + func : MoeGateDispatchGradInferMeta + kernel : + func : moe_gate_dispatch_grad + data_type : y_grad + - backward_op : moe_gate_dispatch_partial_nosoftmaxtopk_grad forward : moe_gate_dispatch_partial_nosoftmaxtopk (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) -> Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) args : (Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_nums_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 568ac628b037df..03586b2f51353b 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3637,6 +3637,17 @@ data_type : x backward : moe_combine_grad +- op : moe_gate_dispatch + args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + infer_meta : + func : MoeGateDispatchInferMeta + kernel : + func : moe_gate_dispatch + data_type : x + optional : corr_bias + backward : moe_gate_dispatch_grad + - op : moe_gate_dispatch_partial_nosoftmaxtopk args : (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) output : Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) From 89402c3bd76464440a5d50c4002e482ce99b9216 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 09:17:05 +0000 Subject: [PATCH 49/71] Fix stale package problems --- .../kernels/gpu/moe_gate_dispatch_kernel.cu | 2 ++ .../gpu/moe_gate_dispatch_permute_kernel.cu | 2 ++ .../ernie_utils/moe_all_gather_layer.py | 36 +++---------------- .../ernie_utils/moe_layer_uneven.py | 27 +++++--------- test/legacy_test/test_incubate_moe_combine.py | 3 +- 5 files changed, 19 insertions(+), 51 deletions(-) diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu index e45cbab45932a5..7912c38665fd16 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/moe_fuse_op.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { // -------- getWorkspaceSize -------- // @@ -338,6 +339,7 @@ void MoeGradDispatchKernel(const Context &dev_ctx, dev_ctx.template Alloc(combine_weights); dev_ctx.template Alloc(y); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); auto x_dims = x.dims(); auto gate_logits_dims = gate_logits.dims(); diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu index 0a7eba28955832..288ec03554b499 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { namespace { @@ -304,6 +305,7 @@ void MoEDispatchPermuteKernel(const Context& dev_ctx, dev_ctx.template Alloc(scatter_index); dev_ctx.template Alloc(combine_weights); dev_ctx.template Alloc(y); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); const auto &x_shape = x.dims(); const auto &gate_logits_shape = gate_logits.dims(); int64_t num_rows = x_shape[0]; diff --git a/test/legacy_test/ernie_utils/moe_all_gather_layer.py b/test/legacy_test/ernie_utils/moe_all_gather_layer.py index 334d1b9c6a823f..87eecb45b97ff0 100644 --- a/test/legacy_test/ernie_utils/moe_all_gather_layer.py +++ b/test/legacy_test/ernie_utils/moe_all_gather_layer.py @@ -34,17 +34,13 @@ from .top2_gate import TopKGateFused, compute_optimal_transport from paddle.incubate.tensor.manipulation import async_offload, async_reload - +from paddle.incubate.nn.functional import expand_modality_expert_id from .moe_layer import MOELayer, fuse_logging try: from src.utils.misc import global_training_logs except ModuleNotFoundError: global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 -try: - import moe_router_loss_ops -except ImportError: - moe_router_loss_ops = None def profile(_): @@ -61,32 +57,8 @@ def profile(_): xpu_moe_gate_dispatch = None logger.warning("`xpu moe dispatch` not found") else: - try: - import moe_ops - except ImportError: - moe_ops = None - logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") - try: - import moe_ops_partial - except ImportError: - moe_ops_partial = None - logger.warning( - "`moe-ops-partial` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install" - ) - try: - import moe_ops_partial_nosoftmaxtopk - except ImportError: - moe_ops_partial_nosoftmaxtopk = None - logger.warning( - "`moe-ops-partial-nosoftmaxtopk` not found, run " - "`python3 src/ernie_core/ops/moe/setup.py install` to install" - ) + pass - try: - import moe_utils - except ImportError: - moe_utils = None - logger.warning("`moe_utils` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") class MOEAllGatherLayer(MOELayer): """_summary_ @@ -221,7 +193,7 @@ def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_ weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias # num_expert_per_modality == 0 时只执行 group-expert expand,不执行 multimodal-expand - expert_id_lm = moe_utils.expand_modality_expert_id( + expert_id_lm = expand_modality_expert_id( expert_id_lm, num_expert_per_modality=num_expert_per_rank_per_modality if (token_type_ids is not None and gate_logits_mm is not None) @@ -245,7 +217,7 @@ def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_ batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias - expert_id_mm = moe_utils.expand_modality_expert_id( + expert_id_mm = expand_modality_expert_id( expert_id_mm, num_expert_per_modality=num_expert_per_rank_per_modality, group_size=group_size, diff --git a/test/legacy_test/ernie_utils/moe_layer_uneven.py b/test/legacy_test/ernie_utils/moe_layer_uneven.py index 3b166205453341..b44809721b2b06 100644 --- a/test/legacy_test/ernie_utils/moe_layer_uneven.py +++ b/test/legacy_test/ernie_utils/moe_layer_uneven.py @@ -13,6 +13,7 @@ import numpy as np import paddle +from paddle import _C_ops from paddle import nn from paddle.distributed.communication import stream @@ -55,25 +56,12 @@ logger.warning("`xpu moe combine` not found") else: try: - import moe_ops - except ImportError: - moe_ops = None - logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") - - try: - import moe_combine + from paddle.incubate.nn.functional import moe_combine except ImportError: moe_combine = None logger.warning("`moe-combine` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") -try: - import moe_ops_no_softmaxtopk -except ImportError: - moe_ops_no_softmaxtopk = None - logger.warning( - "moe-ops-no-softmaxtopk` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install" - ) def average_grad(x, y, dy, eps=1e-12): @@ -225,11 +213,12 @@ def forward(ctx, x, combine_weights, scatter_index): return xpu_moe_combine(x, combine_weights, scatter_index) else: assert moe_combine is not None - ret = moe_combine.moe_combine(x, combine_weights, scatter_index) + ret = moe_combine(x, combine_weights, scatter_index) return ret @staticmethod def backward(ctx, grad_y, *_): + ''' """ Input: grad_y: [seqlen, hidden_size] @@ -248,14 +237,16 @@ def backward(ctx, grad_y, *_): ) else: assert moe_combine is not None - grad_x, grad_combine_weight_helper = moe_combine.moe_combine_bwd( + grad_x, grad_combine_weight_helper = _C_ops.moe_combine_grad( ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y ) # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] # reduce the hidden shape # TODO: implement reduce in cuda ops - grad_combine_weight = grad_combine_weight_helper.sum(-1) - return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + #grad_combine_weight = grad_combine_weight_helper.sum(-1) + #return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + return grad_x, grad_combine_weight_helper + ''' diff --git a/test/legacy_test/test_incubate_moe_combine.py b/test/legacy_test/test_incubate_moe_combine.py index abc9a7b95f0645..1e6d62d422aeef 100644 --- a/test/legacy_test/test_incubate_moe_combine.py +++ b/test/legacy_test/test_incubate_moe_combine.py @@ -66,7 +66,8 @@ def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_n grad = paddle.to_tensor(grad_numpy).cast("float32") y = GateCombine.apply(x, combine_weights, scatter_index) - paddle.autograd.backward([y], [grad], True) + #paddle.autograd.backward([y], [grad], True) + grad.backward() return [x.grad, combine_weights.grad, y] From 422850eb523e92cd6e0f585bd1a3214ffc8b33a4 Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 09:39:47 +0000 Subject: [PATCH 50/71] fix moe_combine bug. --- paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu | 3 +++ paddle/phi/kernels/gpu/moe_combine_kernel.cu | 2 ++ .../gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu | 1 - test/legacy_test/ernie_utils/moe_layer_uneven.py | 8 +++----- test/legacy_test/test_incubate_moe_combine.py | 4 ++-- ...st_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py | 3 --- 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu index 6d1247cd856120..d85ff0c5b5ffb2 100644 --- a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -1,6 +1,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/moe_combine_grad_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { template @@ -129,6 +130,8 @@ void MoeCombineGradKernel(const Context& dev_ctx, DenseTensor* grad_combine_weights_helper) { dev_ctx.template Alloc(grad_x); dev_ctx.template Alloc(grad_combine_weights_helper); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())), 0, grad_combine_weights_helper); auto x_shape = x.dims(); auto combine_weights_shape = combine_weights.dims(); moe_combine_bwd(dev_ctx, diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu index ae57420fc985f5..79d9e9d5515082 100644 --- a/paddle/phi/kernels/gpu/moe_combine_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -1,6 +1,7 @@ #include "paddle/phi/kernels/moe_combine_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -91,6 +92,7 @@ void moe_combine_fwd(const Context& dev_ctx, DenseTensor* y) { dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very // well, maybe replaced with x.dtype(); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); auto combine_weights_shape = combine_weights.dims(); auto x_shape = x.dims(); moe_combine_fwd(dev_ctx, diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu index 1c45398a60a2f7..9e2d0275bc0a6d 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -110,7 +110,6 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, int64_t expert_end_index, DenseTensor* x_grad, DenseTensor* combine_weights_grad){ - printf("MoeGateDispatchPartialNoSoftMaxTopkGradKernel begin\n"); dev_ctx.template Alloc(x_grad); dev_ctx.template Alloc(combine_weights_grad); // DenseTensor t_scatter_index; diff --git a/test/legacy_test/ernie_utils/moe_layer_uneven.py b/test/legacy_test/ernie_utils/moe_layer_uneven.py index b44809721b2b06..a6be1bbb84ba5e 100644 --- a/test/legacy_test/ernie_utils/moe_layer_uneven.py +++ b/test/legacy_test/ernie_utils/moe_layer_uneven.py @@ -218,7 +218,6 @@ def forward(ctx, x, combine_weights, scatter_index): @staticmethod def backward(ctx, grad_y, *_): - ''' """ Input: grad_y: [seqlen, hidden_size] @@ -243,10 +242,9 @@ def backward(ctx, grad_y, *_): # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] # reduce the hidden shape # TODO: implement reduce in cuda ops - #grad_combine_weight = grad_combine_weight_helper.sum(-1) - #return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None - return grad_x, grad_combine_weight_helper - ''' + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + #return grad_x, grad_combine_weight_helper diff --git a/test/legacy_test/test_incubate_moe_combine.py b/test/legacy_test/test_incubate_moe_combine.py index 1e6d62d422aeef..6c39a1af222933 100644 --- a/test/legacy_test/test_incubate_moe_combine.py +++ b/test/legacy_test/test_incubate_moe_combine.py @@ -66,8 +66,8 @@ def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_n grad = paddle.to_tensor(grad_numpy).cast("float32") y = GateCombine.apply(x, combine_weights, scatter_index) - #paddle.autograd.backward([y], [grad], True) - grad.backward() + paddle.autograd.backward([y], [grad], True) + #grad.backward() return [x.grad, combine_weights.grad, y] diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py index 05fbfe3aed7eb8..b688280b779706 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -12,7 +12,6 @@ def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(): - import moe_ops_partial_nosoftmaxtopk s, d, e = 4, 100, 8 k, cap = 4, 3 @@ -137,7 +136,6 @@ def check_ascend(index_rev, chunks): def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): - import moe_ops_partial_nosoftmaxtopk S, E, D = 3, 4, 3 k = 2 @@ -162,7 +160,6 @@ def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): def test_moe_ops_partial_nosoftmax_topk_empty_output(): - import moe_ops_partial_nosoftmaxtopk S, E, D = 3, 4, 3 k = 2 From 4c6ab1313d660c62cd9d1a12b8b4f7fefb1fbbad Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 10:39:07 +0000 Subject: [PATCH 51/71] Fix miscs --- .../phi/kernels/build_src_rank_and_local_expert_id_kernel.h | 4 +--- .../kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h index 866ce93aac7cdd..8fd7f13e6f6649 100644 --- a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h +++ b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h @@ -16,14 +16,12 @@ #include "paddle/phi/core/dense_tensor.h" namespace phi { - template -void BuildSrcRankAndLocalExpertIdInferMeta( +void BuildSrcRankAndLocalExpertIdKernel( const Context& dev_ctx, const DenseTensor& expert_num_global_tensor, const std::vector& expert_num_global, int64_t num_local_experts, DenseTensor* src_rank, DenseTensor* local_expert_id); - } // namespace phi diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index f9508cb47023ac..f35a5b7f6421c5 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/cal_aux_loss_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" From 81aac420a8f59d015fccccca5b572c7eb7e93b3b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 11:28:30 +0000 Subject: [PATCH 52/71] Align with original initializations. --- .pre-commit-config.yaml | 175 ++++++++++++++++++ paddle/phi/infermeta/backward.cc | 5 - ...e_ops_partial_nosoftmaxtopk_grad_kernel.cu | 3 +- .../moe_ops_partial_nosoftmaxtopk_kernel.cu | 7 + ...moe_gate_dispatch_partial_nosoftmaxtopk.py | 1 + 5 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000000..a72994c245fa32 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,175 @@ +# Exclude all third-party libraries and auto-generated files globally +exclude: | + (?x)^( + patches/.+| + paddle/fluid/framework/fleet/heter_ps/cudf/.+| + paddle/fluid/distributed/ps/thirdparty/round_robin.h| + python/paddle/utils/gast/.+| + third_party/.+ + )$ +repos: + # Common hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + - id: end-of-file-fixer + - id: sort-simple-yaml + files: (ops|backward|op_[a-z_]+)\.yaml$ + - id: trailing-whitespace + - repo: https://github.com/Lucas-C/pre-commit-hooks.git + rev: v1.5.1 + hooks: + - id: remove-crlf + - id: remove-tabs + name: Tabs remover (C++) + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$ + args: [--whitespaces-count, '2'] + - id: remove-tabs + name: Tabs remover (Python) + files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ + args: [--whitespaces-count, '4'] + # Exclude some unit test files that require tabs. + exclude: | + (?x)^( + test/dygraph_to_static/test_error.py + )$ + - repo: local + hooks: + - id: copyright_checker + name: copyright_checker + entry: python ./tools/codestyle/copyright.py + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|pyi|sh)$ + exclude: | + (?x)^( + paddle/utils/.*| + paddle/cinn/utils/registry.h + )$ + - repo: https://github.com/PFCCLab/typos-pre-commit-mirror.git + rev: v1.30.2 + hooks: + - id: typos + args: [--force-exclude] + # For Python files + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.1.0 + hooks: + - id: black + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.11 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix, --no-cache] + # For C++ files + - repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: bash ./tools/codestyle/clang_format.sh -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$ + - repo: local + hooks: + - id: cpplint-cpp-source + name: cpplint + description: Check C++ code style using cpplint.py. + entry: bash ./tools/codestyle/cpplint_pre_commit.sh + language: system + files: \.(cc|cxx|cpp|cu|h|hpp|hxx)$ + args: + - --extensions=cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps + - --filter=-readability/fn_size,-build/include_what_you_use,-build/c++11,-whitespace/parens + - --quiet + # Exclude third-party libraries + exclude: | + (?x)^( + paddle/utils/flat_hash_map\.h + )$ + - repo: local + hooks: + - id: clang-tidy + name: clang-tidy + description: Parallel clang-tidy runner. + entry: python ./tools/codestyle/clang-tidy.py + language: system + files: \.(c|cc|cxx|cpp|h|hpp|hxx)$ + args: + - -p=build/ + - -extra-arg=-Wno-unknown-warning-option + - -extra-arg=-Wno-pessimizing-move + - -extra-arg=-Wno-braced-scalar-init + - -extra-arg=-Wno-dangling-gsl + - -extra-arg=-Wno-deprecated-copy + - -extra-arg=-Wno-final-dtor-non-final-class + - -extra-arg=-Wno-implicit-int-float-conversion + - -extra-arg=-Wno-inconsistent-missing-override + - -extra-arg=-Wno-infinite-recursion + - -extra-arg=-Wno-mismatched-tags + - -extra-arg=-Wno-self-assign + - -extra-arg=-Wno-sign-compare + - -extra-arg=-Wno-sometimes-uninitialized + - -extra-arg=-Wno-tautological-overlap-compare + - -extra-arg=-Wno-unused-const-variable + - -extra-arg=-Wno-unused-lambda-capture + - -extra-arg=-Wno-unused-private-field + - -extra-arg=-Wno-unused-value + - -extra-arg=-Wno-unused-variable + - -extra-arg=-Wno-overloaded-virtual + - -extra-arg=-Wno-defaulted-function-deleted + - -extra-arg=-Wno-delete-non-abstract-non-virtual-dtor + - -extra-arg=-Wno-return-type-c-linkage + # For CMake files + - repo: local + hooks: + - id: auto-generate-cmakelists + name: auto-generate-cmakelists + entry: bash ./tools/gen_ut_cmakelists.hook + language: system + files: testslist.csv$ + - repo: https://github.com/cheshirekow/cmake-format-precommit + rev: v0.6.13 + hooks: + - id: cmake-format + # exclude paddle/fluid/operators/CMakeLists.txt, see the comment + # https://github.com/PaddlePaddle/Paddle/pull/43057#pullrequestreview-993471860 + exclude: | + (?x)^( + paddle/fluid/operators/CMakeLists.txt + )$ + - repo: https://github.com/PFCCLab/cmake-lint-paddle + rev: v1.5.1 + hooks: + - id: cmakelint + args: [--config=./tools/codestyle/.cmakelintrc] + # Exclude some files has false positive warnings + # Need to fix them in the future + exclude: | + (?x)^( + cmake/external/onnxruntime.cmake + )$ + # For YAML files + - repo: https://github.com/PFCCLab/yamlfmt-pre-commit-mirror.git + rev: v0.16.0 + hooks: + - id: yamlfmt + files: | + (?x)^( + \.github/.+\.(yaml|yml)| + \.pre-commit-config\.yaml| + \.yamlfmt + ) + # Others + - repo: local + hooks: + - id: sort-txt-file + name: sort-txt-file + description: Sorts each line string in a text file + entry: python ./tools/codestyle/sort_txt_file.py + language: python + files: test/white_list/pir_op_test_white_list + args: [] diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 710c300f7a3af7..72823d7acbf68c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1280,15 +1280,10 @@ void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_ 0, common::errors::InvalidArgument("Input y_grad.dims()[0] should be greater than 0")); } - printf("y_grad shape: %d", y_grad.dims().size()); - printf("combine_weights_out_grad shape: %d, y_grad shape: %d", combine_weights_out_grad.dims().size(), y_grad.dims().size()); - printf("allocate combine_weights_grad\n"); combine_weights_grad->set_dims(combine_weights_out_grad.dims()); combine_weights_grad->set_dtype(phi::DataType::FLOAT32); - printf("allocate x_grad\n"); x_grad->set_dims({num_rows, hidden_size}); x_grad->set_dtype(y_grad.dtype()); - printf("check infer over\n"); } void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu index 9e2d0275bc0a6d..b0123a4d0e9ab8 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -22,6 +22,7 @@ #include "paddle/phi/kernels/moe_fuse_bwd_op.h" #include #include +#include "paddle/phi/kernels/full_kernel.h" namespace phi{ @@ -112,6 +113,7 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, DenseTensor* combine_weights_grad){ dev_ctx.template Alloc(x_grad); dev_ctx.template Alloc(combine_weights_grad); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(combine_weights_grad->dims())), 0, combine_weights_grad); // DenseTensor t_scatter_index; // printf("check pass\n"); // phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); @@ -119,7 +121,6 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, // phi::ContiguousKernel(dev_ctx, t_scatter_index, &t_scatter_index_out); // t_scatter_index = t_scatter_index_out; // int64_t num_experts = expert_offset.dims()[0]; - printf("dive into moe_dispatch_bwd\n"); // moe_dispatch_bwd(dev_ctx, // combine_weights_out, // t_scatter_index, diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu index 6570adec9a813a..6d009350263aed 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -26,6 +26,7 @@ #include "paddle/phi/kernels/slice_kernel.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/moe_kernel_impl.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -392,6 +393,7 @@ void apply_moe_dispatch_fwd( y->Resize({expert_offset_host.back(), x.dims()[1]}); dev_ctx.template Alloc(y); } + phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); copy_unpermuted_to_permuted_kernelLauncher(x.data(), y->data(), //out scatter_index_rev, //padded_out_to_unpermuted_input @@ -473,6 +475,11 @@ void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, dev_ctx.template Alloc(expert_offset); dev_ctx.template Alloc(expert_nums_local); dev_ctx.template Alloc(combine_weights_out); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(scatter_index->dims())), 0, scatter_index); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(scatter_index_rev->dims())), 0, scatter_index_rev); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(expert_offset->dims())), 0, expert_offset); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(expert_nums_local->dims())), 0, expert_nums_local); + phi::Full(dev_ctx, phi::IntArray(common::vectorize(combine_weights_out->dims())), 0, combine_weights_out); phi::Copy(dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); const auto &x_shape = x.dims(); int64_t num_rows = x_shape[0]; diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py index b688280b779706..0a0f7a586a3f6b 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -166,6 +166,7 @@ def test_moe_ops_partial_nosoftmax_topk_empty_output(): capacity = 2 x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") cw = paddle.randn([S, k]) + paddle.device.synchronize() eid = paddle.to_tensor([[0, 1], [0, 1], [0, 2]], dtype="int32") # 1 # 2 # 3 ( y, From 82acb48f2748db5606039631e936df3343bd424b Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 12:30:25 +0000 Subject: [PATCH 53/71] fix typos and pre-commit warnings --- paddle/phi/infermeta/backward.cc | 112 +-- paddle/phi/infermeta/backward.h | 50 +- paddle/phi/infermeta/binary.cc | 10 +- paddle/phi/infermeta/binary.h | 11 +- paddle/phi/infermeta/ternary.cc | 256 ++++--- paddle/phi/infermeta/ternary.h | 33 +- paddle/phi/infermeta/unary.cc | 22 +- paddle/phi/infermeta/unary.h | 2 +- .../expand_modality_expert_id_kernel.h | 16 +- ...ild_src_rank_and_local_expert_id_kernel.cu | 12 +- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 16 +- .../gpu/expand_modality_expert_id_kernel.cu | 59 +- paddle/phi/kernels/gpu/int_bincount.cu | 85 ++- .../phi/kernels/gpu/layer_norm_cuda_kernel.cu | 49 +- .../kernels/gpu/moe_combine_grad_kernel.cu | 48 +- paddle/phi/kernels/gpu/moe_combine_kernel.cu | 75 +- .../kernels/gpu/moe_gate_dispatch_kernel.cu | 12 +- .../moe_gate_dispatch_permute_grad_kernel.cu | 169 ++-- .../gpu/moe_gate_dispatch_permute_kernel.cu | 325 ++++---- ...e_ops_partial_nosoftmaxtopk_grad_kernel.cu | 207 ++--- .../moe_ops_partial_nosoftmaxtopk_kernel.cu | 719 ++++++++++-------- paddle/phi/kernels/int_bincount.h | 39 +- paddle/phi/kernels/layer_norm_cuda_kernel.h | 386 +++++----- paddle/phi/kernels/moe_combine_grad_kernel.h | 2 +- paddle/phi/kernels/moe_combine_kernel.h | 2 +- paddle/phi/kernels/moe_fuse_bwd_op.h | 476 ++++++------ paddle/phi/kernels/moe_fuse_op.h | 38 +- .../moe_gate_dispatch_permute_grad_kernel.h | 2 +- .../moe_gate_dispatch_permute_kernel.h | 8 +- paddle/phi/kernels/moe_kernel_impl.h | 71 +- ...oe_ops_partial_nosoftmaxtopk_grad_kernel.h | 35 +- .../moe_ops_partial_nosoftmaxtopk_kernel.h | 39 +- paddle/phi/ops/yaml/backward.yaml | 18 +- paddle/phi/ops/yaml/ops.yaml | 32 +- .../paddle/incubate/nn/functional/__init__.py | 39 +- .../functional/expand_modality_expert_id.py | 64 +- .../nn/functional/fused_rms_norm_ext.py | 27 +- .../incubate/nn/functional/int_bincount.py | 29 +- .../incubate/nn/functional/moe_combine.py | 30 +- ...moe_gate_dispatch_partial_nosoftmaxtopk.py | 53 +- .../functional/moe_gate_dispatch_permute.py | 60 +- .../ernie_utils/moe_all_gather_layer.py | 111 ++- test/legacy_test/ernie_utils/moe_layer.py | 97 ++- .../ernie_utils/moe_layer_uneven.py | 102 ++- test/legacy_test/ernie_utils/top2_gate.py | 327 ++++++-- ...bate_build_src_rank_and_local_expert_id.py | 56 +- ...test_incubate_expand_modality_expert_id.py | 115 ++- test/legacy_test/test_incubate_fused_loss.py | 116 ++- .../test_incubate_fused_rmsnorm_ext.py | 58 +- .../legacy_test/test_incubate_int_bincount.py | 33 +- test/legacy_test/test_incubate_moe_combine.py | 51 +- ...moe_gate_dispatch_partial_nosoftmaxtopk.py | 73 +- ...st_incubate_moe_gate_dispatch_w_permute.py | 122 ++- ...ncubate_moe_gate_dispatch_w_permute_bwd.py | 76 +- 54 files changed, 3006 insertions(+), 2069 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 72823d7acbf68c..9db6ba0857d42b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1222,7 +1222,7 @@ void MoeCombineGradInferMeta(const MetaTensor& x, const MetaTensor& scatter_index, const MetaTensor& y, MetaTensor* grad_x, - MetaTensor* grad_combine_weights_helper){ + MetaTensor* grad_combine_weights_helper) { auto x_dim = x.dims(); auto combine_weights_shape = combine_weights.dims(); PADDLE_ENFORCE_EQ( @@ -1232,53 +1232,55 @@ void MoeCombineGradInferMeta(const MetaTensor& x, "But received X's dimension = %d", x_dim.size())); PADDLE_ENFORCE_EQ( - (scatter_index.dtype() == phi::DataType::INT32), - true, - errors::InvalidArgument( - "The input scatter_index type should be int32" - "But received scatter_index type = %s", - scatter_index.dtype())); - grad_x->set_dims(common::make_ddim({x_dim[0],x_dim[1]})); + (scatter_index.dtype() == phi::DataType::INT32), + true, + errors::InvalidArgument("The input scatter_index type should be int32" + "But received scatter_index type = %s", + scatter_index.dtype())); + grad_x->set_dims(common::make_ddim({x_dim[0], x_dim[1]})); grad_x->set_dtype(x.dtype()); - grad_combine_weights_helper->set_dims(common::make_ddim({combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); + grad_combine_weights_helper->set_dims(common::make_ddim( + {combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); grad_combine_weights_helper->set_dtype(x.dtype()); } -void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_weights_out, - const MetaTensor& scatter_index, - const MetaTensor& scatter_index_rev, - const MetaTensor& expert_offset, - const MetaTensor& expert_offset_local, - const MetaTensor& y_grad, - const MetaTensor& combine_weights_out_grad, - int64_t k, - int64_t capacity, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - MetaTensor* x_grad, - MetaTensor* combine_weights_grad){ +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta( + const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad) { int64_t num_experts = expert_offset.dims()[0]; int64_t hidden_size = y_grad.dims()[1]; int64_t num_rows = scatter_index.dims()[1]; - PADDLE_ENFORCE_GT( - num_experts, - 0, - common::errors::InvalidArgument("Input num_experts should be greater than 0")); - PADDLE_ENFORCE_EQ( - (expert_offset.dtype()==phi::DataType::INT64), - true, - common::errors::InvalidArgument("Input expert_offset type should be int64")); - if(use_pad){ - PADDLE_ENFORCE_GE( - num_experts, - y_grad.dims()[0] / capacity, - common::errors::InvalidArgument( - "Number of experts should be greater than or equal to y_grad.dims()[0]/capacity")); + PADDLE_ENFORCE_GT(num_experts, + 0, + common::errors::InvalidArgument( + "Input num_experts should be greater than 0")); + PADDLE_ENFORCE_EQ((expert_offset.dtype() == phi::DataType::INT64), + true, + common::errors::InvalidArgument( + "Input expert_offset type should be int64")); + if (use_pad) { + PADDLE_ENFORCE_GE(num_experts, + y_grad.dims()[0] / capacity, + common::errors::InvalidArgument( + "Number of experts should be greater than or equal " + "to y_grad.dims()[0]/capacity")); } else { PADDLE_ENFORCE_GT(y_grad.dims()[0], - 0, - common::errors::InvalidArgument("Input y_grad.dims()[0] should be greater than 0")); + 0, + common::errors::InvalidArgument( + "Input y_grad.dims()[0] should be greater than 0")); } combine_weights_grad->set_dims(combine_weights_out_grad.dims()); combine_weights_grad->set_dtype(phi::DataType::FLOAT32); @@ -1295,18 +1297,19 @@ void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, int64_t capacity, int64_t world_size, MetaTensor* x_grad, - MetaTensor* gate_logits_grad){ - + MetaTensor* gate_logits_grad) { auto y_grad_dims = y_grad.dims(); PADDLE_ENFORCE_EQ( - y_grad_dims[1], - world_size, - common::errors::InvalidArgument("The second dimension of y_grad should be equal to world_size, but " - "received y_grad_dims[1] = %d, world_size = %d", - y_grad_dims[1], world_size)); + y_grad_dims[1], + world_size, + common::errors::InvalidArgument( + "The second dimension of y_grad should be equal to world_size, but " + "received y_grad_dims[1] = %d, world_size = %d", + y_grad_dims[1], + world_size)); int64_t num_local_experts = y_grad_dims[0]; int64_t num_experts = world_size * num_local_experts; - int64_t hidden_size = y_grad_dims[y_grad_dims.size()-1]; + int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; int64_t num_rows = scatter_index.dims()[1]; x_grad->set_dims({num_rows, hidden_size}); x_grad->set_dtype(y_grad.dtype()); @@ -2057,14 +2060,13 @@ void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, x_grad->set_dims(common::make_ddim({num_rows, hidden_size})); x_grad->set_dtype(y_grad.dtype()); } -void FusedRMSNormGradInferMeta(const MetaTensor &x, - const MetaTensor &scale, - const MetaTensor &invvar, - const MetaTensor &dy, - float epsilon, - MetaTensor* x_grad, - MetaTensor* scale_grad){ - +void FusedRMSNormGradInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& dy, + float epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad) { x_grad->set_dims(x.dims()); x_grad->set_dtype(x.dtype()); scale_grad->set_dims(scale.dims()); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 267c321d9d5782..72c4c0e69a377c 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -468,22 +468,26 @@ void MoeCombineGradInferMeta(const MetaTensor& x, const MetaTensor& grad_y, MetaTensor* grad_x, MetaTensor* grad_combine_weights_helper); -//Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index) -// output : Tensor(x_grad), Tensor(combine_weights_grad) -void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(const MetaTensor& combine_weights_out, - const MetaTensor& scatter_index, - const MetaTensor& scatter_index_rev, - const MetaTensor& expert_offset, - const MetaTensor& expert_offset_local, - const MetaTensor& y_grad, - const MetaTensor& combine_weights_out_grad, - int64_t k, - int64_t capacity, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - MetaTensor* x_grad, - MetaTensor* combine_weights_grad); +// Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, +// Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor +// combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t +// expert_start_index, int64_t expert_end_index) +// output : Tensor(x_grad), Tensor(combine_weights_grad) +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta( + const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad); void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, const MetaTensor& scatter_index, @@ -734,11 +738,11 @@ void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, MetaTensor* x_grad, MetaTensor* gate_logits_grad); -void FusedRMSNormGradInferMeta(const MetaTensor &x, - const MetaTensor &scale, - const MetaTensor &invvar, - const MetaTensor &dy, - float epsilon, - MetaTensor* x_grad, - MetaTensor* scale_grad); +void FusedRMSNormGradInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& dy, + float epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad); } // namespace phi diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7b75627d338ac5..39aeeacd82a204 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -4592,11 +4592,11 @@ void WeightDequantizeInferMeta(const MetaTensor& x, out->set_dtype(scale.dtype()); } -void FusedRMSNormInferMeta(const MetaTensor &x, - const MetaTensor &scale, - float epsilon, - MetaTensor* y, - MetaTensor* invvar){ +void FusedRMSNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar) { // Y: same shape, dtype, layout as X y->set_dims(x.dims()); y->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 799fc1267e83c6..81041d00c73903 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -790,11 +790,10 @@ void WeightDequantizeInferMeta(const MetaTensor& x, const std::string& algo, const int32_t group_size, MetaTensor* out); -void FusedRMSNormInferMeta(const MetaTensor &x, - const MetaTensor &scale, - float epsilon, - MetaTensor* y, - MetaTensor* invvar); - +void FusedRMSNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar); } // namespace phi diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a30097e46b6a7e..df96838f1132bd 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1615,131 +1615,133 @@ void MultiClassNMSInferMeta(const MetaTensor& bboxes, void MoeCombineInferMeta(const MetaTensor& x, const MetaTensor& combine_weights, const MetaTensor& scatter_index, - MetaTensor* y){ + MetaTensor* y) { auto x_dim = x.dims(); auto combine_weights_shape = combine_weights.dims(); - PADDLE_ENFORCE_EQ( - x_dim.size(), - 2, - common::errors::InvalidArgument("The dimensions of Input(x) must be 1, but " - "received dimensions of" - "Input(x) is [%d]", - x_dim.size())); + PADDLE_ENFORCE_EQ(x_dim.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 1, but " + "received dimensions of" + "Input(x) is [%d]", + x_dim.size())); // maybe there is more conditions here.... y->set_dims(phi::make_ddim({combine_weights_shape[0], x_dim[1]})); y->set_dtype(x.dtype()); } -void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& expert_id, - int64_t k, - int64_t capacity, - int64_t num_experts, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse_token_drop, - MetaTensor* y, - MetaTensor* combine_weights_out, - MetaTensor* scatter_index, - MetaTensor* scatter_index_rev, - MetaTensor* expert_offset, - MetaTensor* expert_nums_local){ +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta( + const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local) { auto x_dims = x.dims(); - PADDLE_ENFORCE_EQ( - x_dims.size(), - 2, - common::errors::InvalidArgument("The dimensions of Input(x) must be 2, but " - "received dimensions of" - "Input(x) is [%d]", - x_dims.size())); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); auto combine_weights_dims = combine_weights.dims(); PADDLE_ENFORCE_EQ( combine_weights_dims.size(), 2, - common::errors::InvalidArgument("The dimensions of Input(combine_weights) must be 2, but " - "received dimensions of" - "Input(combine_weights) is [%d]", - combine_weights_dims.size())); - PADDLE_ENFORCE_EQ( - combine_weights_dims[0], - x_dims[0], common::errors::InvalidArgument( - "The first dimensions of Input(combine_weights) must be equal to the first " - "dimension of Input(x), but received Input(combine_weights) shape is [%d]," - "Input(x) shape is [%d]", - combine_weights_dims[0], - x_dims[0])); - PADDLE_ENFORCE_GT( - expert_end_index, - expert_start_index, - common::errors::InvalidArgument( - "expert_end_index must be greater than expert_start_index, but received " - "expert_end_index = %d, expert_start_index = %d", - expert_end_index, - expert_start_index)); + "The dimensions of Input(combine_weights) must be 2, but " + "received dimensions of" + "Input(combine_weights) is [%d]", + combine_weights_dims.size())); + PADDLE_ENFORCE_EQ(combine_weights_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(combine_weights) must " + "be equal to the first " + "dimension of Input(x), but received " + "Input(combine_weights) shape is [%d]," + "Input(x) shape is [%d]", + combine_weights_dims[0], + x_dims[0])); + PADDLE_ENFORCE_GT(expert_end_index, + expert_start_index, + common::errors::InvalidArgument( + "expert_end_index must be greater than " + "expert_start_index, but received " + "expert_end_index = %d, expert_start_index = %d", + expert_end_index, + expert_start_index)); PADDLE_ENFORCE_EQ( combine_weights.dtype(), phi::DataType::FLOAT32, - common::errors::InvalidArgument( - "The dtype of Input(combine_weights) must be FLOAT32, but received %s", - combine_weights.dtype())); + common::errors::InvalidArgument("The dtype of Input(combine_weights) " + "must be FLOAT32, but received %s", + combine_weights.dtype())); PADDLE_ENFORCE_EQ( expert_id.dtype(), phi::DataType::INT32, common::errors::InvalidArgument( "The dtype of Input(expert_id) must be INT32, but received %s", expert_id.dtype())); - PADDLE_ENFORCE_GT( - k, - 0, - common::errors::InvalidArgument( - "k must be greater than 0, but received k = %d", - k)); + PADDLE_ENFORCE_GT(k, + 0, + common::errors::InvalidArgument( + "k must be greater than 0, but received k = %d", k)); PADDLE_ENFORCE_GT( x_dims[0], 0, common::errors::InvalidArgument( "num_rows must be greater than 0, but received num_rows = %d", x_dims[0])); - PADDLE_ENFORCE_GE( - num_experts, - k, - common::errors::InvalidArgument( - "num_experts must be greater than or equal to k, but received num_experts = %d, k = %d", - num_experts, - k)); + PADDLE_ENFORCE_GE(num_experts, + k, + common::errors::InvalidArgument( + "num_experts must be greater than or equal to k, but " + "received num_experts = %d, k = %d", + num_experts, + k)); PADDLE_ENFORCE_EQ( !reverse_token_drop || !use_pad, true, common::errors::InvalidArgument( - "use_pad must be false when reverse_token_drop is true, but received use_pad = %d, reverse_token_drop = %d", + "use_pad must be false when reverse_token_drop is true, but received " + "use_pad = %d, reverse_token_drop = %d", use_pad, reverse_token_drop)); PADDLE_ENFORCE_EQ( combine_weights.dtype(), phi::DataType::FLOAT32, - common::errors::InvalidArgument( - "The dtype of Input(combine_weights) must be FLOAT32, but received %s", - combine_weights.dtype())); -//int64_t num_experts_diff = expert_end_index - expert_start_index; -int64_t num_rows = x_dims[0]; -// if (use_pad) -// y->set_dims({num_experts_diff * capacity, x_dims[1]}) ; -y->set_dims({-1, x_dims[1]}); -y->set_dtype(x.dtype()); -scatter_index->set_dims({k, num_rows}); -scatter_index->set_dtype(phi::DataType::INT32); -scatter_index_rev->set_dims({num_experts*capacity}); -scatter_index_rev->set_dtype(phi::DataType::INT32); -expert_offset->set_dims({num_experts}); -expert_offset->set_dtype(phi::DataType::INT64); -expert_nums_local->set_dims({num_experts}); -expert_nums_local->set_dtype(phi::DataType::INT64); -combine_weights_out->set_dims(combine_weights_dims); -combine_weights_out->set_dtype(combine_weights.dtype()); -// combine_weights_out->share_meta(combine_weights); + common::errors::InvalidArgument("The dtype of Input(combine_weights) " + "must be FLOAT32, but received %s", + combine_weights.dtype())); + // int64_t num_experts_diff = expert_end_index - expert_start_index; + int64_t num_rows = x_dims[0]; + // if (use_pad) + // y->set_dims({num_experts_diff * capacity, x_dims[1]}) ; + y->set_dims({-1, x_dims[1]}); + y->set_dtype(x.dtype()); + scatter_index->set_dims({k, num_rows}); + scatter_index->set_dtype(phi::DataType::INT32); + scatter_index_rev->set_dims({num_experts * capacity}); + scatter_index_rev->set_dtype(phi::DataType::INT32); + expert_offset->set_dims({num_experts}); + expert_offset->set_dtype(phi::DataType::INT64); + expert_nums_local->set_dims({num_experts}); + expert_nums_local->set_dtype(phi::DataType::INT64); + combine_weights_out->set_dims(combine_weights_dims); + combine_weights_out->set_dtype(combine_weights.dtype()); + // combine_weights_out->share_meta(combine_weights); } void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, @@ -1752,50 +1754,53 @@ void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, MetaTensor* combine_weights, MetaTensor* scatter_index, MetaTensor* expert_offset, - MetaTensor* expert_id){ + MetaTensor* expert_id) { auto x_dims = x.dims(); - PADDLE_ENFORCE_EQ( - x_dims.size(), - 2, - common::errors::InvalidArgument("The dimensions of Input(x) must be 2, but " - "received dimensions of" - "Input(x) is [%d]", - x_dims.size())); - auto gate_logits_dims= gate_logits.dims(); - PADDLE_ENFORCE_EQ( - gate_logits_dims.size(), - 2, - common::errors::InvalidArgument("The dimensions of Input(gate_logits) must be 2, but " - "received dimensions of" - "Input(gate_logits) is [%d]", - gate_logits_dims.size())); - PADDLE_ENFORCE_EQ( - gate_logits_dims[0], - x_dims[0], - common::errors::InvalidArgument( - "The first dimensions of Input(gate_logits) must be equal to the first " - "dimension of Input(x), but received Input(gate_logits) shape is [%d]," - "Input(x) shape is [%d]", - gate_logits_dims[0], - x_dims[0])); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); + auto gate_logits_dims = gate_logits.dims(); + PADDLE_ENFORCE_EQ(gate_logits_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(gate_logits) must be 2, but " + "received dimensions of" + "Input(gate_logits) is [%d]", + gate_logits_dims.size())); + PADDLE_ENFORCE_EQ(gate_logits_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(gate_logits) must be " + "equal to the first " + "dimension of Input(x), but received " + "Input(gate_logits) shape is [%d]," + "Input(x) shape is [%d]", + gate_logits_dims[0], + x_dims[0])); PADDLE_ENFORCE_EQ( gate_logits_dims[1] % world_size, 0, common::errors::InvalidArgument( - "The number of experts (the second dimension of Input(gate_logits)) must be divisible by world_size, but received " + "The number of experts (the second dimension of Input(gate_logits)) " + "must be divisible by world_size, but received " "num_experts = %d, world_size = %d", gate_logits_dims[1], world_size)); - PADDLE_ENFORCE_GE( - gate_logits_dims[1], - k, - common::errors::InvalidArgument( - "The number of experts ((the second dimension of Input(gate_logits))) must be greater than or equal to k, but received " - "num_experts = %d, k = %d", - gate_logits_dims[1], - k)); - + PADDLE_ENFORCE_GE(gate_logits_dims[1], + k, + common::errors::InvalidArgument( + "The number of experts ((the second dimension of " + "Input(gate_logits))) must be greater than or equal to " + "k, but received " + "num_experts = %d, k = %d", + gate_logits_dims[1], + k)); + PADDLE_ENFORCE_EQ( gate_logits.dtype(), phi::DataType::FLOAT32, @@ -1803,7 +1808,7 @@ void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, "The dtype of Input(gate_logits) must be FLOAT32, but received %s", gate_logits.dtype())); - if(corr_bias){ + if (corr_bias) { auto corr_bias_dims = corr_bias.dims(); PADDLE_ENFORCE_EQ( corr_bias_dims.size(), @@ -1817,7 +1822,8 @@ void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, x_dims[0], common::errors::InvalidArgument( "The dimensions of Input(corr_bias) must be equal to the first " - "dimension of Input(x), but received Input(corr_bias) first dimension is [%d]," + "dimension of Input(x), but received Input(corr_bias) first " + "dimension is [%d]," "Input(x) first dimension is [%d]", corr_bias_dims[0], x_dims[0])); diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index ad729e7bd30afe..c462d51ecd0d9a 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -274,22 +274,23 @@ void MoeCombineInferMeta(const MetaTensor& x, const MetaTensor& scatter_index, MetaTensor* y); -void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(const MetaTensor& x, - const MetaTensor& combine_weights, - const MetaTensor& expert_id, - int64_t k, - int64_t capacity, - int64_t num_experts, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse_token_drop, - MetaTensor* y, - MetaTensor* combine_weights_out, - MetaTensor* scatter_index, - MetaTensor* scatter_index_rev, - MetaTensor* expert_offset, - MetaTensor* expert_nums_local); +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta( + const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local); void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, const MetaTensor& gate_logits, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a839f1c2d2a370..8f5f177f1c3b97 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1370,7 +1370,7 @@ void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, int64_t group_size, int64_t modality_offset, bool is_group_expert, - MetaTensor* expert_id_out){ + MetaTensor* expert_id_out) { auto expert_id_dims = expert_id.dims(); PADDLE_ENFORCE_EQ( expert_id_dims.size(), @@ -1381,12 +1381,13 @@ void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, expert_id_dims.size(), expert_id_dims)); PADDLE_ENFORCE_EQ( - expert_id.dtype() == DataType::INT32 || expert_id.dtype() == DataType::INT64, - true, - common::errors::InvalidArgument( - "The dtype of expert_id should be INT32 or INT64. But received" - "dtype=%s.", - DataTypeToString(expert_id.dtype()))); + expert_id.dtype() == DataType::INT32 || + expert_id.dtype() == DataType::INT64, + true, + common::errors::InvalidArgument( + "The dtype of expert_id should be INT32 or INT64. But received" + "dtype=%s.", + DataTypeToString(expert_id.dtype()))); int64_t seqlen = expert_id_dims[0]; int64_t k = expert_id_dims[1]; @@ -6221,12 +6222,14 @@ void IntBincountInferMeta(const MetaTensor& x, int64_t dtype, MetaTensor* out) { PADDLE_ENFORCE_EQ( - x.dims().size(), 1, + x.dims().size(), + 1, errors::InvalidArgument( "The input 'x' of int_bincount must be a 1-D Tensor, but got %u-D.", x.dims().size())); PADDLE_ENFORCE_GT( - high, low, + high, + low, errors::InvalidArgument("Attr high (%d) must be > low (%d).", high, low)); int64_t bin_count = high - low + 1; @@ -6234,7 +6237,6 @@ void IntBincountInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } - } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 10681997a51497..e6c16debb0a7ee 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -1011,7 +1011,7 @@ void BuildSrcRankAndLocalExpertIdInferMeta( int64_t num_local_experts, MetaTensor* src_rank, MetaTensor* local_expert_id); - + void IntBincountInferMeta(const MetaTensor& x, int64_t low, int64_t high, diff --git a/paddle/phi/kernels/expand_modality_expert_id_kernel.h b/paddle/phi/kernels/expand_modality_expert_id_kernel.h index 0f6ce161dce05f..1d0d308d33fb3f 100644 --- a/paddle/phi/kernels/expand_modality_expert_id_kernel.h +++ b/paddle/phi/kernels/expand_modality_expert_id_kernel.h @@ -1,3 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "paddle/phi/core/dense_tensor.h" @@ -12,4 +26,4 @@ void ExpandModalityExpertIDKernel(const Context& dev_ctx, bool is_group_expert, DenseTensor* expert_id_out); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index f35a5b7f6421c5..07ba0fc48d0784 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -81,12 +81,12 @@ void BuildSrcRankAndLocalExpertIdKernel( int* local_expert_id_data = dev_ctx.template Alloc(local_expert_id); build_srcrank_and_local_expert_id(src_rank_data, - local_expert_id_data, - expert_num_global_tensor_data, - token_num, - expert_num_global.size(), - num_local_experts, - dev_ctx.stream()); + local_expert_id_data, + expert_num_global_tensor_data, + token_num, + expert_num_global.size(), + num_local_experts, + dev_ctx.stream()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index 60ee81cdd93d77..d5fb1a32836ed5 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -221,14 +221,14 @@ void CalAuxLossKernel(const Context& dev_ctx, int64_t dispatch_tokens_mask_len = 0; auto dispatch_tokens_mask_ptr = dispatch_tokens_mask.get_ptr(); if (dispatch_tokens_mask) { - const auto mask_dims = dispatch_tokens_mask_ptr->dims(); - const auto dim_size = mask_dims.size(); - const bool is_not_zero_size = (dim_size > 0); - if (is_not_zero_size) { - dispatch_tokens_mask_len = dispatch_tokens_mask_ptr->dims()[0]; - } else { - dispatch_tokens_mask_len = 0; - } + const auto mask_dims = dispatch_tokens_mask_ptr->dims(); + const auto dim_size = mask_dims.size(); + const bool is_not_zero_size = (dim_size > 0); + if (is_not_zero_size) { + dispatch_tokens_mask_len = dispatch_tokens_mask_ptr->dims()[0]; + } else { + dispatch_tokens_mask_len = 0; + } } /* diff --git a/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu index 4e06d7274325a4..b9f9aa7674c0bd 100644 --- a/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu @@ -1,39 +1,56 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/phi/kernels/expand_modality_expert_id_kernel.h" +#include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include namespace phi { -template +template void expand_modality_expert_id(const T* expert_id, T* expert_id_out, int64_t seqlen, int64_t k, int64_t num_expert_per_modality, int64_t group_size, - int64_t modality_offset, + int64_t modality_offset, bool is_group_expert, - cudaStream_t stream){ - thrust::transform( + cudaStream_t stream) { + thrust::transform( thrust::cuda::par.on(stream), thrust::device_pointer_cast(expert_id), thrust::device_pointer_cast(expert_id) + seqlen * k, thrust::counting_iterator(0), thrust::device_pointer_cast(expert_id_out), - [k, num_expert_per_modality, group_size, modality_offset, is_group_expert] __device__(T e, T idx) { - if (is_group_expert){ - e += idx % k * group_size; - } - if (num_expert_per_modality <= 0) - return static_cast(e); - T rank = e / num_expert_per_modality; - T expert_id_in_rank = e % num_expert_per_modality; - return static_cast(rank * (num_expert_per_modality * 2) // HRAD code: only support 2 modality - + expert_id_in_rank - + modality_offset * num_expert_per_modality); - } - ); + [k, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert] __device__(T e, T idx) { + if (is_group_expert) { + e += idx % k * group_size; + } + if (num_expert_per_modality <= 0) return static_cast(e); + T rank = e / num_expert_per_modality; + T expert_id_in_rank = e % num_expert_per_modality; + return static_cast(rank * (num_expert_per_modality * + 2) // HRAD code: only support 2 modality + + expert_id_in_rank + + modality_offset * num_expert_per_modality); + }); } template @@ -43,7 +60,7 @@ void ExpandModalityExpertIDKernel(const Context& dev_ctx, int64_t group_size, int64_t modality_offset, bool is_group_expert, - DenseTensor* expert_id_out){ + DenseTensor* expert_id_out) { dev_ctx.template Alloc(expert_id_out); auto expert_id_shape = expert_id.dims(); int64_t seqlen = expert_id_shape[0]; @@ -58,11 +75,11 @@ void ExpandModalityExpertIDKernel(const Context& dev_ctx, is_group_expert, dev_ctx.stream()); } -} // namespace phi +} // namespace phi PD_REGISTER_KERNEL(expand_modality_expert_id, GPU, ALL_LAYOUT, phi::ExpandModalityExpertIDKernel, int, - int64_t) {} \ No newline at end of file + int64_t) {} diff --git a/paddle/phi/kernels/gpu/int_bincount.cu b/paddle/phi/kernels/gpu/int_bincount.cu index 287ef0cbe2d733..265a267929675c 100644 --- a/paddle/phi/kernels/gpu/int_bincount.cu +++ b/paddle/phi/kernels/gpu/int_bincount.cu @@ -1,12 +1,27 @@ +// NOLINT +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // #include "paddle/extension.h" -#include "paddle/phi/core/utils/data_type.h" -#include "paddle/common/flags.h" -#include +#include "paddle/phi/kernels/int_bincount.h" // NOLINT #include +#include #include "cub/device/device_histogram.cuh" +#include "paddle/common/flags.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/empty_kernel.h" // NOLINT -#include "paddle/phi/kernels/int_bincount.h" // NOLINT #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" @@ -14,7 +29,7 @@ COMMON_DECLARE_bool(enable_pir_api); -namespace phi{ +namespace phi { static phi::DataType TransToDataType(int64_t dtype) { if (FLAGS_enable_pir_api) { return static_cast(dtype); @@ -31,41 +46,54 @@ std::vector> IntBincountInferShape( return {{max_value - min_value}}; } -std::vector IntBincountInferDType( - phi::DataType x_dtype, - int64_t min_value, - int64_t max_value, - int64_t out_dtype) { +std::vector IntBincountInferDType(phi::DataType x_dtype, + int64_t min_value, + int64_t max_value, + int64_t out_dtype) { return {TransToDataType(out_dtype)}; } template -void IntBincountImpl(const Context& ctx, const T *x, int64_t n, T min_v, T max_v, BinsT *bins) { +void IntBincountImpl( + const Context &ctx, const T *x, int64_t n, T min_v, T max_v, BinsT *bins) { DenseTensor workspace; void *workspace_ptr = nullptr; size_t workspace_size = 0; #pragma unroll for (int i = 0; i < 2; ++i) { if (workspace_size > 0) { - workspace = phi::Empty(ctx, {static_cast(workspace_size)}); + workspace = + phi::Empty(ctx, {static_cast(workspace_size)}); workspace_ptr = workspace.data(); } - auto err = cub::DeviceHistogram::HistogramEven( - workspace_ptr, workspace_size, x, bins, max_v - min_v + 1, min_v, max_v, n, ctx.stream()); - PD_CHECK(err == cudaSuccess, "HistogramEven error: %s", cudaGetErrorString(err)); + auto err = cub::DeviceHistogram::HistogramEven(workspace_ptr, + workspace_size, + x, + bins, + max_v - min_v + 1, + min_v, + max_v, + n, + ctx.stream()); + PD_CHECK( + err == cudaSuccess, "HistogramEven error: %s", cudaGetErrorString(err)); } } // T is x's input type and out_dtype is in args -template -void IntBincount(const Context& ctx, const DenseTensor &x, int64_t low, int64_t high, int64_t out_dtype, DenseTensor* out) { +template +void IntBincount(const Context &ctx, + const DenseTensor &x, + int64_t low, + int64_t high, + int64_t out_dtype, + DenseTensor *out) { PD_CHECK(low < high); int64_t bins_width = high - low; PD_CHECK(bins_width + 1 < std::numeric_limits::max()); auto bins_dtype = TransToDataType(out_dtype); - // auto x_dytpe = x.dtype(); auto low_v = static_cast(low); auto high_v = static_cast(high); @@ -75,24 +103,21 @@ void IntBincount(const Context& ctx, const DenseTensor &x, int64_t low, int64_t int64_t n = x.numel(); if (bins_dtype == phi::DataType::INT32) { ctx.template Alloc(out); - uint32_t *out_ptr = static_cast(out->data()); - IntBincountImpl(ctx, x_data, n, low_v, high_v, out_ptr); + uint32_t *out_ptr = static_cast(out->data()); + IntBincountImpl( + ctx, x_data, n, low_v, high_v, out_ptr); } else if (bins_dtype == phi::DataType::INT64) { - using ULLI = unsigned long long int; + using ULLI = unsigned long long int; // NOLINT ctx.template Alloc(out); - static_assert(sizeof(int64_t) == sizeof(ULLI)); + static_assert(sizeof(int64_t) == sizeof(ULLI)); // WARNING: unsafe type cast used in original impl. - ULLI* out_ptr = static_cast (out->data()); + ULLI *out_ptr = static_cast(out->data()); IntBincountImpl(ctx, x_data, n, low_v, high_v, out_ptr); } else { PD_THROW("Only support INT32 and INT64, but got %s", bins_dtype); } } -} // namespace phi +} // namespace phi -PD_REGISTER_KERNEL(int_bincount, - GPU, - ALL_LAYOUT, - phi::IntBincount, - int64_t, - int) {} \ No newline at end of file +PD_REGISTER_KERNEL( + int_bincount, GPU, ALL_LAYOUT, phi::IntBincount, int64_t, int) {} diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu index 3f9df7fcb9a1db..173d4f20b96ffe 100644 --- a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -37,13 +37,13 @@ static void GetRowsCols(const std::vector &shape, *p_cols = cols; } -template -void RMSLnFwd(const Context& ctx, +template +void RMSLnFwd(const Context &ctx, const DenseTensor &x, const DenseTensor &scale, float epsilon, - DenseTensor* y, - DenseTensor* invvar) { + DenseTensor *y, + DenseTensor *invvar) { const auto &scale_shape = scale.dims(); const auto &x_shape = x.dims(); PD_CHECK(scale_shape.size() == 1); @@ -54,21 +54,21 @@ void RMSLnFwd(const Context& ctx, cols = x_shape[1]; // GetRowsCols(x_shape, &rows, &cols); - *y = phi::EmptyLike(ctx, x); + *y = phi::EmptyLike(ctx, x); *invvar = phi::Empty(ctx, {rows}); cuda_rms_norm(ctx, x, scale, rows, cols, epsilon, y, invvar); } -template -void RMSLnBwd(const Context& ctx, +template +void RMSLnBwd(const Context &ctx, const DenseTensor &x, const DenseTensor &scale, const DenseTensor &invvar, const DenseTensor &y_grad, float epsilon, - DenseTensor* x_grad, - DenseTensor* scale_grad) { + DenseTensor *x_grad, + DenseTensor *scale_grad) { int rows, cols; const auto &x_shape = x.dims(); rows = x_shape[0]; @@ -76,32 +76,13 @@ void RMSLnBwd(const Context& ctx, ctx.template Alloc(x_grad); ctx.template Alloc(scale_grad); cuda_rms_norm_gradient( - ctx, - x, - scale, - invvar, - y_grad, - rows, - cols, - epsilon, - x_grad, - scale_grad - ); + ctx, x, scale, invvar, y_grad, rows, cols, epsilon, x_grad, scale_grad); } -} // namespace phi +} // namespace phi +PD_REGISTER_KERNEL( + fused_rms_norm, GPU, ALL_LAYOUT, phi::RMSLnFwd, float, double) {} -PD_REGISTER_KERNEL(fused_rms_norm, - GPU, - ALL_LAYOUT, - phi::RMSLnFwd, - float, - double) {} - -PD_REGISTER_KERNEL(fused_rms_norm_grad, - GPU, - ALL_LAYOUT, - phi::RMSLnBwd, - float, - double) {} +PD_REGISTER_KERNEL( + fused_rms_norm_grad, GPU, ALL_LAYOUT, phi::RMSLnBwd, float, double) {} diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu index d85ff0c5b5ffb2..d7f603746ad39d 100644 --- a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -1,6 +1,20 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_combine_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/moe_combine_grad_kernel.h" #include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -109,16 +123,17 @@ void moe_combine_bwd(const Context& dev_ctx, const int64_t k, const int64_t seqlen, const int64_t hidden_size) { - apply_moe_combine_bwd(x.data(), - combine_weights.data(), - scatter_index.data(), - grad_y.data(), - const_cast(grad_x->data()), - const_cast(grad_combine_weights_helper->data()), - k, - seqlen, - hidden_size, - dev_ctx.stream()); + apply_moe_combine_bwd( + x.data(), + combine_weights.data(), + scatter_index.data(), + grad_y.data(), + const_cast(grad_x->data()), + const_cast(grad_combine_weights_helper->data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); } template void MoeCombineGradKernel(const Context& dev_ctx, @@ -130,8 +145,13 @@ void MoeCombineGradKernel(const Context& dev_ctx, DenseTensor* grad_combine_weights_helper) { dev_ctx.template Alloc(grad_x); dev_ctx.template Alloc(grad_combine_weights_helper); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())), 0, grad_combine_weights_helper); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())), + 0, + grad_combine_weights_helper); auto x_shape = x.dims(); auto combine_weights_shape = combine_weights.dims(); moe_combine_bwd(dev_ctx, @@ -154,4 +174,4 @@ PD_REGISTER_KERNEL(moe_combine_grad, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} \ No newline at end of file + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu index 79d9e9d5515082..0c670f530f21c2 100644 --- a/paddle/phi/kernels/gpu/moe_combine_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/phi/kernels/moe_combine_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -66,7 +80,7 @@ void apply_moe_combine_fwd(const T* x, } template -void moe_combine_fwd(const Context& dev_ctx, +void moe_combine_fwd(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& combine_weights, const DenseTensor& scatter_index, @@ -75,36 +89,37 @@ void moe_combine_fwd(const Context& dev_ctx, const int64_t seqlen, const int64_t hidden_size) { apply_moe_combine_fwd(x.data(), - combine_weights.data(), - scatter_index.data(), - const_cast(y.data()), - k, - seqlen, - hidden_size, - dev_ctx.stream()); - } + combine_weights.data(), + scatter_index.data(), + const_cast(y.data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); +} - template - void MoeCombineKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& combine_weights, - const DenseTensor& scatter_index, - DenseTensor* y) { - dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very - // well, maybe replaced with x.dtype(); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); - auto combine_weights_shape = combine_weights.dims(); - auto x_shape = x.dims(); - moe_combine_fwd(dev_ctx, - x, - combine_weights, - scatter_index, - *y, - combine_weights_shape[1], // k - combine_weights_shape[0], // seqlen - x_shape[1]); // hidden_size - } +template +void MoeCombineKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + DenseTensor* y) { + dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very + // well, maybe replaced with x.dtype(); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + auto combine_weights_shape = combine_weights.dims(); + auto x_shape = x.dims(); + moe_combine_fwd(dev_ctx, + x, + combine_weights, + scatter_index, + *y, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size } +} // namespace phi PD_REGISTER_KERNEL(moe_combine, GPU, @@ -113,4 +128,4 @@ PD_REGISTER_KERNEL(moe_combine, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} \ No newline at end of file + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu index 7912c38665fd16..23bbca3cd8614a 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -1,3 +1,4 @@ +// NOLINT // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,8 +17,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/moe_fuse_op.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" namespace phi { // -------- getWorkspaceSize -------- // @@ -29,7 +30,8 @@ size_t getWorkspaceSize(const int num_rows, const int num_experts, const int k, // const int max_seq_len, - phi::CubKeyValueSorter &sorter) { + phi::CubKeyValueSorter &sorter // NOLINT +) { // const int buf_size = AlignTo16(k * num_rows * hidden_size); // const int interbuf_size = AlignTo16(k * num_rows * inter_size); // const int padded_experts = AlignTo16(num_experts); @@ -58,8 +60,7 @@ size_t getWorkspaceSize(const int num_rows, // "< void apply_moe_dispatch_fwd(const Context &dev_ctx, @@ -339,7 +340,8 @@ void MoeGradDispatchKernel(const Context &dev_ctx, dev_ctx.template Alloc(combine_weights); dev_ctx.template Alloc(y); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); auto x_dims = x.dims(); auto gate_logits_dims = gate_logits.dims(); diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu index 03b446e8dad0bf..48d07e2bff02ac 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu @@ -1,57 +1,81 @@ +// NOLINT +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/moe_fuse_bwd_op.h" #include "paddle/phi/kernels/transpose_kernel.h" -namespace phi{ +namespace phi { template -void apply_moe_dispatch_bwd( - const T* y_grad, - const float* combine_weights, // [s, k] - const int* scatter_index, // [s, k] - const float* combine_weights_grad, - const int* expert_id, // [s, k] - float* gate_logits_grad, - T* x_grad, - int64_t num_rows, - int64_t k, - int64_t dim, - int64_t num_experts, - int64_t capacity, - bool use_all2all_permute, - int64_t world_size, - int64_t num_local_experts, - cudaStream_t stream){ - gather_with_mask_launcher(y_grad, - scatter_index, - combine_weights, - x_grad, num_rows, k, dim, -1, stream, use_all2all_permute, world_size, num_local_experts, capacity); +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_grad, + const int* expert_id, // [s, k] + float* gate_logits_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + -1, + stream, + use_all2all_permute, + world_size, + num_local_experts, + capacity); - topk_grad_with_mask_launcher(combine_weights_grad, - expert_id, - combine_weights, - gate_logits_grad, - num_rows, k, num_experts, stream); + topk_grad_with_mask_launcher(combine_weights_grad, + expert_id, + combine_weights, + gate_logits_grad, + num_rows, + k, + num_experts, + stream); } - template void moe_dispatch_bwd(const Context& dev_ctx, - const DenseTensor& combine_weights, // [s, k] - const DenseTensor& scatter_index, // [k, s] - const DenseTensor& expert_id, // [s, k] - const DenseTensor& y_grad, // [num_experts * capacity, h] - const DenseTensor& combine_weights_grad, // [s, k] - DenseTensor& x_grad, - DenseTensor& gate_logits_grad, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + DenseTensor& x_grad, // NOLINT + DenseTensor& gate_logits_grad, // NOLINT int64_t capacity, bool use_all2all_permute = false, int64_t world_size = -1, - int64_t num_local_experts = -1){ + int64_t num_local_experts = -1) { auto combine_weights_dims = combine_weights.dims(); int64_t num_rows = combine_weights_dims[0]; int64_t k = combine_weights_dims[1]; @@ -59,42 +83,44 @@ void moe_dispatch_bwd(const Context& dev_ctx, int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; int64_t num_experts = gate_logits_grad.dims()[1]; - apply_moe_dispatch_bwd( - y_grad.data(), - combine_weights.data(), - scatter_index.data(), - combine_weights_grad.data(), - expert_id.data(), - gate_logits_grad.data(), - x_grad.data(), - num_rows, - k, - hidden_size, - num_experts, - capacity, - use_all2all_permute, - world_size, - num_local_experts, - dev_ctx.stream()); + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_grad.data(), + expert_id.data(), + gate_logits_grad.data(), + x_grad.data(), + num_rows, + k, + hidden_size, + num_experts, + capacity, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); } template -void MoeGateDispatchGradKernel(const Context& dev_ctx, - const DenseTensor& combine_weights, // [s, k] - const DenseTensor& scatter_index, // [k, s] - const DenseTensor& expert_id, // [num_local_experts, num_experts * capacity // num_local_experts, h] - const DenseTensor& y_grad, // [s, k] - const DenseTensor& combine_weights_grad, - int64_t k, - int64_t capacity, - int64_t world_size, - DenseTensor* x_grad, - DenseTensor* gate_logits_grad){ +void MoeGateDispatchGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [num_local_experts, num_experts * capacity + // // num_local_experts, h] + const DenseTensor& y_grad, // [s, k] + const DenseTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* x_grad, + DenseTensor* gate_logits_grad) { int64_t num_local_experts = y_grad.dims()[0]; auto scatter_index_dims = scatter_index.dims(); DenseTensor t_scatter_index; - phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); + phi::Transpose( + dev_ctx, scatter_index, {1, 0}, &t_scatter_index); DenseTensor t_scatter_index_; phi::ContiguousKernel( dev_ctx, t_scatter_index, &t_scatter_index_); @@ -113,9 +139,8 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, true, /*use_all2all_permute*/ world_size, num_local_experts); - } -} // namespace phi +} // namespace phi PD_REGISTER_KERNEL(moe_gate_dispatch_permute_grad, GPU, @@ -124,4 +149,4 @@ PD_REGISTER_KERNEL(moe_gate_dispatch_permute_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} \ No newline at end of file + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu index 288ec03554b499..8d35c763770695 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -1,3 +1,4 @@ +// NOLINT // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,12 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/moe_fuse_op.h" #include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" namespace phi { namespace { @@ -29,32 +30,40 @@ size_t getWorkspaceSize(const int num_rows, const int num_experts, const int k, // const int max_seq_len, - phi::CubKeyValueSorter &sorter) -{ - + phi::CubKeyValueSorter &sorter // NOLINT +) { // const int buf_size = AlignTo16(k * num_rows * hidden_size); // const int interbuf_size = AlignTo16(k * num_rows * inter_size); // const int padded_experts = AlignTo16(num_experts); const int num_moe_inputs = AlignTo16(k * num_rows); int num_softmax_outs = 0; - // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them - // in Encoder or Decoder before invoking FfnLayer forward. - size_t total_ws_bytes = 4 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data - // total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt - // total_ws_bytes += num_softmax_outs * sizeof(KeyT); - // const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT); - const int sorter_ws_size_bytes = AlignTo16(sorter.getWorkspaceSize(k * num_rows)); - //sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity // 用所有 bit 做排序,会降低些许性能,但是防止越界 - total_ws_bytes += sorter_ws_size_bytes; // intermediate (fc1) output + cub sorting workspace - // std::cout<<"sorter_ws_size_bytes = "< -void apply_moe_dispatch_fwd(const Context& dev_ctx, +void apply_moe_dispatch_fwd(const Context &dev_ctx, const T *x, const float *gate_logits, const float *corr_bias, @@ -72,30 +81,35 @@ void apply_moe_dispatch_fwd(const Context& dev_ctx, bool use_all2all_permute, int64_t world_size, int64_t num_local_experts, - cudaStream_t stream){ + cudaStream_t stream) { phi::CubKeyValueSorter sorter(stream); // phi::funcs::SetConstant zero; // zero(ctx, &finished_tensor, false); - DenseTensor xpanded_source_row_to_expanded_dest_row_tensor = phi::Empty(dev_ctx, IntArray({num_rows, k})); + DenseTensor xpanded_source_row_to_expanded_dest_row_tensor = + phi::Empty(dev_ctx, IntArray({num_rows, k})); // int* expanded_source_row_to_expanded_dest_row = // expanded_source_row_to_expanded_dest_row_tensor.data(); - // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, paddle::DataType::FLOAT32, place); - // float* expert_scales_float = expert_scales_tensor_float.data(); + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); - // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, paddle::DataType::INT32, place); - // int* expert_for_source_row = expert_for_source_row_tensor.data(); - DenseTensor active_cnt_tensor = phi::Empty(dev_ctx, IntArray({1})); + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); + DenseTensor active_cnt_tensor = + phi::Empty(dev_ctx, IntArray({1})); int64_t bytes = getWorkspaceSize(num_rows, - hidden_size, // hidden-size=0 - 0, // inter-size=0 + hidden_size, // hidden-size=0 + 0, // inter-size=0 num_experts, k, sorter); - DenseTensor ws_ptr_tensor = phi::Empty(dev_ctx, IntArray({bytes})); + DenseTensor ws_ptr_tensor = + phi::Empty(dev_ctx, IntArray({bytes})); int8_t *ws_ptr = ws_ptr_tensor.data(); // Pointers @@ -109,7 +123,8 @@ void apply_moe_dispatch_fwd(const Context& dev_ctx, // int64_t* total_rows_before_expert_; T *fc1_result_; - const int sorter_ws_size_bytes = AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); // const int buf_size = AlignTo16(k * num_rows * hidden_size); // const int interbuf_size = AlignTo16(k * num_rows * 0); const int padded_experts = AlignTo16(num_experts); @@ -121,86 +136,102 @@ void apply_moe_dispatch_fwd(const Context& dev_ctx, expert_id_ = permuted_experts_ + num_moe_inputs; // permuted_data_ = reinterpret_cast(expert_id_ + num_moe_inputs); - // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + buf_size); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); // only use one number - // num_active = reinterpret_cast(permuted_experts_ + num_moe_inputs); + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); fc1_result_ = reinterpret_cast(expert_id_ + num_moe_inputs); softmax_out_ = nullptr; - + #ifdef DEBUG_MOE_OP - // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits before_topk")); - // print_to_screen1(finished, 2, 16, std::string("finished before_topk")); + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); #endif - topk_gating_softmax_kernelLauncher(gate_logits, + topk_gating_softmax_kernelLauncher(gate_logits, corr_bias, - combine_weights, // output - softmax_out_, // no use - expert_id, // output - source_rows_, // output + combine_weights, // output + softmax_out_, // no use + expert_id, // output + source_rows_, // output num_rows, num_experts, k, stream); #ifdef DEBUG_MOE_OP - // phi::CastKernel(ctx, expert_scales_tensor_float, expert_scales_tensor.dtype(), &expert_scales_tensor); - print_to_screen1(combine_weights, 8, 16, std::string("expert_scales_float after topk")); - print_to_screen1(expert_id, 8, 16, std::string("expert-id before permute")); - print_to_screen1(source_rows_, 8, 16, std::string("desc->src idx before permute")); + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); #endif - // modifiy expert-id according to k - if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id - modify_expert_id_launcher(expert_id, expert_id_, k, num_rows, num_experts, stream); + // modify expert-id according to k + if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id + modify_expert_id_launcher( + expert_id, expert_id_, k, num_rows, num_experts, stream); - // calc expert-size -/* - if (!use_pad) - cal_expert_size_and_filter_launcher(expert_id, - k * num_rows, - num_experts, - capacity, - stream); -*/ - #ifdef DEBUG_MOE_OP - print_to_screen1(expert_id, 8, 16, std::string("expert-id after modified")); + // calc expert-size + /* + if (!use_pad) + cal_expert_size_and_filter_launcher(expert_id, + k * num_rows, + num_experts, + capacity, + stream); + */ +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id, 8, 16, std::string("expert-id after modified")); #endif - sorter.run(fc1_result_, - sorter_ws_size_bytes, - use_pad ? expert_id_ : expert_id, // key in - permuted_experts_, // key out // [num_row, k]: expert-id - source_rows_, // value in - permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 - k * num_rows, // num_rows - false, - stream); - - if (use_pad) - unmodify_expert_id_launcher(permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + use_pad ? expert_id_ : expert_id, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + if (use_pad) + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); #ifdef DEBUG_MOE_OP - print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); - print_to_screen1(permuted_rows_, 8, 16, std::string("dest->src idx after permute")); + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); #endif compute_total_rows_before_expert( - permuted_experts_, - k * num_rows, - num_experts, - expert_offset, - stream); - + permuted_experts_, k * num_rows, num_experts, expert_offset, stream); + #ifdef DEBUG_MOE_OP print_to_screen1(expert_offset, 8, 16, std::string("expert_offset")); int64_t num_active_host_v2; - cudaMemcpy(&num_active_host_v2, expert_offset + num_experts - 1, sizeof(int64_t), cudaMemcpyDeviceToHost); + cudaMemcpy(&num_active_host_v2, + expert_offset + num_experts - 1, + sizeof(int64_t), + cudaMemcpyDeviceToHost); std::cerr << "[DEBUG] num_active v2: " << num_active_host_v2 << std::endl; - print_to_screen1(permuted_experts_, 8, num_active_host_v2+2, std::string("expert-id after permute")); - // print_to_screen1(permuted_experts_, 4096, 8192, std::string("expert-id after permute")); + print_to_screen1(permuted_experts_, + 8, + num_active_host_v2 + 2, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); #endif - + if (!use_all2all_permute) { initialize_moe_routing_kernelLauncher(x, y, @@ -208,10 +239,10 @@ void apply_moe_dispatch_fwd(const Context& dev_ctx, scatter_index, permuted_experts_, expert_offset, - combine_weights, + combine_weights, static_cast(num_rows), static_cast(hidden_size), - static_cast(k), + static_cast(k), capacity, use_pad, stream); @@ -224,88 +255,92 @@ void apply_moe_dispatch_fwd(const Context& dev_ctx, scatter_index, permuted_experts_, expert_offset, - combine_weights, + combine_weights, static_cast(num_rows), static_cast(hidden_size), - static_cast(k), + static_cast(k), capacity, world_size, num_local_experts, stream); } - + // turn expert_offset_ptr into experts_num // auto expert_offset_ptr = thrust::device_pointer_cast(expert_offset); // thrust::adjacent_difference( // expert_offset_ptr, expert_offset_ptr + num_experts, expert_offset_ptr // ); #ifdef DEBUG_MOE_OP - print_to_screen1(scatter_index, 8, 16, std::string("scatter_index after pad")); + print_to_screen1( + scatter_index, 8, 16, std::string("scatter_index after pad")); #endif - // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, cudaMemcpyDeviceToDevice); - // cudaMemcpy(combine_weights, expert_scales_float, sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); + // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, + // cudaMemcpyDeviceToDevice); cudaMemcpy(combine_weights, expert_scales_float, + // sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); return; } template -void moe_dispatch_fwd(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& gate_logits, - const paddle::optional& corr_bias, +void moe_dispatch_fwd(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, int64_t num_rows, int64_t num_experts, int64_t hidden_size, int64_t capacity, int64_t k, - const DenseTensor& y, - const DenseTensor& combine_weights, - const DenseTensor& scatter_index, - const DenseTensor& expert_offset, - const DenseTensor& expert_id, + const DenseTensor &y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &expert_offset, + const DenseTensor &expert_id, bool use_pad, int64_t use_all2all_permute = false, int64_t world_size = -1, - int64_t num_local_experts = -1){ - apply_moe_dispatch_fwd(dev_ctx, - x.data(), - gate_logits.data(), - corr_bias? corr_bias.get_ptr()->data() : nullptr, - num_rows, - num_experts, - hidden_size, - capacity, - k, - const_cast(y.data()), - const_cast(combine_weights.data()), - const_cast(scatter_index.data()), - const_cast(expert_offset.data()), - const_cast(expert_id.data()), - use_pad, - use_all2all_permute, - world_size, - num_local_experts, - dev_ctx.stream()); + int64_t num_local_experts = -1) { + apply_moe_dispatch_fwd( + dev_ctx, + x.data(), + gate_logits.data(), + corr_bias ? corr_bias.get_ptr()->data() : nullptr, + num_rows, + num_experts, + hidden_size, + capacity, + k, + const_cast(y.data()), + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(expert_offset.data()), + const_cast(expert_id.data()), + use_pad, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); } template -void MoEDispatchPermuteKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& gate_logits, - const paddle::optional& corr_bias, +void MoEDispatchPermuteKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, int64_t k, int64_t capacity, int64_t world_size, - DenseTensor* y, - DenseTensor* combine_weights, - DenseTensor* scatter_index, - DenseTensor* expert_offset, - DenseTensor* expert_id){ + DenseTensor *y, + DenseTensor *combine_weights, + DenseTensor *scatter_index, + DenseTensor *expert_offset, + DenseTensor *expert_id) { dev_ctx.template Alloc(expert_id); dev_ctx.template Alloc(expert_offset); dev_ctx.template Alloc(scatter_index); dev_ctx.template Alloc(combine_weights); dev_ctx.template Alloc(y); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); const auto &x_shape = x.dims(); const auto &gate_logits_shape = gate_logits.dims(); int64_t num_rows = x_shape[0]; @@ -313,23 +348,23 @@ void MoEDispatchPermuteKernel(const Context& dev_ctx, int64_t num_experts = gate_logits_shape[1]; int64_t num_local_experts = num_experts / world_size; moe_dispatch_fwd(dev_ctx, - x, - gate_logits, - corr_bias, - num_rows, - num_experts, - hidden_size, - capacity, - k, - *y, - *combine_weights, - *scatter_index, - *expert_offset, - *expert_id, - true, /*use_pad*/ - true, /*use_all2all_permute*/ - world_size, - num_local_experts); + x, + gate_logits, + corr_bias, + num_rows, + num_experts, + hidden_size, + capacity, + k, + *y, + *combine_weights, + *scatter_index, + *expert_offset, + *expert_id, + true, /*use_pad*/ + true, /*use_all2all_permute*/ + world_size, + num_local_experts); } } // namespace phi @@ -340,4 +375,4 @@ PD_REGISTER_KERNEL(moe_gate_dispatch_permute, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} \ No newline at end of file + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu index b0123a4d0e9ab8..f1e1ec0d752bcc 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -13,123 +13,128 @@ // limitations under the License. #pragma once -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h" +#include +#include #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/contiguous_kernel.h" -#include "paddle/phi/kernels/moe_fuse_bwd_op.h" -#include -#include +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" -namespace phi{ +namespace phi { template -void apply_moe_dispatch_bwd( - const T* y_grad, - const float* combine_weights, // [s, k] - const int* scatter_index, // [s, k] - const float* combine_weights_out_grad, - float* combine_weights_in_grad, - T* x_grad, - int64_t num_rows, - int64_t k, - int64_t dim, - int64_t num_experts, - int64_t num_active, - cudaStream_t stream){ - printf("apply_moe_dispatch_bwd\n"); - gather_with_mask_launcher(y_grad, - scatter_index, - combine_weights, - x_grad, num_rows, k, dim, num_active, stream); - auto out_grad_ptr = thrust::device_pointer_cast(combine_weights_out_grad); - auto in_grad_ptr = thrust::device_pointer_cast(combine_weights_in_grad); - auto combine_weight_ptr = thrust::device_pointer_cast(combine_weights); - printf("kernel over\n"); - thrust::transform( - thrust::cuda::par.on(stream), - out_grad_ptr, - out_grad_ptr + num_rows * k, - combine_weight_ptr, - in_grad_ptr, - [] __device__ (float g, float w){ - return w > static_cast(0) ? g : static_cast(0); - } - ); - // topk_grad_with_mask_launcher(combine_weights_grad, - // expert_id, - // combine_weights, - // gate_logtis_grad, - // num_rows, k, num_experts, stream); +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_out_grad, + float* combine_weights_in_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t num_active, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + num_active, + stream); + auto out_grad_ptr = thrust::device_pointer_cast(combine_weights_out_grad); + auto in_grad_ptr = thrust::device_pointer_cast(combine_weights_in_grad); + auto combine_weight_ptr = thrust::device_pointer_cast(combine_weights); + thrust::transform(thrust::cuda::par.on(stream), + out_grad_ptr, + out_grad_ptr + num_rows * k, + combine_weight_ptr, + in_grad_ptr, + [] __device__(float g, float w) { + return w > static_cast(0) ? g + : static_cast(0); + }); + // topk_grad_with_mask_launcher(combine_weights_grad, + // expert_id, + // combine_weights, + // gate_logtis_grad, + // num_rows, k, num_experts, stream); } template void moe_dispatch_bwd(const Context& dev_ctx, - const DenseTensor &combine_weights, // [s, k] - const DenseTensor &scatter_index, // [k, s] - const DenseTensor &y_grad, // [num_experts * capacity, h] - const DenseTensor &combine_weights_out_grad, // [s, k] - DenseTensor *x_grad, - DenseTensor *combine_weights_in_grad, - int64_t num_experts){ - int64_t num_rows = combine_weights.dims()[0]; - int64_t k = combine_weights.dims()[1]; - int64_t hidden_size = y_grad.dims()[1]; - int64_t num_active = y_grad.dims()[0]; + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_out_grad, // [s, k] + DenseTensor* x_grad, + DenseTensor* combine_weights_in_grad, + int64_t num_experts) { + int64_t num_rows = combine_weights.dims()[0]; + int64_t k = combine_weights.dims()[1]; + int64_t hidden_size = y_grad.dims()[1]; + int64_t num_active = y_grad.dims()[0]; - apply_moe_dispatch_bwd( - y_grad.data(), - combine_weights.data(), - scatter_index.data(), - combine_weights_out_grad.data(), - combine_weights_in_grad->data(), - x_grad->data(), - num_rows, - k, - hidden_size, - num_experts, - num_active, - dev_ctx.stream()); + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_out_grad.data(), + combine_weights_in_grad->data(), + x_grad->data(), + num_rows, + k, + hidden_size, + num_experts, + num_active, + dev_ctx.stream()); } template -void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, - const DenseTensor& combine_weights_out, - const DenseTensor& scatter_index, - const DenseTensor& scatter_index_rev, - const DenseTensor& expert_offset, - const DenseTensor& expert_offset_local, - const DenseTensor& y_grad, - const DenseTensor& combine_weights_out_grad, - int64_t k, - int64_t capacity, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - DenseTensor* x_grad, - DenseTensor* combine_weights_grad){ +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad) { dev_ctx.template Alloc(x_grad); dev_ctx.template Alloc(combine_weights_grad); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(combine_weights_grad->dims())), 0, combine_weights_grad); - // DenseTensor t_scatter_index; - // printf("check pass\n"); - // phi::Transpose(dev_ctx, scatter_index, {1,0}, &t_scatter_index); - // DenseTensor t_scatter_index_out; - // phi::ContiguousKernel(dev_ctx, t_scatter_index, &t_scatter_index_out); - // t_scatter_index = t_scatter_index_out; - // int64_t num_experts = expert_offset.dims()[0]; - // moe_dispatch_bwd(dev_ctx, - // combine_weights_out, - // t_scatter_index, - // y_grad, - // combine_weights_out_grad, - // x_grad, - // combine_weights_grad, - // num_experts); - + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(combine_weights_grad->dims())), + 0, + combine_weights_grad); + DenseTensor t_scatter_index; + phi::Transpose( + dev_ctx, scatter_index, {1, 0}, &t_scatter_index); + DenseTensor t_scatter_index_out; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_out); + t_scatter_index = t_scatter_index_out; + int64_t num_experts = expert_offset.dims()[0]; + moe_dispatch_bwd(dev_ctx, + combine_weights_out, + t_scatter_index, + y_grad, + combine_weights_out_grad, + x_grad, + combine_weights_grad, + num_experts); } } // namespace phi @@ -140,4 +145,4 @@ PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} \ No newline at end of file + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu index 6d009350263aed..d870fb85bc166b 100644 --- a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -1,3 +1,4 @@ +// NOLINT // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,38 +15,40 @@ /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ -/*This code is copied fron NVIDIA apex: +/*This code is copied from NVIDIA apex: * https://github.com/NVIDIA/apex * with minor changes. */ -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/moe_fuse_op.h" #include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/moe_kernel_impl.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" +#include "paddle/phi/kernels/slice_kernel.h" namespace phi { -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -// already defined need to revise! +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// already defined need to revise! // static inline size_t AlignTo16(const size_t &input){ // static constexpr int ALIGNMENT = 16; // return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); // } -namespace{ +namespace { // -------- getWorkspaceSize -------- // template size_t getWorkspaceSize(const int num_rows, @@ -56,9 +59,7 @@ size_t getWorkspaceSize(const int num_rows, const int k, // const int max_seq_len, bool use_pad, - phi::CubKeyValueSorter &sorter) -{ - + phi::CubKeyValueSorter &sorter) { // NOLINT // const int buf_size = AlignTo16(k * num_rows * hidden_size); const int interbuf_size = AlignTo16(k * num_rows * inter_size); const int padded_experts = AlignTo16(num_experts); @@ -66,35 +67,45 @@ size_t getWorkspaceSize(const int num_rows, const int num_dispatched_size = AlignTo16(num_experts * capacity); int num_softmax_outs = 0; - // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them - // in Encoder or Decoder before invoking FfnLayer forward. - size_t total_ws_bytes = 4 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ total_ws_bytes += 2 * num_dispatched_size * sizeof(int); - total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt + total_ws_bytes += + padded_experts * + sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data total_ws_bytes += num_softmax_outs * sizeof(KeyT); const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT); - const int sorter_ws_size_bytes = - std::max(AlignTo16(sorter.getWorkspaceSize(k * num_rows)), - AlignTo16(sorter.getWorkspaceSize(capacity))); - //sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity // 用所有 bit 做排序,会降低些许性能,但是防止越界 + const int sorter_ws_size_bytes = + std::max(AlignTo16(sorter.getWorkspaceSize(k * num_rows)), + AlignTo16(sorter.getWorkspaceSize(capacity))); + // sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity + // // 用所有 bit 做排序,会降低些许性能,但是防止越界 int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; - if (sorter_ws_size_bytes > bytes_for_fc1_result) - { - int remaining_bytes = AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); bytes_for_intermediate_and_sorting += remaining_bytes; } // std::cout<<"num_softmax_outs --"<< num_softmax_outs << std::endl; - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace - // std::cout<<"buf_size --"<< buf_size<<" "< void apply_moe_dispatch_fwd( - const Context& dev_ctx, - const DenseTensor& x, + const Context &dev_ctx, + const DenseTensor &x, int64_t num_rows, int64_t num_experts, int64_t hidden_size, @@ -103,33 +114,34 @@ void apply_moe_dispatch_fwd( int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop, - thrust::host_vector& expert_offset_host, + thrust::host_vector &expert_offset_host, // NOLINT DenseTensor *y, float *combine_weights, int *scatter_index, - int * scatter_index_rev, + int *scatter_index_rev, int64_t *expert_offset_global, - int64_t* expert_nums_local, + int64_t *expert_nums_local, int *expert_id, bool use_pad, - cudaStream_t stream) -{ + cudaStream_t stream) { phi::CubKeyValueSorter sorter(stream); // paddle::Tensor expanded_source_row_to_expanded_dest_row_tensor = // paddle::empty({num_rows, k}, paddle::DataType::INT32, place); // int* expanded_source_row_to_expanded_dest_row = // expanded_source_row_to_expanded_dest_row_tensor.data(); - // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, paddle::DataType::FLOAT32, place); - // float* expert_scales_float = expert_scales_tensor_float.data(); + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); - // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, paddle::DataType::INT32, place); - // int* expert_for_source_row = expert_for_source_row_tensor.data(); - // paddle::Tensor active_cnt_tensor = paddle::empty({1}, paddle::DataType::INT32, place); + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); paddle::Tensor active_cnt_tensor + // = paddle::empty({1}, paddle::DataType::INT32, place); int64_t bytes = getWorkspaceSize(num_rows, - hidden_size, // hidden-size=0 - 0, // inter-size=0 + hidden_size, // hidden-size=0 + 0, // inter-size=0 num_experts, capacity, k, @@ -139,7 +151,8 @@ void apply_moe_dispatch_fwd( DenseTensor ws_ptr_tensor = phi::Empty(dev_ctx, {bytes}); int8_t *ws_ptr = ws_ptr_tensor.data(); - phi::memory_utils::ThrustAllocator allocator(dev_ctx.GetPlace(), dev_ctx.stream()); + phi::memory_utils::ThrustAllocator allocator(dev_ctx.GetPlace(), + dev_ctx.stream()); // Pointers int *source_rows_; @@ -150,7 +163,7 @@ void apply_moe_dispatch_fwd( int *source_rows_for_seqsort_out_; int *source_pos_for_seqsort_; int *source_pos_for_seqsort_out_; - int64_t *expert_offset_; // local-expert-offset + int64_t *expert_offset_; // local-expert-offset char *sorter_ws_; // T* permuted_data_; @@ -158,14 +171,16 @@ void apply_moe_dispatch_fwd( // int64_t* total_rows_before_expert_; T *fc1_result_; - const int sorter_ws_size_bytes = AlignTo16(sorter.getWorkspaceSize(k * num_rows)); - const int sorter_ws_size_bytes_seqsort = AlignTo16(sorter.getWorkspaceSize(capacity)); + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + const int sorter_ws_size_bytes_seqsort = + AlignTo16(sorter.getWorkspaceSize(capacity)); const int buf_size = AlignTo16(k * num_rows * hidden_size); // const int interbuf_size = AlignTo16(k * num_rows * 0); const int padded_experts = AlignTo16(num_experts); const int num_moe_inputs = AlignTo16(k * num_rows); - const int num_dispatched_size = AlignTo16(num_experts * capacity); + const int num_dispatched_size = AlignTo16(num_experts * capacity); // 4:ints [k*row] source_rows_ = reinterpret_cast(ws_ptr); @@ -176,111 +191,127 @@ void apply_moe_dispatch_fwd( source_rows_for_seqsort_ = expert_id_ + num_moe_inputs; source_rows_for_seqsort_out_ = source_rows_for_seqsort_ + num_dispatched_size; // 1:ints: [E] - expert_offset_ = reinterpret_cast (source_rows_for_seqsort_out_ + num_dispatched_size); + expert_offset_ = reinterpret_cast(source_rows_for_seqsort_out_ + + num_dispatched_size); // permuted_data_ = reinterpret_cast(expert_offset_ + padded_experts); - // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + buf_size); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); // only use one number - // num_active = reinterpret_cast(permuted_experts_ + num_moe_inputs); + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); fc1_result_ = reinterpret_cast(expert_offset_ + padded_experts); // fc1_result_ = reinterpret_cast(permuted_data_ + buf_size); - + #ifdef DEBUG_MOE_OP - // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits before_topk")); - // print_to_screen1(finished, 2, 16, std::string("finished before_topk")); + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); #endif - thrust::transform( - thrust::cuda::par.on(stream), - thrust::device_pointer_cast(source_rows_), - thrust::device_pointer_cast(source_rows_) + num_rows * k, - thrust::counting_iterator(0), - thrust::device_pointer_cast(source_rows_), - [num_rows, k] __device__ (int i, int cnt) { - int k_idx = cnt % k; - int block_row = cnt / k; - return k_idx * num_rows + block_row; - } - ); - + thrust::transform(thrust::cuda::par.on(stream), + thrust::device_pointer_cast(source_rows_), + thrust::device_pointer_cast(source_rows_) + num_rows * k, + thrust::counting_iterator(0), + thrust::device_pointer_cast(source_rows_), + [num_rows, k] __device__(int i, int cnt) { + int k_idx = cnt % k; + int block_row = cnt / k; + return k_idx * num_rows + block_row; + }); + #ifdef DEBUG_MOE_OP - // phi::CastKernel(ctx, expert_scales_tensor_float, expert_scales_tensor.dtype(), &expert_scales_tensor); - print_to_screen1(combine_weights, 8, 16, std::string("expert_scales_float after topk")); - print_to_screen1(expert_id, 8, 16, std::string("expert-id before permute")); - print_to_screen1(source_rows_, 8, 16, std::string("desc->src idx before permute")); + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); #endif // compute global expert offset, **not** consider capacity // 必须在 modify_and_mask_expert_id_launcher 之前算出**全局 expert-offset** - compute_global_expert_offset(expert_id, - expert_id_, //buffer - expert_offset_global, - num_rows * k, - num_experts, - capacity, - stream, - allocator); - - // modifiy expert-id according to k - modify_and_mask_expert_id_launcher(expert_id, - expert_id_, - k, - num_rows, - static_cast(num_experts), - static_cast(expert_start_index), - static_cast(expert_end_index), - stream); - - - #ifdef DEBUG_MOE_OP - print_to_screen1(expert_id_, 8, 16, std::string("expert-id after modified 22")); + compute_global_expert_offset(expert_id, + expert_id_, // buffer + expert_offset_global, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + // modify expert-id according to k + modify_and_mask_expert_id_launcher(expert_id, + expert_id_, + k, + num_rows, + static_cast(num_experts), + static_cast(expert_start_index), + static_cast(expert_end_index), + stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id_, 8, 16, std::string("expert-id after modified 22")); #endif - sorter.run(fc1_result_, - sorter_ws_size_bytes, - expert_id_, // key in - permuted_experts_, // key out // [num_row, k]: expert-id - source_rows_, // value in - permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 - k * num_rows, // num_rows - false, - stream); - - unmodify_expert_id_launcher(permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + expert_id_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); #ifdef DEBUG_MOE_OP - print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); - print_to_screen1(permuted_rows_, 8, 16, std::string("dest->src idx after permute")); + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); #endif - compute_local_expert_offset( - permuted_experts_, - expert_offset_, - expert_nums_local, - num_rows * k, - num_experts, - capacity, - stream, - allocator); + compute_local_expert_offset(permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), - expert_offset_, - num_experts * sizeof(int64_t), - cudaMemcpyDeviceToHost, - stream)); + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); CUDACHECK(cudaStreamSynchronize(stream)); #ifdef DEBUG_MOE_OP - std::cerr << "[DEBUG] num_active v2: " << expert_offset_host.back() << std::endl; - print_to_screen1(expert_offset_global, 8, 16, std::string("expert_offset global")); + std::cerr << "[DEBUG] num_active v2: " << expert_offset_host.back() + << std::endl; + print_to_screen1( + expert_offset_global, 8, 16, std::string("expert_offset global")); print_to_screen1(expert_offset_, 8, 16, std::string("expert_offset local")); - print_to_screen1(permuted_experts_, 8, 16, std::string("expert-id after permute")); - // print_to_screen1(permuted_experts_, 4096, 8192, std::string("expert-id after permute")); + print_to_screen1(permuted_experts_, + 8, + 16, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); #endif - + // calc expert-size - // 不 use-pad 的情况下,在此处标记截断位置。之后需要再 sort 一遍把截断 id 放到句尾 - if (!use_pad){ // 2sort + // 不 use-pad 的情况下,在此处标记截断位置。之后需要再 sort 一遍把截断 id + // 放到句尾 + if (!use_pad) { // 2sort cal_expert_size_and_filter_launcher(permuted_experts_, expert_offset_, expert_offset_host.back(), @@ -290,232 +321,278 @@ void apply_moe_dispatch_fwd( expert_end_index, reverse_token_drop, stream); - //2sort - sorter.run(fc1_result_, - sorter_ws_size_bytes, - permuted_experts_, // key in - permuted_experts_, // key out // [num_row, k]: expert-id - permuted_rows_, // value in - permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 - k * num_rows, // num_rows - false, - stream); - - compute_local_expert_offset( - permuted_experts_, - expert_offset_, - expert_nums_local, - num_rows * k, - num_experts, - capacity, - stream, - allocator); - - CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), - expert_offset_, - num_experts * sizeof(int64_t), - cudaMemcpyDeviceToHost, - stream)); + // 2sort + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + permuted_experts_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + permuted_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + compute_local_expert_offset(permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); CUDACHECK(cudaStreamSynchronize(stream)); #ifdef DEBUG_MOE_OP - std::cerr << "[DEBUG](after 2sort) num_active v2: " << expert_offset_host.back() << std::endl; - print_to_screen1(expert_id_, 8, 16, std::string(" permuted_experts")); - print_to_screen1(permuted_experts_, 8, 16, std::string(" permuted_experts")); - print_to_screen1(permuted_rows_, 8,16, std::string(" dest->src idx")); -#endif + std::cerr << "[DEBUG](after 2sort) num_active v2: " + << expert_offset_host.back() << std::endl; + print_to_screen1( + expert_id_, 8, 16, std::string(" permuted_experts")); + print_to_screen1(permuted_experts_, + 8, + 16, + std::string(" permuted_experts")); + print_to_screen1( + permuted_rows_, 8, 16, std::string(" dest->src idx")); +#endif } thrust::fill( - thrust::cuda::par.on(stream), - thrust::device_ptr(scatter_index_rev), - thrust::device_ptr(scatter_index_rev) + num_experts * capacity, - num_rows - ); - build_seqsort_kv_pairs_kernel_launcher(scatter_index_rev, //padded_to_unpermuted_input - source_rows_for_seqsort_, //seqsort-value - permuted_rows_, - // scatter_index, // 对截断位置置0 - permuted_experts_, - expert_offset_, - combine_weights, // 对截断位置置0 - static_cast(num_rows), - static_cast(k), - expert_offset_host.back(), //num_active - capacity, - expert_start_index, // expert start index - use_pad, - stream); + thrust::cuda::par.on(stream), + thrust::device_ptr(scatter_index_rev), + thrust::device_ptr(scatter_index_rev) + num_experts * capacity, + num_rows); + build_seqsort_kv_pairs_kernel_launcher( + scatter_index_rev, // padded_to_unpermuted_input + source_rows_for_seqsort_, // seqsort-value + permuted_rows_, + // scatter_index, // 对截断位置置0 + permuted_experts_, + expert_offset_, + combine_weights, // 对截断位置置0 + static_cast(num_rows), + static_cast(k), + expert_offset_host.back(), // num_active + capacity, + expert_start_index, // expert start index + use_pad, + stream); #ifdef DEBUG_MOE_OP - // print_to_screen1(scatter_index, 8, 16, std::string("scatter_index after build_seqsort_kv_pairs_kernel_launcher")); - print_to_screen1(source_rows_for_seqsort_, 8, 16, std::string("source_rows_for_seqsort_ after build_seqsort_kv_pairs_kernel_launcher")); - print_to_screen1(scatter_index_rev, 8, 16, std::string("scatter_index_rev after build_seqsort_kv_pairs_kernel_launcher")); + // print_to_screen1(scatter_index, 8, 16, std::string("scatter_index + // after build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1(source_rows_for_seqsort_, + 8, + 16, + std::string("source_rows_for_seqsort_ after " + "build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1( + scatter_index_rev, + 8, + 16, + std::string( + "scatter_index_rev after build_seqsort_kv_pairs_kernel_launcher")); #endif - if (use_pad){ - for (auto iexpert = 0; iexpert != expert_end_index - expert_start_index; ++iexpert){ + if (use_pad) { + for (auto iexpert = 0; iexpert != expert_end_index - expert_start_index; + ++iexpert) { sorter.run(fc1_result_, - sorter_ws_size_bytes_seqsort, - scatter_index_rev + (iexpert * capacity), // key in - scatter_index_rev + (iexpert * capacity), // key out - source_rows_for_seqsort_ + (iexpert * capacity), // value in - source_rows_for_seqsort_ + (iexpert * capacity), // value out //[num_row, k]: id在原 activation 中的位置 - capacity, // num_rows - false, - stream); + sorter_ws_size_bytes_seqsort, + scatter_index_rev + (iexpert * capacity), // key in + scatter_index_rev + (iexpert * capacity), // key out + source_rows_for_seqsort_ + (iexpert * capacity), // value in + source_rows_for_seqsort_ + + (iexpert * capacity), // value out //[num_row, k]: id在原 + // activation 中的位置 + capacity, // num_rows + false, + stream); } - }else{ + } else { auto sort_iter = thrust::make_zip_iterator(thrust::make_tuple( - thrust::device_pointer_cast(permuted_experts_), //key1 - thrust::device_pointer_cast(scatter_index_rev), //key2 - thrust::device_pointer_cast(source_rows_for_seqsort_) - )); - thrust::stable_sort( - thrust::cuda::par.on(stream), - sort_iter, - sort_iter + expert_offset_host.back(), - []__device__(auto lhs, auto rhs){ - if (thrust::get<0>(lhs) < thrust::get<0>(rhs)) - return true; - else if(thrust::get<0>(lhs) > thrust::get<0>(rhs)) - return false; - else - return thrust::get<1>(lhs) < thrust::get<1>(rhs); - } - ); + thrust::device_pointer_cast(permuted_experts_), // key1 + thrust::device_pointer_cast(scatter_index_rev), // key2 + thrust::device_pointer_cast(source_rows_for_seqsort_))); + thrust::stable_sort(thrust::cuda::par.on(stream), + sort_iter, + sort_iter + expert_offset_host.back(), + [] __device__(auto lhs, auto rhs) { + if (thrust::get<0>(lhs) < thrust::get<0>(rhs)) + return true; + else if (thrust::get<0>(lhs) > thrust::get<0>(rhs)) + return false; + else + return thrust::get<1>(lhs) < thrust::get<1>(rhs); + }); } if (use_pad) { - int64_t num_experts_diff = expert_end_index - expert_start_index; - y->Resize({num_experts_diff * capacity, x.dims()[1]}); - dev_ctx.template Alloc(y); + int64_t num_experts_diff = expert_end_index - expert_start_index; + y->Resize({num_experts_diff * capacity, x.dims()[1]}); + dev_ctx.template Alloc(y); } else { - y->Resize({expert_offset_host.back(), x.dims()[1]}); - dev_ctx.template Alloc(y); + y->Resize({expert_offset_host.back(), x.dims()[1]}); + dev_ctx.template Alloc(y); } - phi::Full(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); - copy_unpermuted_to_permuted_kernelLauncher(x.data(), - y->data(), //out - scatter_index_rev, //padded_out_to_unpermuted_input - source_rows_for_seqsort_, //padded_out_to_expanded_input - scatter_index, //out - use_pad? (expert_end_index - expert_start_index) * capacity : expert_offset_host.back(), //num_active - num_rows, - k, - hidden_size, - stream); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + copy_unpermuted_to_permuted_kernelLauncher( + x.data(), + y->data(), // out + scatter_index_rev, // padded_out_to_unpermuted_input + source_rows_for_seqsort_, // padded_out_to_expanded_input + scatter_index, // out + use_pad ? (expert_end_index - expert_start_index) * capacity + : expert_offset_host.back(), // num_active + num_rows, + k, + hidden_size, + stream); // cudaDeviceSynchronize(); //debug // turn expert_offset_ptr into experts_num return; } template -void moe_dispatch_fwd(const Context& dev_ctx, - const DenseTensor& x, - int64_t num_rows, - int64_t num_experts, - int64_t hidden_size, - int64_t capacity, - int64_t k, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse_token_drop, - thrust::host_vector& expert_offset_host, - DenseTensor* y, - const DenseTensor& combine_weights, - const DenseTensor& scatter_index, - const DenseTensor& scatter_index_rev, - const DenseTensor& expert_offset, - const DenseTensor& expert_nums_local, - const DenseTensor& expert_id, - bool use_pad){ +void moe_dispatch_fwd( + const Context &dev_ctx, + const DenseTensor &x, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + thrust::host_vector &expert_offset_host, // NOLINT + DenseTensor *y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &scatter_index_rev, + const DenseTensor &expert_offset, + const DenseTensor &expert_nums_local, + const DenseTensor &expert_id, + bool use_pad) { apply_moe_dispatch_fwd( - dev_ctx, - x, - num_rows, - num_experts, - hidden_size, - capacity, - k, - expert_start_index, - expert_end_index, - reverse_token_drop, - expert_offset_host, - y, - const_cast(combine_weights.data()), - const_cast(scatter_index.data()), - const_cast(scatter_index_rev.data()), - const_cast(expert_offset.data()), - const_cast(expert_nums_local.data()), - const_cast(expert_id.data()), - use_pad, - dev_ctx.stream()); + dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(scatter_index_rev.data()), + const_cast(expert_offset.data()), + const_cast(expert_nums_local.data()), + const_cast(expert_id.data()), + use_pad, + dev_ctx.stream()); } template -void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& combine_weights, - const DenseTensor& expert_id, - int64_t k, - int64_t capacity, - int64_t num_experts, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse_token_drop, - DenseTensor* y, - DenseTensor* combine_weights_out, - DenseTensor* scatter_index, - DenseTensor* scatter_index_rev, - DenseTensor* expert_offset, - DenseTensor* expert_nums_local){ +void MoeGateDispatchPartialNoSoftMaxTopkKernel( + const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &combine_weights, + const DenseTensor &expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor *y, + DenseTensor *combine_weights_out, + DenseTensor *scatter_index, + DenseTensor *scatter_index_rev, + DenseTensor *expert_offset, + DenseTensor *expert_nums_local) { dev_ctx.template Alloc(scatter_index); dev_ctx.template Alloc(scatter_index_rev); dev_ctx.template Alloc(expert_offset); dev_ctx.template Alloc(expert_nums_local); dev_ctx.template Alloc(combine_weights_out); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(scatter_index->dims())), 0, scatter_index); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(scatter_index_rev->dims())), 0, scatter_index_rev); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(expert_offset->dims())), 0, expert_offset); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(expert_nums_local->dims())), 0, expert_nums_local); - phi::Full(dev_ctx, phi::IntArray(common::vectorize(combine_weights_out->dims())), 0, combine_weights_out); - phi::Copy(dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(scatter_index->dims())), + 0, + scatter_index); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(scatter_index_rev->dims())), + 0, + scatter_index_rev); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(expert_offset->dims())), + 0, + expert_offset); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(expert_nums_local->dims())), + 0, + expert_nums_local); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(combine_weights_out->dims())), + 0, + combine_weights_out); + phi::Copy( + dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); const auto &x_shape = x.dims(); int64_t num_rows = x_shape[0]; int64_t hidden_size = x_shape[1]; thrust::host_vector expert_offset_host(num_experts); int64_t num_experts_diff = expert_end_index - expert_start_index; - moe_dispatch_fwd(dev_ctx, - x, - num_rows, - num_experts, - hidden_size, - capacity, - k, - expert_start_index, - expert_end_index, - reverse_token_drop, - expert_offset_host, - y, - *combine_weights_out, - *scatter_index, - *scatter_index_rev, - *expert_offset, //global-offset - *expert_nums_local, - expert_id, - use_pad - ); - if(use_pad){ - // scatter_index_rev = scatter_index_rev.slice(0, num_experts_diff * capacity); - *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); - }else{ - if (expert_offset_host.back() > 0){ - // scatter_index_rev = scatter_index_rev.slice(0, expert_offset_host.back()); - *scatter_index_rev = phi::Slice(dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); - }else{ + moe_dispatch_fwd(dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + *combine_weights_out, + *scatter_index, + *scatter_index_rev, + *expert_offset, // global-offset + *expert_nums_local, + expert_id, + use_pad); + if (use_pad) { + // scatter_index_rev = scatter_index_rev.slice(0, num_experts_diff * + // capacity); + *scatter_index_rev = phi::Slice( + dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); + } else { + if (expert_offset_host.back() > 0) { + // scatter_index_rev = scatter_index_rev.slice(0, + // expert_offset_host.back()); + *scatter_index_rev = phi::Slice( + dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); + } else { *y = phi::Empty(dev_ctx, {1, x_shape[1]}); - *scatter_index_rev = phi::Empty(dev_ctx, {}); //special treatment + *scatter_index_rev = + phi::Empty(dev_ctx, {}); // special treatment } } } @@ -528,4 +605,4 @@ PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} \ No newline at end of file + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/int_bincount.h b/paddle/phi/kernels/int_bincount.h index 0c4286eeadd39c..18c44cc520505e 100644 --- a/paddle/phi/kernels/int_bincount.h +++ b/paddle/phi/kernels/int_bincount.h @@ -1,22 +1,37 @@ -#include "paddle/phi/core/utils/data_type.h" -#include "paddle/common/flags.h" -#include +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once #include +#include #include "cub/device/device_histogram.cuh" +#include "paddle/common/flags.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/empty_kernel.h" // NOLINT #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -namespace phi{ +namespace phi { -template -void IntBincount(const Context& ctx, - const DenseTensor &x, - int64_t low, - int64_t high, - int64_t out_dtype, - DenseTensor* out); -} \ No newline at end of file +template +void IntBincount(const Context& ctx, + const DenseTensor& x, + int64_t low, + int64_t high, + int64_t out_dtype, + DenseTensor* out); +} diff --git a/paddle/phi/kernels/layer_norm_cuda_kernel.h b/paddle/phi/kernels/layer_norm_cuda_kernel.h index 1dcfebf890be98..2e96debcacf3d1 100644 --- a/paddle/phi/kernels/layer_norm_cuda_kernel.h +++ b/paddle/phi/kernels/layer_norm_cuda_kernel.h @@ -1,3 +1,4 @@ +// NOLINT // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,15 +15,14 @@ #pragma once -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/selected_rows.h" #include "paddle/common/exception.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" #include // NOLINT #include // NOLINT -namespace phi{ +namespace phi { #define DEFAULT_THROW(NAME, TYPE) \ default: \ do { \ @@ -32,36 +32,36 @@ namespace phi{ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ - case float: { \ + case float: { \ using scalar_t_in = float; \ switch (TYPEOUT) { \ - case float: { \ + case float: { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ - DEFAULT_THROW(NAME, TYPEOUT); \ + DEFAULT_THROW(NAME, TYPEOUT); \ } \ break; \ } \ - DEFAULT_THROW(NAME, TYPEIN); \ + DEFAULT_THROW(NAME, TYPEIN); \ } #define WARP_SIZE 32 template __device__ __forceinline__ T WARP_SHFL_XOR(T value, - int laneMask, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { + int laneMask, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { return __shfl_xor_sync(mask, value, laneMask, width); } template __device__ __forceinline__ T WARP_SHFL(T value, - int srcLane, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { + int srcLane, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { return __shfl_sync(mask, value, srcLane, width); } @@ -101,23 +101,16 @@ __device__ void cuChanOnlineSum(const U muB, } } -template __device__ -void cuRMSOnlineSum( - const U curr, - U& sigma2) -{ +template +__device__ void cuRMSOnlineSum(const U curr, U& sigma2) { // NOLINT sigma2 = sigma2 + curr * curr; } -template __device__ -void cuChanRMSOnlineSum( - const U sigma2B, - U& sigma2) -{ +template +__device__ void cuChanRMSOnlineSum(const U sigma2B, U& sigma2) { // NOLINT sigma2 = sigma2 + sigma2B; } - template __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, @@ -125,7 +118,8 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int i1, U& mu, // NOLINT U& sigma2, // NOLINT - U* buf, bool rms_only) { + U* buf, + bool rms_only) { // Assumptions: // 1) blockDim.x == WARP_SIZE // 2) Tensor is contiguous @@ -147,7 +141,7 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); + cuWelfordOnlineSum(curr, mu, sigma2, count); } else { cuRMSOnlineSum(curr, sigma2); } @@ -156,7 +150,7 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, for (; l < n2; ++l) { U curr = static_cast(lvals[l]); if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); + cuWelfordOnlineSum(curr, mu, sigma2, count); } else { cuRMSOnlineSum(curr, sigma2); } @@ -168,7 +162,7 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, if (!rms_only) { U muB = WARP_SHFL(mu, srcLaneB); U countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } else { cuChanRMSOnlineSum(sigma2B, sigma2); } @@ -184,7 +178,7 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, threadIdx.y < 2 * offset) { const int wrt_y = threadIdx.y - offset; if (!rms_only) { - ubuf[2*wrt_y] = mu; + ubuf[2 * wrt_y] = mu; ibuf[wrt_y] = count; } ubuf[2 * wrt_y + 1] = sigma2; @@ -194,11 +188,11 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals, if (threadIdx.x == 0 && threadIdx.y < offset) { U sigma2B = ubuf[2 * threadIdx.y + 1]; if (!rms_only) { - U muB = ubuf[2*threadIdx.y]; + U muB = ubuf[2 * threadIdx.y]; U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } else { - cuChanRMSOnlineSum(sigma2B,sigma2); + cuChanRMSOnlineSum(sigma2B, sigma2); } } __syncthreads(); @@ -233,7 +227,8 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, const int i1, float& mu, // NOLINT float& sigma2, // NOLINT - float* buf, bool rms_only) { + float* buf, + bool rms_only) { // Assumptions: // 1) blockDim.x == WARP_SIZE // 2) Tensor is contiguous @@ -257,7 +252,7 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, if (thrx == 0) { float curr = static_cast(lvals[0]); if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); + cuWelfordOnlineSum(curr, mu, sigma2, count); } else { cuRMSOnlineSum(curr, sigma2); } @@ -269,8 +264,8 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, for (int k = 0; k < 8; k += 2) { float2 curr = __half22float2(*((__half2*)(lvals + l + k))); // NOLINT if (!rms_only) { - cuWelfordOnlineSum(curr.x,mu,sigma2,count); - cuWelfordOnlineSum(curr.y,mu,sigma2,count); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); } else { cuRMSOnlineSum(curr.x, sigma2); cuRMSOnlineSum(curr.y, sigma2); @@ -280,7 +275,7 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, for (; l < n2; ++l) { float curr = static_cast(lvals[l]); if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); + cuWelfordOnlineSum(curr, mu, sigma2, count); } else { cuRMSOnlineSum(curr, sigma2); } @@ -292,7 +287,7 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, if (!rms_only) { float muB = WARP_SHFL(mu, srcLaneB); float countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } else { cuChanRMSOnlineSum(sigma2B, sigma2); } @@ -309,7 +304,7 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, const int wrt_y = threadIdx.y - offset; ubuf[2 * wrt_y + 1] = sigma2; if (!rms_only) { - ubuf[2*wrt_y] = mu; + ubuf[2 * wrt_y] = mu; ibuf[wrt_y] = count; } } @@ -318,9 +313,9 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, if (threadIdx.x == 0 && threadIdx.y < offset) { float sigma2B = ubuf[2 * threadIdx.y + 1]; if (!rms_only) { - float muB = ubuf[2*threadIdx.y]; + float muB = ubuf[2 * threadIdx.y]; float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } else { cuChanRMSOnlineSum(sigma2B, sigma2); } @@ -393,14 +388,15 @@ struct SharedMemory { template __device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, bool rms_only) { + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, + bool rms_only) { // Assumptions: // 1) blockDim.x == WARP_SIZE // 2) Tensors are contiguous @@ -415,11 +411,13 @@ __device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, U c_invvar = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && (beta != NULL || rms_only) ) { + if (gamma != NULL && (beta != NULL || rms_only)) { for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); if (!rms_only) { - ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * (curr - mu) + static_cast(beta[i])); + ovals[i] = + static_cast(static_cast(gamma[i]) * c_invvar * (curr - mu) + + static_cast(beta[i])); } else { ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * curr); } @@ -444,34 +442,30 @@ __device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, } } -template __global__ -void cuApplyLayerNorm( - V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta - ) -{ - cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + cuApplyLayerNorm_( + output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); } - -template __global__ -void cuApplyRMSNorm( - V* __restrict__ output_vals, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma) -{ - cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); +template +__global__ void cuApplyRMSNorm(V* __restrict__ output_vals, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma) { + cuApplyLayerNorm_( + output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); } template @@ -487,7 +481,8 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block, const int i1_end, const int n2, const U* __restrict__ mean, - const U* __restrict__ invvar, bool rms_only) { + const U* __restrict__ invvar, + bool rms_only) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean; @@ -504,9 +499,10 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block, U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; } else { - warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; } } else { if (!rms_only) { @@ -539,7 +535,8 @@ __device__ void cuLoadAddStridedInputs(const int i1_block, const int i1_end, const int n2, const U* __restrict__ mean, - const U* __restrict__ invvar, bool rms_only) { + const U* __restrict__ invvar, + bool rms_only) { int i1 = i1_block + thr_load_row_off; if (i1 < i1_end) { U curr_mean; @@ -556,9 +553,10 @@ __device__ void cuLoadAddStridedInputs(const int i1_block, U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; } else { - warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; } } } @@ -574,7 +572,8 @@ __global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const U* __restrict__ invvar, U epsilon, U* part_grad_gamma, - U* part_grad_beta, bool rms_only) { + U* part_grad_beta, + bool rms_only) { const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; @@ -607,7 +606,8 @@ __global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, i1_end, n2, mean, - invvar, rms_only); + invvar, + rms_only); for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) { cuLoadAddStridedInputs(i1_block, @@ -622,7 +622,8 @@ __global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, i1_end, n2, mean, - invvar,rms_only); + invvar, + rms_only); } __syncthreads(); // inter-warp reductions @@ -639,7 +640,7 @@ __global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, } if (!rms_only) { - warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; } warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; __syncthreads(); @@ -664,7 +665,7 @@ __global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, int idx1 = row1 * row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x; if (!rms_only) { - part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; } part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; } @@ -677,7 +678,8 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const int n1, const int n2, V* grad_gamma, - V* grad_beta, bool rms_only) { + V* grad_beta, + bool rms_only) { // sum partial gradients for gamma and beta SharedMemory shared; U* buf = shared.getPointer(); @@ -695,7 +697,7 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; if (!rms_only) { - sum_beta += part_grad_beta_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; } } // inter-warp reductions @@ -706,7 +708,7 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; buf[write_idx] = sum_gamma; if (!rms_only) { - buf[write_idx+nbsize3] = sum_beta; + buf[write_idx + nbsize3] = sum_beta; } } __syncthreads(); @@ -715,7 +717,7 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; if (!rms_only) { - sum_beta += buf[read_idx+nbsize3]; + sum_beta += buf[read_idx + nbsize3]; } } __syncthreads(); @@ -739,8 +741,8 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, const U* __restrict__ invvar, U epsilon, const V* gamma, - T* grad_input, bool rms_only) { - + T* grad_input, + bool rms_only) { for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); @@ -764,7 +766,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, sum_loss1 += c_loss * gamma_tmp; sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; } else { - sum_loss2 += c_loss * gamma_tmp * (c_h) * c_invvar; + sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; } } } @@ -776,7 +778,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, sum_loss1 += c_loss * gamma_tmp; sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; } else { - sum_loss2 += c_loss * gamma_tmp * (c_h) * c_invvar; + sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; } } } else { @@ -789,7 +791,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + sum_loss2 += c_loss * (c_h)*c_invvar; } } } @@ -800,7 +802,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + sum_loss2 += c_loss * (c_h)*c_invvar; } } } @@ -820,7 +822,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; if (!rms_only) { - buf[2*wrt_i] = sum_loss1; + buf[2 * wrt_i] = sum_loss1; } buf[2 * wrt_i + 1] = sum_loss2; } @@ -829,7 +831,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; if (!rms_only) { - sum_loss1 += buf[2*read_i]; + sum_loss1 += buf[2 * read_i]; } sum_loss2 += buf[2 * read_i + 1]; } @@ -837,14 +839,14 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, } if (threadIdx.y == 0) { if (!rms_only) { - buf[2*threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x] = sum_loss1; } buf[2 * threadIdx.x + 1] = sum_loss2; } __syncthreads(); if (threadIdx.y != 0) { if (!rms_only) { - sum_loss1 = buf[2*threadIdx.x]; + sum_loss1 = buf[2 * threadIdx.x]; } sum_loss2 = buf[2 * threadIdx.x + 1]; } @@ -862,7 +864,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + f_grad_input -= (c_h)*c_invvar * sum_loss2; } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); @@ -876,7 +878,7 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout, f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + f_grad_input -= (c_h)*c_invvar * sum_loss2; } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); @@ -920,31 +922,29 @@ void HostApplyLayerNorm(V* output, output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); } -template -void HostApplyRMSNorm( - V* output, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, cudaStream_t stream) -{ - // auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32,4,1); - // const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? - threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : - 0; - cuApplyRMSNorm<<>>( +template +void HostApplyRMSNorm(V* output, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + cudaStream_t stream) { + // auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + // const uint64_t maxGridY = + // at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyRMSNorm<<>>( output, invvar, input, n1, n2, U(epsilon), gamma); } // template -// void cuda_layer_norm(const Context& ctx, +// void cuda_layer_norm(const Context& ctx, // const DenseTensor& x, // const DenseTensor& scale, // const DenseTensor& bias, @@ -970,37 +970,37 @@ void HostApplyRMSNorm( // ctx.stream())); // } -template -void cuda_rms_norm(const Context& ctx, - const DenseTensor& x, - const DenseTensor& scale, - int rows, - int cols, - float epsilon, - DenseTensor* y, - DenseTensor* invvar) { +template +void cuda_rms_norm(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + int rows, + int cols, + float epsilon, + DenseTensor* y, + DenseTensor* invvar) { HostApplyRMSNorm(y->data(), - invvar->data(), - const_cast(x.data()), - rows, - cols, - epsilon, - const_cast(scale.data()), - ctx.stream()); + invvar->data(), + const_cast(x.data()), + rows, + cols, + epsilon, + const_cast(scale.data()), + ctx.stream()); } template -void HostRMSNormGradient( const Context& ctx, - const V* dout, - const U* invvar, - const DenseTensor& input, - int n1, - int n2, - const V* gamma, - double epsilon, - T* grad_input, - V* grad_gamma, - cudaStream_t stream) { +void HostRMSNormGradient(const Context& ctx, + const V* dout, + const U* invvar, + const DenseTensor& input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + cudaStream_t stream) { if (gamma != NULL) { const int part_size = 16; const dim3 threads2(32, 4, 1); @@ -1010,19 +1010,19 @@ void HostRMSNormGradient( const Context& ctx, const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; auto place = input.place(); - DenseTensor part_grad_gamma = phi::Empty(ctx, {part_size, n2}); + DenseTensor part_grad_gamma = + phi::Empty(ctx, {part_size, n2}); cuComputePartGradGammaBeta<<>>( dout, input.data(), n1, n2, - invvar, // unused + invvar, // unused invvar, U(epsilon), part_grad_gamma.data(), - part_grad_gamma.data(), /* unused */ - true - ); + part_grad_gamma.data(), /* unused */ + true); const dim3 threads3(32, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); @@ -1034,8 +1034,8 @@ void HostRMSNormGradient( const Context& ctx, n1, n2, grad_gamma, - grad_gamma, /* unused */ - true); + grad_gamma, /* unused */ + true); } // compute grad_input @@ -1043,41 +1043,41 @@ void HostRMSNormGradient( const Context& ctx, const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 threads1(32, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>(dout, - input.data(), - n1, - n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); + cuComputeGradInput<<>>( + dout, + input.data(), + n1, + n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); } -template -void cuda_rms_norm_gradient(const Context& ctx, - const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& invvar, - const DenseTensor& dy, - int rows, - int cols, - float epsilon, - DenseTensor* grad_x, - DenseTensor* grad_scale) { - HostRMSNormGradient( - ctx, - dy.data(), - invvar.data(), - x, - rows, - cols, - scale.data(), - epsilon, - grad_x->data(), - grad_scale->data(), - ctx.stream()); +template +void cuda_rms_norm_gradient(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& invvar, + const DenseTensor& dy, + int rows, + int cols, + float epsilon, + DenseTensor* grad_x, + DenseTensor* grad_scale) { + HostRMSNormGradient(ctx, + dy.data(), + invvar.data(), + x, + rows, + cols, + scale.data(), + epsilon, + grad_x->data(), + grad_scale->data(), + ctx.stream()); } -} \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_combine_grad_kernel.h b/paddle/phi/kernels/moe_combine_grad_kernel.h index 43682c941f87fe..7468d9e944ce34 100644 --- a/paddle/phi/kernels/moe_combine_grad_kernel.h +++ b/paddle/phi/kernels/moe_combine_grad_kernel.h @@ -24,4 +24,4 @@ void MoeCombineGradKernel(const Context& dev_ctx, const DenseTensor& grad_y, DenseTensor* grad_x, DenseTensor* grad_combine_weights_helper); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_combine_kernel.h b/paddle/phi/kernels/moe_combine_kernel.h index 8225241c018759..8057833db0f604 100644 --- a/paddle/phi/kernels/moe_combine_kernel.h +++ b/paddle/phi/kernels/moe_combine_kernel.h @@ -22,4 +22,4 @@ void MoeCombineKernel(const Context& dev_ctx, const DenseTensor& combine_weights, const DenseTensor& scatter_index, DenseTensor* out); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index b2b30638d59966..e5af322d04f2f5 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -13,133 +13,140 @@ // limitations under the License. #pragma once -#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/common/exception.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" -template -__global__ void gather_with_mask_permute_kernel(const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s, d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t N, - int64_t num_active, // skip > num_active pos is num_active specified - int64_t s_shared_num, - int64_t capacity, - int64_t world_size, - int64_t num_local_experts - ){ - extern __shared__ char shared[]; - int* scatter_index_shared = reinterpret_cast(shared); - float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); - int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; +template +__global__ void gather_with_mask_permute_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num, + int64_t capacity, + int64_t world_size, + int64_t num_local_experts) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; - for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ - idx < N; idx += blockDim.x * gridDim.x * vec_size) { - int64_t si = idx / dim; - int64_t di_begin = idx % dim; - int64_t si_shared_begin = shared_idx_begin / dim; - int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); - for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { - if (si_shared_begin * k + i >= num_rows * k) { - break; - } - scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; - combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; - } - __syncthreads(); - - phi::AlignedVector in_vec; - phi::AlignedVector out_vec; - for (int ii = 0; ii < vec_size; ++ii) { - out_vec[ii] = static_cast(0); + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); - for (int64_t i = 0; i < k; ++i) { - int64_t scatter_offset = (si - si_shared_begin) * k + i; - int id = scatter_index_shared[scatter_offset]; - if (num_active >= 0 && id >= num_active){ - continue; - } - if (combine_weights_shared[scatter_offset] > 0.f){ - int64_t remaining_after_irank = id % (num_local_experts * capacity); + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t remaining_after_irank = id % (num_local_experts * capacity); - int64_t irank = id / (num_local_experts * capacity); - int64_t local_iexpert = remaining_after_irank / capacity; - int64_t row_in_expert = remaining_after_irank % capacity; - int64_t permuted_id = local_iexpert * (world_size * capacity) + irank * capacity + row_in_expert; - int64_t in_offset = permuted_id * dim + di_begin; - phi::Load(dy + in_offset, &in_vec); - for (int64_t j = 0; j < vec_size; ++j) { - out_vec[j] += in_vec[j]; - } + int64_t irank = id / (num_local_experts * capacity); + int64_t local_iexpert = remaining_after_irank / capacity; + int64_t row_in_expert = remaining_after_irank % capacity; + int64_t permuted_id = local_iexpert * (world_size * capacity) + + irank * capacity + row_in_expert; + int64_t in_offset = permuted_id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; } } - phi::Store(out_vec, dx + idx); - shared_idx_begin += blockDim.x * gridDim.x * vec_size; } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } } -template -__global__ void gather_with_mask_kernel(const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s, d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t N, - int64_t num_active, // skip > num_active pos is num_active specified - int64_t s_shared_num - ){ - extern __shared__ char shared[]; - int* scatter_index_shared = reinterpret_cast(shared); - float* combine_weights_shared = reinterpret_cast(shared + s_shared_num * k * sizeof(int)); - int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; +template +__global__ void gather_with_mask_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; - for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; \ - idx < N; idx += blockDim.x * gridDim.x * vec_size) { - int64_t si = idx / dim; - int64_t di_begin = idx % dim; - int64_t si_shared_begin = shared_idx_begin / dim; - int64_t shared_stride = min(static_cast(blockDim.x), N - shared_idx_begin); + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); - for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { - if (si_shared_begin * k + i >= num_rows * k) { - break; - } - scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; - combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; - } - __syncthreads(); - - phi::AlignedVector in_vec; - phi::AlignedVector out_vec; - for (int ii = 0; ii < vec_size; ++ii) { - out_vec[ii] = static_cast(0); + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); - for (int64_t i = 0; i < k; ++i) { - int64_t scatter_offset = (si - si_shared_begin) * k + i; - int id = scatter_index_shared[scatter_offset]; - if (num_active >= 0 && id >= num_active){ - continue; - } - if (combine_weights_shared[scatter_offset] > 0.f){ - int64_t in_offset = id * dim + di_begin; - phi::Load(dy + in_offset, &in_vec); - for (int64_t j = 0; j < vec_size; ++j) { - out_vec[j] += in_vec[j]; - } + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t in_offset = id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; } } - phi::Store(out_vec, dx + idx); - shared_idx_begin += blockDim.x * gridDim.x * vec_size; } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } } template @@ -147,7 +154,10 @@ inline T DivUp(T a, T b) { return (a + b - 1) / b; } -inline int64_t max_shared_s_num(int64_t num_rows, int64_t dim, int64_t threads, int64_t vec_size) { +inline int64_t max_shared_s_num(int64_t num_rows, + int64_t dim, + int64_t threads, + int64_t vec_size) { if ((threads * vec_size) % dim == 0) { return min(num_rows, threads * vec_size / dim); } else { @@ -161,149 +171,141 @@ inline int64_t max_shared_s_num(int64_t num_rows, int64_t dim, int64_t threads, } } -template -void gather_with_mask_launcher(const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s,k,d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t num_active, - cudaStream_t stream, - bool use_all2all_permute = false, - int64_t world_size = -1, - int64_t num_local_experts = -1, - int64_t capacity = -1 -){ - int numel = num_rows * dim; +template +void gather_with_mask_launcher(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s,k,d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t num_active, + cudaStream_t stream, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1, + int64_t capacity = -1) { + int numel = num_rows * dim; - int64_t threads = 512; - if (dim % 4 == 0) { - int64_t blocks = DivUp(DivUp(numel, 4), threads); - int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); - size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + int64_t threads = 512; + if (dim % 4 == 0) { + int64_t blocks = DivUp(DivUp(numel, 4), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); - if (!use_all2all_permute) { - gather_with_mask_kernel<<>>( - dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num); - } else { - PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); - gather_with_mask_permute_kernel<<>>( - dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num, - capacity, - world_size, - num_local_experts); - } + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); } else { - int64_t blocks = DivUp(DivUp(numel, 1), threads); - int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); - size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } else { + int64_t blocks = DivUp(DivUp(numel, 1), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); #ifdef DEBUG_MOE_OP - std::cerr << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" << s_shared_num << ", block=" << blocks << std::endl; + std::cerr + << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" + << s_shared_num << ", block=" << blocks << std::endl; #endif - if (!use_all2all_permute) { - gather_with_mask_kernel<<>>( - dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num); - } else { - gather_with_mask_permute_kernel<<>>( - dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num, - capacity, - world_size, - num_local_experts); - } + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); } + } } -template -__global__ void topk_grad_with_mask(const T* dy, // [s, k] - const int* topk_idx, // [s, k] - const T* combine_weights, // [s, k] - T* dx, // [s, e] - int64_t num_rows, // s - int64_t k, // k - int64_t num_experts // e - ){ - // init dx to zero - for (int i = blockIdx.x; i < num_rows; i+=gridDim.x){ - int base_grad = i * num_experts; - for (int j = threadIdx.x; j < num_experts; j+=blockDim.x){ - dx[base_grad + j] = static_cast(0); - } - __syncthreads(); - int base_index = i * k; - for (int j = threadIdx.x; j < k; j+=blockDim.x){ - int64_t idx = topk_idx[base_index + j]; - if (combine_weights[base_index + j] > static_cast(0)){ - dx[base_grad + idx] = dy[base_index + j]; - } - } +template +__global__ void topk_grad_with_mask(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts // e +) { + // init dx to zero + for (int i = blockIdx.x; i < num_rows; i += gridDim.x) { + int base_grad = i * num_experts; + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + dx[base_grad + j] = static_cast(0); + } + __syncthreads(); + int base_index = i * k; + for (int j = threadIdx.x; j < k; j += blockDim.x) { + int64_t idx = topk_idx[base_index + j]; + if (combine_weights[base_index + j] > static_cast(0)) { + dx[base_grad + idx] = dy[base_index + j]; + } } + } } - // y=zero_part(topk(x)) 的反向过程 // x: [s,e] // dy: [s,k] -// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> conbine_weights: [s, k] -template -void topk_grad_with_mask_launcher( - const T* dy, // [s, k] - const int* topk_idx, // [s, k] - const T* combine_weights, // [s, k] - T* dx, // [s, e] - int64_t num_rows, // s - int64_t k, // k - int64_t num_experts, // e - cudaStream_t stream){ +// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> combine_weights: [s, k] +template +void topk_grad_with_mask_launcher(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts, // e + cudaStream_t stream) { + int blocks = num_rows; + int threads = 1024; - int blocks = num_rows; - int threads = 1024; - - topk_grad_with_mask<<>>(dy, - topk_idx, - combine_weights, - dx, - num_rows, - k, - num_experts - ); -} \ No newline at end of file + topk_grad_with_mask<<>>( + dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); +} diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index bbda9e9e21e45c..40c2fc3fcf57a5 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -1,3 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include // 包含常用的 thrust 算法 #include @@ -453,9 +467,9 @@ void initialize_moe_routing_permute_kernelLauncher( template void compute_global_expert_offset( - const T* expert_id, //[len] - T* sort_buffer, //[len] - int64_t* expert_offset, //[num_experts] + const T* expert_id, // [len] + T* sort_buffer, // [len] + int64_t* expert_offset, // [num_experts] const int64_t len, const int64_t num_experts, const int64_t capacity, @@ -534,8 +548,8 @@ void modify_and_mask_expert_id_launcher(const T* expert_id, template void compute_local_expert_offset( - const T* sorted_expert_id, //[len] - int64_t* expert_offset, //[num_experts] + const T* sorted_expert_id, // [len] + int64_t* expert_offset, // [num_experts] int64_t* expert_num, const int64_t len, const int64_t num_experts, @@ -553,7 +567,7 @@ void compute_local_expert_offset( compute_total_rows_before_expert_kernel<<>>( sorted_expert_id, len, num_experts, expert_offset); - // 不考虑 capcity 影响 + // 不考虑 capacity 影响 thrust::adjacent_difference( exec_policy, offset_ptr, offset_ptr + num_experts, expert_num_ptr); } @@ -563,7 +577,7 @@ __global__ void cal_expert_size_and_filter(T* expert_id, const int64_t* expert_offset, int64_t len, int64_t num_experts, - int64_t capcity, + int64_t capacity, int64_t expert_start_index, int64_t expert_end_index, bool reverse) { @@ -582,11 +596,11 @@ __global__ void cal_expert_size_and_filter(T* expert_id, } } if (reverse) { - if (((off - 1) - idx) >= capcity) { + if (((off - 1) - idx) >= capacity) { expert_id[idx] = num_experts; } } else { - if ((idx - off) >= capcity) { + if ((idx - off) >= capacity) { expert_id[idx] = num_experts; } } @@ -597,7 +611,7 @@ void cal_expert_size_and_filter_launcher(T* expert_id, const int64_t* expert_offset, int64_t len, int64_t num_experts, - int64_t capcity, + int64_t capacity, int64_t expert_start_index, int64_t expert_end_index, bool reverse, @@ -610,7 +624,7 @@ void cal_expert_size_and_filter_launcher(T* expert_id, expert_offset, len, num_experts, - capcity, + capacity, expert_start_index, expert_end_index, reverse); @@ -796,4 +810,4 @@ void copy_unpermuted_to_permuted_kernelLauncher( k, num_cols); } -} \ No newline at end of file +} diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h index cce8c2c7cc0c4a..f8cd9bee0d6083 100644 --- a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h @@ -28,4 +28,4 @@ void MoeGateDispatchGradKernel(const Context& dev_ctx, int64_t world_size, DenseTensor* x_grad, DenseTensor* gate_logits_grad); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h index 2ca266150783cc..1d6c1f5fed0b33 100644 --- a/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h @@ -25,8 +25,8 @@ void MoEDispatchPermuteKernel(const Context& dev_ctx, int64_t capacity, int64_t world_size, DenseTensor* y, - DenseTensor* combine_weights, - DenseTensor* scatter_index, - DenseTensor* expert_offset, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, DenseTensor* expert_id); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_kernel_impl.h b/paddle/phi/kernels/moe_kernel_impl.h index 4db8ee46954116..2881463e8d3045 100644 --- a/paddle/phi/kernels/moe_kernel_impl.h +++ b/paddle/phi/kernels/moe_kernel_impl.h @@ -1,5 +1,5 @@ +// NOLINT /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include +#include #include #include "cub/cub.cuh" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" -#include -#include -#include -#include namespace phi { static const float HALF_FLT_MAX = 65504.F; @@ -33,25 +33,25 @@ class CubKeyValueSorter { public: inline CubKeyValueSorter(); - inline CubKeyValueSorter(cudaStream_t stream = 0); + inline CubKeyValueSorter(cudaStream_t stream = 0); // NOLINT inline explicit CubKeyValueSorter(const int num_experts); inline void update_num_experts(const int num_experts); inline size_t getWorkspaceSize(const size_t num_key_value_pairs, - bool descending = false); + bool descending = false); template inline void run(void* workspace, - const size_t workspace_size, - const KeyT* keys_in, - KeyT* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream); + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream); private: size_t num_key_value_pairs_; @@ -60,13 +60,12 @@ class CubKeyValueSorter { cudaStream_t stream_; }; - // ===== CUB Sorting things ===== CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) - : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} CubKeyValueSorter::CubKeyValueSorter(const int num_experts) : num_experts_(num_experts), @@ -74,7 +73,8 @@ CubKeyValueSorter::CubKeyValueSorter(const int num_experts) void CubKeyValueSorter::update_num_experts(const int num_experts) { num_experts_ = num_experts; - num_bits_ = static_cast(log2(num_experts)) + 3; //额外增加 3 位用于标记 topk的位置 + num_bits_ = static_cast(log2(num_experts)) + + 3; // 额外增加 3 位用于标记 topk的位置 } size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, @@ -108,17 +108,16 @@ size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, return required_storage; } - template inline void CubKeyValueSorter::run(void* workspace, - const size_t workspace_size, - const KeyT* keys_in, - KeyT* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream) { + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); size_t actual_ws_size = workspace_size; @@ -158,14 +157,14 @@ inline void CubKeyValueSorter::run(void* workspace, template <> inline void CubKeyValueSorter::run(void* workspace, - const size_t workspace_size, - const __nv_bfloat16* keys_in, - __nv_bfloat16* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - bool descending, - cudaStream_t stream) {} + const size_t workspace_size, + const __nv_bfloat16* keys_in, + __nv_bfloat16* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) {} // CubKeyValueSorter sorter_(stream); @@ -644,4 +643,4 @@ __global__ void initialize_moe_routing_kernel( } } -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h index 9929687d42d4fa..c5cdcbfe6f4443 100644 --- a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h @@ -15,22 +15,23 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -namespace phi{ +namespace phi { template -void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx, - const DenseTensor& combine_weights_out, - const DenseTensor& scatter_index, - const DenseTensor& scatter_index_rev, - const DenseTensor& expert_offset, - const DenseTensor& expert_offset_local, - const DenseTensor& y_grad, - const DenseTensor& combine_weights_out_grad, - int64_t k, - int64_t capacity, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - DenseTensor* x_grad, - DenseTensor* combine_weights_grad); +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h index fbe517b531066c..0ecf6afda63e91 100644 --- a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h @@ -15,24 +15,25 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -namespace phi{ +namespace phi { template -void MoeGateDispatchPartialNoSoftMaxTopkKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& combine_weights, - const DenseTensor& expert_id, - int64_t k, - int64_t capacity, - int64_t num_experts, - bool use_pad, - int64_t expert_start_index, - int64_t expert_end_index, - bool reverse_token_drop, - DenseTensor* y, - DenseTensor* combine_weights_out, - DenseTensor* scatter_index, - DenseTensor* scatter_index_rev, - DenseTensor* expert_offset, - DenseTensor* expert_nums_local); +void MoeGateDispatchPartialNoSoftMaxTopkKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor* y, + DenseTensor* combine_weights_out, + DenseTensor* scatter_index, + DenseTensor* scatter_index_rev, + DenseTensor* expert_offset, + DenseTensor* expert_nums_local); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 151e0006a57369..2f67cd8ebcb81d 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3871,6 +3871,15 @@ func: check_model_nan_inf data_type: out_grad +- backward_op: fused_rms_norm_grad + forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) + output: Tensor(x_grad), Tensor(scale_grad) + infer_meta: + func: FusedRMSNormGradInferMeta + kernel: + func: fused_rms_norm_grad + - backward_op: im2sequence_grad forward: im2sequence (Tensor x, Tensor y, int[] kernels, int[] strides = {1, 1}, int[] paddings = {0, 0, 0, 0}, int[] out_stride = {1, 1}) -> Tensor (out) @@ -3977,12 +3986,3 @@ param : [condition] composite: where_double_grad(condition, grad_x_grad, grad_y_grad, grad_out_grad) optional: grad_x_grad, grad_y_grad - -- backward_op: fused_rms_norm_grad - forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) - args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) - output: Tensor(x_grad), Tensor(scale_grad) - infer_meta: - func: FusedRMSNormGradInferMeta - kernel: - func: fused_rms_norm_grad \ No newline at end of file diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 03586b2f51353b..7905b122556c3c 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3647,7 +3647,7 @@ data_type : x optional : corr_bias backward : moe_gate_dispatch_grad - + - op : moe_gate_dispatch_partial_nosoftmaxtopk args : (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) output : Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) @@ -5739,16 +5739,15 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface traits : paddle::dialect::ForwardOnlyTrait -- op: number_count - args: (Tensor numbers, int upper_range) - output: Tensor(out) +- op: fused_rms_norm + args: (Tensor x, Tensor scale, float epsilon) + output: Tensor(y), Tensor(invvar) infer_meta: - func: NumberCountInferMeta + func: FusedRMSNormInferMeta kernel: - func: number_count - data_type: numbers - interfaces : paddle::dialect::InferSymbolicShapeInterface - traits : paddle::dialect::ForwardOnlyTrait + func: fused_rms_norm + data_type: x + backward: fused_rms_norm_grad - op: int_bincount args: (Tensor x, int64_t low, int64_t high, int64_t dtype) @@ -5759,12 +5758,13 @@ func: int_bincount data_type: x -- op: fused_rms_norm - args: (Tensor x, Tensor scale, float epsilon) - output: Tensor(y), Tensor(invvar) +- op: number_count + args: (Tensor numbers, int upper_range) + output: Tensor(out) infer_meta: - func: FusedRMSNormInferMeta + func: NumberCountInferMeta kernel: - func: fused_rms_norm - data_type: x - backward: fused_rms_norm_grad \ No newline at end of file + func: number_count + data_type: numbers + interfaces : paddle::dialect::InferSymbolicShapeInterface + traits : paddle::dialect::ForwardOnlyTrait diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index a408c2cd16192a..bed678f8fa606d 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +16,22 @@ from .blha_get_max_len import blha_get_max_len from .block_multihead_attention import ( block_multihead_attention, - block_multihead_attention_xpu, # noqa: F401 + block_multihead_attention_xpu, ) + +# from .moe_gate_dispatch_permute import moe_gate_dispatch_permute +from .build_src_rank_and_local_expert_id import ( + build_src_rank_and_local_expert_id, +) +from .cal_aux_loss import cal_aux_loss +from .expand_modality_expert_id import expand_modality_expert_id from .fused_bias_act import fused_bias_act from .fused_dot_product_attention import ( - cudnn_flash_attention, # noqa: F401 - fused_dot_product_attention, # noqa: F401 + cudnn_flash_attention, + fused_dot_product_attention, ) from .fused_dropout_add import fused_dropout_add -from .fused_gate_attention import fused_gate_attention # noqa: F401 +from .fused_gate_attention import fused_gate_attention from .fused_layer_norm import fused_layer_norm from .fused_matmul_bias import ( fused_linear, @@ -31,6 +39,7 @@ fused_matmul_bias, ) from .fused_rms_norm import fused_rms_norm +from .fused_rms_norm_ext import fused_rms_norm_ext from .fused_rotary_position_embedding import fused_rotary_position_embedding from .fused_transformer import ( fused_bias_dropout_residual_layer_norm, @@ -38,21 +47,18 @@ fused_multi_head_attention, fused_multi_transformer, ) +from .int_bincount import int_bincount from .masked_multihead_attention import masked_multihead_attention +from .moe_combine import moe_combine +from .moe_gate_dispatch import moe_gate_dispatch +from .moe_gate_dispatch_partial_nosoftmaxtopk import ( + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from .moe_gate_dispatch_permute import moe_gate_dispatch_permute from .swiglu import swiglu from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, ) -from .moe_combine import moe_combine -from .expand_modality_expert_id import expand_modality_expert_id -from .cal_aux_loss import cal_aux_loss -# from .moe_gate_dispatch_permute import moe_gate_dispatch_permute -from .build_src_rank_and_local_expert_id import build_src_rank_and_local_expert_id -from .int_bincount import int_bincount -from .fused_rms_norm_ext import fused_rms_norm_ext -from .moe_gate_dispatch import moe_gate_dispatch -from .moe_gate_dispatch_permute import moe_gate_dispatch_permute -from .moe_gate_dispatch_partial_nosoftmaxtopk import moe_gate_dispatch_partial_nosoftmaxtopk __all__ = [ 'fused_multi_head_attention', @@ -74,12 +80,9 @@ "swiglu", "moe_combine", "expand_modality_expert_id", - "cal_aux_loss" - "build_src_rank_and_local_expert_id" - "int_bincount", + "cal_aux_loss" "build_src_rank_and_local_expert_id" "int_bincount", "fused_rms_norm_ext", "moe_gate_dispatch", "moe_gate_dispatch_permute", "moe_gate_dispatch_partial_nosoftmaxtopk", - ] diff --git a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py index 1d6351da47602f..e91a02ef795783 100644 --- a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py +++ b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py @@ -1,7 +1,23 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations + from typing import TYPE_CHECKING + from paddle import _C_ops -import paddle + # from ....framework import LayerHelper, in_dynamic_or_pir_mode from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper @@ -9,36 +25,48 @@ if TYPE_CHECKING: from paddle import Tensor + def expand_modality_expert_id( - expert_id: Tensor, - num_expert_per_modality: int, - group_size: int, - modality_offset: int, + expert_id: Tensor, + num_expert_per_modality: int, + group_size: int, + modality_offset: int, is_group_expert: bool, - name: str | None = None + name: str | None = None, ) -> Tensor: """ Args: expert_id: - num_expert_per_modality: + num_expert_per_modality: group_size: modality_offset: is_group_expert: - + Returns: """ if in_dynamic_or_pir_mode(): - return _C_ops.expand_modality_expert_id(expert_id, num_expert_per_modality, group_size, modality_offset, is_group_expert) + return _C_ops.expand_modality_expert_id( + expert_id, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + ) helper = LayerHelper('expand_modality_expert_id', **locals()) - expert_id_out = helper.create_variable_for_type_inference(dtype=expert_id.dtype) - inputs = { - 'expert_id': expert_id - } + expert_id_out = helper.create_variable_for_type_inference( + dtype=expert_id.dtype + ) + inputs = {'expert_id': expert_id} attrs = { - 'num_expert_per_modality': num_expert_per_modality, - 'group_size': group_size, - 'modality_offset': modality_offset, - 'is_group_expert': is_group_expert + 'num_expert_per_modality': num_expert_per_modality, + 'group_size': group_size, + 'modality_offset': modality_offset, + 'is_group_expert': is_group_expert, } - helper.append_op(type='expand_modality_expert_id', inputs=inputs, attrs=attrs, outputs={'expert_id_out': expert_id_out}) + helper.append_op( + type='expand_modality_expert_id', + inputs=inputs, + attrs=attrs, + outputs={'expert_id_out': expert_id_out}, + ) return expert_id_out diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py index bb0a2d9a8245fe..61f9cce5a440f6 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py @@ -1,8 +1,23 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # File: python/paddle/incubate/nn/functional/layer_norm_cuda.py -import paddle +from paddle import _C_ops +from paddle.base.data_feeder import convert_dtype from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper -from paddle import _C_ops + def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): """ @@ -19,9 +34,7 @@ def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ if in_dynamic_or_pir_mode(): - return _C_ops.fused_rms_norm( - x,scale,epsilon - ) + return _C_ops.fused_rms_norm(x, scale, epsilon) helper = LayerHelper('fused_rms_norm', **locals()) dtype = convert_dtype(x.dtype) y = helper.create_variable_for_type_inference(dtype) @@ -33,6 +46,6 @@ def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): type='fused_rms_norm', inputs=inputs, outputs={'y': y, 'invvar': invvar}, - attrs={'epsilon': epsilon} + attrs={'epsilon': epsilon}, ) - return y, invvar \ No newline at end of file + return y, invvar diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py index 171fd29ed68483..9e444ae5992a30 100644 --- a/python/paddle/incubate/nn/functional/int_bincount.py +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -1,5 +1,19 @@ -import paddle +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from paddle import _C_ops +from paddle.base.data_feeder import convert_dtype from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper @@ -7,7 +21,7 @@ def int_bincount(x, low, high, dtype=None, name=None): if in_dynamic_or_pir_mode(): return _C_ops.int_bincount(x, low, high, dtype) - + helper = LayerHelper("int_bincount", **locals()) out_dtype = dtype if dtype is not None else x.dtype y = helper.create_variable_for_type_inference(dtype=out_dtype) @@ -18,8 +32,9 @@ def int_bincount(x, low, high, dtype=None, name=None): inputs={"x": x}, outputs={"y": y}, attrs={ - "low": low, - "high": high, - "dtype": dtype_attr, - }) - return y \ No newline at end of file + "low": low, + "high": high, + "dtype": dtype_attr, + }, + ) + return y diff --git a/python/paddle/incubate/nn/functional/moe_combine.py b/python/paddle/incubate/nn/functional/moe_combine.py index 78964c2cf9a1d6..e9e23915ce0a5e 100644 --- a/python/paddle/incubate/nn/functional/moe_combine.py +++ b/python/paddle/incubate/nn/functional/moe_combine.py @@ -1,7 +1,23 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations + from typing import TYPE_CHECKING + from paddle import _C_ops -import paddle + # from ....framework import LayerHelper, in_dynamic_or_pir_mode from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper @@ -9,15 +25,19 @@ if TYPE_CHECKING: from paddle import Tensor + def moe_combine( - x: Tensor, combine_weights: Tensor, scatter_index: Tensor, name: str | None = None + x: Tensor, + combine_weights: Tensor, + scatter_index: Tensor, + name: str | None = None, ) -> Tensor: """ Args: x: Input tensor [seq, dim] combine_weights: Combination weights [s, k] scatter_index: Scatter indices [k, s] dtype=int32 - + Returns: Output Combined output [s, dim] """ @@ -28,7 +48,7 @@ def moe_combine( inputs = { 'x': x, 'combine_weights': combine_weights, - 'scatter_index': scatter_index + 'scatter_index': scatter_index, } helper.append_op(type='moe_combine', inputs=inputs, outputs={'y': y}) - return y \ No newline at end of file + return y diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py index 07853b7a4e49fb..e8146589b1ad96 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -1,6 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations -from typing import TYPE_CHECKING, Optional -import paddle + +from typing import TYPE_CHECKING + from paddle import _C_ops from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper @@ -8,6 +23,7 @@ if TYPE_CHECKING: from paddle import Tensor + def moe_gate_dispatch_partial_nosoftmaxtopk( x: Tensor, combine_weights: Tensor, @@ -19,13 +35,26 @@ def moe_gate_dispatch_partial_nosoftmaxtopk( expert_start_index: int, expert_end_index: int, reverse_token_drop: bool, - name: str | None = None + name: str | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: if in_dynamic_or_pir_mode(): - return _C_ops.moe_gate_dispatch_partial_nosoftmaxtopk(x, combine_weights, expert_id, k, capacity, num_experts, use_pad, expert_start_index, expert_end_index, reverse_token_drop) + return _C_ops.moe_gate_dispatch_partial_nosoftmaxtopk( + x, + combine_weights, + expert_id, + k, + capacity, + num_experts, + use_pad, + expert_start_index, + expert_end_index, + reverse_token_drop, + ) helper = LayerHelper("moe_gate_dispatch_partial_nosoftmaxtopk", **locals()) y = helper.create_variable_for_type_inference(dtype=x.dtype) - combine_weights_out = helper.create_variable_for_type_inference(dtype=combine_weights.dtype) + combine_weights_out = helper.create_variable_for_type_inference( + dtype=combine_weights.dtype + ) scatter_index = helper.create_variable_for_type_inference(dtype='int32') scatter_index_rev = helper.create_variable_for_type_inference(dtype='int32') expert_offset = helper.create_variable_for_type_inference(dtype='int64') @@ -58,8 +87,14 @@ def moe_gate_dispatch_partial_nosoftmaxtopk( outputs=outputs, attrs=attrs, ) - return y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local - + return ( + y, + combine_weights_out, + scatter_index, + scatter_index_rev, + expert_offset, + expert_nums_local, + ) # import paddle @@ -101,7 +136,7 @@ def moe_gate_dispatch_partial_nosoftmaxtopk( # expert_start_index=expert_start_index, # expert_end_index=expert_end_index, # reverse_token_drop=reverse_token_drop -# ) +# ) # # 打印结果 # print("y:", y.numpy()) @@ -114,4 +149,4 @@ def moe_gate_dispatch_partial_nosoftmaxtopk( # a = paddle.sum(y)+paddle.sum(combine_weights_out) # a.backward() # print("\n##########backward output##########\n") -# print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n expert_id.grad: {expert_id.grad}") \ No newline at end of file +# print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n expert_id.grad: {expert_id.grad}") diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py index 0874e2a9d71d79..23e762a4421805 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py @@ -1,6 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations -from typing import TYPE_CHECKING, Optional -import paddle + +from typing import TYPE_CHECKING + from paddle import _C_ops from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper @@ -8,14 +23,15 @@ if TYPE_CHECKING: from paddle import Tensor + def moe_gate_dispatch_permute( - x: Tensor, - gate_logits: Tensor, - corr_bias: Tensor, - k: int, - capacity: int, + x: Tensor, + gate_logits: Tensor, + corr_bias: Tensor, + k: int, + capacity: int, world_size: int, - name: str | None = None + name: str | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Dispatch and permute for Mixture of Experts (MoE). @@ -38,7 +54,9 @@ def moe_gate_dispatch_permute( - expert_id: IDs of selected experts for each position. """ if in_dynamic_or_pir_mode(): - return _C_ops.moe_gate_dispatch_permute(x, gate_logits, corr_bias, k, capacity, world_size) + return _C_ops.moe_gate_dispatch_permute( + x, gate_logits, corr_bias, k, capacity, world_size + ) helper = LayerHelper('moe_gate_dispatch_permute', **locals()) y = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -50,7 +68,7 @@ def moe_gate_dispatch_permute( inputs = { 'x': x, 'gate_logits': gate_logits, - 'corr_bias': corr_bias if corr_bias is not None else None + 'corr_bias': corr_bias if corr_bias is not None else None, } attrs = {'k': k, 'capacity': capacity, 'world_size': world_size} outputs = { @@ -58,12 +76,18 @@ def moe_gate_dispatch_permute( 'combine_weights': combine_weights, 'scatter_index': scatter_index, 'expert_offset': expert_offset, - 'expert_id': expert_id + 'expert_id': expert_id, } - helper.append_op(type='moe_gate_dispatch_permute', inputs=inputs, outputs=outputs, attrs=attrs) + helper.append_op( + type='moe_gate_dispatch_permute', + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) return y, combine_weights, scatter_index, expert_offset, expert_id + # # 定义输入参数 # num_rows = 10 # 示例行数 # hidden_size = 128 # 隐藏层维度 @@ -87,11 +111,11 @@ def moe_gate_dispatch_permute( # # 调用封装的API # y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( -# x=x, -# gate_logits=gate_logits, -# corr_bias=corr_bias, -# k=k, -# capacity=capacity, +# x=x, +# gate_logits=gate_logits, +# corr_bias=corr_bias, +# k=k, +# capacity=capacity, # world_size=world_size # ) @@ -107,4 +131,4 @@ def moe_gate_dispatch_permute( # print("Gradient of x:", x.grad) # print("Gradient of gate_logits:", gate_logits.grad) -# print("Gradient of corr_bias:", corr_bias.grad) \ No newline at end of file +# print("Gradient of corr_bias:", corr_bias.grad) diff --git a/test/legacy_test/ernie_utils/moe_all_gather_layer.py b/test/legacy_test/ernie_utils/moe_all_gather_layer.py index 87eecb45b97ff0..2f5e4ef911689a 100644 --- a/test/legacy_test/ernie_utils/moe_all_gather_layer.py +++ b/test/legacy_test/ernie_utils/moe_all_gather_layer.py @@ -1,5 +1,20 @@ -# -*- coding: utf-8 -*- +# ruff: noqa: FA100 # !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ @author: kebo @contact: kebo01@baidu.com @@ -13,29 +28,16 @@ """ -from typing import Any, Tuple, List, Dict, Optional, Callable -import itertools -from collections import defaultdict -import logging import contextlib -import numpy as np -import inspect +import logging +from typing import List, Optional import paddle -import paddle.distributed as dist -from paddle.distributed import fleet -from paddle import framework -import paddle.nn.functional as F from paddle import nn -from paddle.autograd import PyLayer -from paddle.distributed.communication.group import _get_global_group -from paddle.distributed.fleet.utils import recompute from paddle.distributed.communication.group import Group - -from .top2_gate import TopKGateFused, compute_optimal_transport -from paddle.incubate.tensor.manipulation import async_offload, async_reload from paddle.incubate.nn.functional import expand_modality_expert_id -from .moe_layer import MOELayer, fuse_logging + +from .moe_layer import MOELayer try: from src.utils.misc import global_training_logs @@ -98,7 +100,7 @@ def __init__( group_experts, moe_statics, ) - + class MOEAllGatherLayerV2(MOEAllGatherLayer): """_summary_ @@ -125,7 +127,7 @@ def __init__( use_expert_out_alltoall=True, # use_expert_alltoall_overlap=False, use_padding=True, - dense_token_type=3, # considerd as dense tokens (no moe) + dense_token_type=3, # considered as dense tokens (no moe) moe_statics=None, ): super().__init__( @@ -158,22 +160,28 @@ def __init__( self.use_expert_out_alltoall = use_expert_out_alltoall self.use_expert_alltoall_overlap = use_expert_alltoall_overlap logger.info( - f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " + f"using MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " f"enable_reverse_token_drop={self.enable_reverse_token_drop}" ) self.two = paddle.to_tensor(2, dtype=paddle.float32) self.zero = paddle.to_tensor(0, dtype=paddle.float32) - - def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_type_ids): + + def fused_gate_logits_process_fused( + self, gate_logits_lm, gate_logits_mm, token_type_ids + ): """process gatelogits w/ moe utils""" - #top_k = 1 if isinstance(self.gate, SinkHornGateFused) else self.k + # top_k = 1 if isinstance(self.gate, SinkHornGateFused) else self.k top_k = self.k - num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] // self.config.moe_world_size + num_expert_per_rank_per_modality = ( + gate_logits_lm.shape[-1] // self.config.moe_world_size + ) group_size = gate_logits_lm.shape[-1] // top_k if self.group_experts: assert not self.use_correction_bias - gate_logits_lm = gate_logits_lm.reshape([gate_logits_lm.shape[0], top_k, -1]) + gate_logits_lm = gate_logits_lm.reshape( + [gate_logits_lm.shape[0], top_k, -1] + ) prob_lm = self.gate.act(gate_logits_lm) prob_lm_ = prob_lm weight_lm, expert_id_lm = prob_lm_.topk(k=1, axis=-1) @@ -183,38 +191,59 @@ def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_ else: prob_lm = self.gate.act(gate_logits_lm) if self.use_correction_bias: - prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[0].detach() + prob_lm_ = ( + prob_lm + + self.moe_statics.e_score_correction_bias[0].detach() + ) else: prob_lm_ = prob_lm weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, axis=-1) if self.use_correction_bias: - batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + batch_idx = ( + paddle.arange(prob_lm_.shape[0]) + .unsqueeze(-1) + .expand_as(expert_id_lm) + ) weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias # num_expert_per_modality == 0 时只执行 group-expert expand,不执行 multimodal-expand expert_id_lm = expand_modality_expert_id( expert_id_lm, - num_expert_per_modality=num_expert_per_rank_per_modality - if (token_type_ids is not None and gate_logits_mm is not None) - else 0, + num_expert_per_modality=( + num_expert_per_rank_per_modality + if (token_type_ids is not None and gate_logits_mm is not None) + else 0 + ), group_size=group_size, modality_offset=0, is_group_expert=self.group_experts, ) expert_id_lm = expert_id_lm.reshape(weight_lm.shape) - lm_weight_and_expert_id = paddle.concat([weight_lm, expert_id_lm.astype("float32")], -1) + lm_weight_and_expert_id = paddle.concat( + [weight_lm, expert_id_lm.astype("float32")], -1 + ) if token_type_ids is None or gate_logits_mm is None: - return lm_weight_and_expert_id, prob_lm.reshape([prob_lm.shape[0], -1]), None + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) prob_mm = self.gate.act(gate_logits_mm) if self.use_correction_bias: - prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[1].detach() + prob_mm_ = ( + prob_mm + self.moe_statics.e_score_correction_bias[1].detach() + ) else: prob_mm_ = prob_mm weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, axis=-1) if self.use_correction_bias: - batch_idx = paddle.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + batch_idx = ( + paddle.arange(prob_lm_.shape[0]) + .unsqueeze(-1) + .expand_as(expert_id_lm) + ) weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias expert_id_mm = expand_modality_expert_id( @@ -225,12 +254,16 @@ def fused_gate_logits_process_fused(self, gate_logits_lm, gate_logits_mm, token_ is_group_expert=False, ) expert_id_mm = expert_id_mm.reshape(weight_mm.shape) - mm_weight_and_expert_id = paddle.concat([weight_mm, expert_id_mm.astype("float32")], -1) + mm_weight_and_expert_id = paddle.concat( + [weight_mm, expert_id_mm.astype("float32")], -1 + ) weight_and_expert = paddle.where( (token_type_ids == 0).unsqueeze(-1), lm_weight_and_expert_id, mm_weight_and_expert_id, ) - return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm - - + return ( + weight_and_expert, + prob_lm.reshape([prob_lm.shape[0], -1]), + prob_mm, + ) diff --git a/test/legacy_test/ernie_utils/moe_layer.py b/test/legacy_test/ernie_utils/moe_layer.py index bf02a5a99c5ab4..25f5007e7461d0 100644 --- a/test/legacy_test/ernie_utils/moe_layer.py +++ b/test/legacy_test/ernie_utils/moe_layer.py @@ -1,31 +1,34 @@ +# ruff: noqa: FA100 # !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """_summary_ Returns: _type_: _description_ """ -from typing import Any, Tuple, List, Optional, Callable import logging from collections import namedtuple -from functools import partial -import inspect -import numpy as np +from typing import List, Optional import paddle -from paddle import framework +import paddle.distributed as dist from paddle import nn -from paddle.distributed.communication import stream -import paddle.nn.functional as F - -from paddle.autograd import PyLayer -from paddle.distributed.communication.group import Group -from paddle.distributed.fleet.utils import recompute from paddle.distributed import fleet - -import paddle.distributed as dist -from paddle import Tensor - - +from paddle.distributed.communication.group import Group try: from src.utils.misc import global_training_logs @@ -57,7 +60,10 @@ def in_auto_parallel_align_mode(): import moe_ops except ImportError: moe_ops = None - logger.warning("`moe-ops` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") + logger.warning( + "`moe-ops` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) GateOutput = namedtuple( "GateOutput", @@ -68,6 +74,7 @@ def in_auto_parallel_align_mode(): ], ) + class MOELayer(nn.Layer): """MOELayer module which implements MixtureOfExperts as described in Gshard_. :: @@ -139,11 +146,15 @@ def __init__( self.use_correction_bias = moe_statics is not None self.moe_statics = moe_statics if self.use_correction_bias: - logger.info(f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}") + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) assert self.gate.config.moe_use_aux_free self.is_mp_moe = ( - hasattr(fleet.fleet, "_hcg") and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + hasattr(fleet.fleet, "_hcg") + and group + is fleet.get_hybrid_communicate_group().get_model_parallel_group() ) is_dummy_moe = dist.get_world_size(group) == 1 @@ -163,29 +174,45 @@ def __init__( self.rank = 0 self.num_local_experts = len(self.experts) - self.dispatch_by_task = hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") + and self.gate.dispatch_by_task + ) if self.dispatch_by_task: - assert 0, f"no supported, checkout earylier code" + assert 0, "no supported, checkout earylier code" assert self.num_local_experts == 1 + ''' dummy skip if enable_bpr: - logger.info(f"using BPR") + logger.info("using BPR") prepost_process_buffer = {} - self.input_preprocess = partial(bpr_preprocess, buffer=prepost_process_buffer) - self.output_postprocess = partial(bpr_postprocess, buffer=prepost_process_buffer) + self.input_preprocess = partial( + bpr_preprocess, buffer=prepost_process_buffer + ) + self.output_postprocess = partial( + bpr_postprocess, buffer=prepost_process_buffer + ) else: self.input_preprocess = self.output_postprocess = None + ''' + self.input_preprocess = self.output_postprocess = None self.group_experts = group_experts self.config = self.gate.config self.zero = paddle.to_tensor(0, dtype=paddle.float32) self._rr_moe_gate_dispatch = None self._rr_moe_combine = None - if self.config.use_recompute and self.config.skip_recompute_ops.get("moe_gate_dispatch", False): + ''' dummy skip + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_gate_dispatch", False + ): self._rr_moe_gate_dispatch = RefinedRcomputeMoEGateDispatch() - if self.config.use_recompute and self.config.skip_recompute_ops.get("moe_combine", False): + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_combine", False + ): self._rr_moe_combine = RefinedRcomputeMoECombine() + ''' def fuse_logging(gate_logits, combine_weights, token_type_ids): @@ -199,9 +226,17 @@ def fuse_logging(gate_logits, combine_weights, token_type_ids): gate_expert_per_token_type_0, gate_expert_per_token_type_1, gate_experts_per_token, - ) = moe_router_loss_ops.cal_gate_experts_per_token_info(combine_weights, token_type_ids) + ) = moe_router_loss_ops.cal_gate_experts_per_token_info( + combine_weights, token_type_ids + ) else: - gate_experts_per_token = paddle.count_nonzero(combine_weights) / (gate_logits.shape[0]) - - return gate_expert_per_token_type_0, gate_expert_per_token_type_1, gate_experts_per_token, ce - + gate_experts_per_token = paddle.count_nonzero(combine_weights) / ( + gate_logits.shape[0] + ) + + return ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ce, + ) diff --git a/test/legacy_test/ernie_utils/moe_layer_uneven.py b/test/legacy_test/ernie_utils/moe_layer_uneven.py index a6be1bbb84ba5e..4bdf42c377d75c 100644 --- a/test/legacy_test/ernie_utils/moe_layer_uneven.py +++ b/test/legacy_test/ernie_utils/moe_layer_uneven.py @@ -1,33 +1,34 @@ # !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ moe """ -from ast import Import -from operator import le -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast, List -import logging -import sys import inspect -from collections import defaultdict, namedtuple, Counter +import logging +from collections import namedtuple -import numpy as np import paddle from paddle import _C_ops -from paddle import nn -from paddle.distributed.communication import stream - from paddle.autograd import PyLayer -from paddle.distributed.communication.group import Group -from paddle.distributed.fleet.utils import recompute -import paddle.distributed as dist -from paddle import Tensor -from paddle.nn import functional as F -from paddle.distributed import fleet - # from ernie_core.models.moe.moe_layer import _AllToAll -from .top2_gate import TopKGateFused +from paddle.incubate.nn.functional import moe_gate_dispatch +from paddle.nn import functional as F try: from src.utils.misc import global_training_logs @@ -48,8 +49,10 @@ if False: try: - from paddle_xpu_nn import moe_combine as xpu_moe_combine - from paddle_xpu_nn import moe_combine_bwd as xpu_moe_combine_bwd + from paddle_xpu_nn import ( + moe_combine as xpu_moe_combine, + moe_combine_bwd as xpu_moe_combine_bwd, + ) except ImportError: xpu_moe_combine = None xpu_moe_combine_bwd = None @@ -59,9 +62,10 @@ from paddle.incubate.nn.functional import moe_combine except ImportError: moe_combine = None - logger.warning("`moe-combine` not found, run " "`python3 src/ernie_core/ops/moe/setup.py install` to install") - - + logger.warning( + "`moe-combine` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) def average_grad(x, y, dy, eps=1e-12): @@ -74,10 +78,14 @@ def average_grad(x, y, dy, eps=1e-12): maskpos = (xsum == 0.0).expand_as(x) xsum_square = xsum.square() # [s,1] - left = paddle.triu(paddle.tril((1 / xsum).unsqueeze(-1).expand([s, k, k]))) # aka diag-emb [s,k,k] + left = paddle.triu( + paddle.tril((1 / xsum).unsqueeze(-1).expand([s, k, k])) + ) # aka diag-emb [s,k,k] right = (-x / xsum_square).unsqueeze(-1).expand([s, k, k]) dydx = left + right - dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze(-2) # [s,1,k] @[s,k,k] -> [s,1,k] + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze( + -2 + ) # [s,1,k] @[s,k,k] -> [s,1,k] dx = paddle.where(maskpos, paddle.zeros_like(dx), dx) return dx @@ -100,12 +108,18 @@ def average_grad_bi(x, y, dy, eps=1e-12): s, k = x.shape assert k == 2, k xsum = paddle.clip(x.sum(axis=-1, keepdim=True), min=eps) # [s,1] - dydx = x.flip(axis=1).unsqueeze(-2).tile([1, 2, 1]) * mask.cast(x.dtype) / xsum.square().unsqueeze(-1) - dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze(-2) # [s,1,k] @[s,k,k] -> [s,1,k] + dydx = ( + x.flip(axis=1).unsqueeze(-2).tile([1, 2, 1]) + * mask.cast(x.dtype) + / xsum.square().unsqueeze(-1) + ) + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze( + -2 + ) # [s,1,k] @[s,k,k] -> [s,1,k] return dx -def topk_grad(x, dy, indicies): +def topk_grad(x, dy, indices): """ TODO: fuse 这坨 shit y=gather(topk(x)) 的反向过程 @@ -117,8 +131,8 @@ def topk_grad(x, dy, indicies): dx = paddle.scatter_nd( paddle.stack( [ - paddle.arange(s).repeat_interleave(k).cast(indicies.dtype), - indicies.reshape([-1]), + paddle.arange(s).repeat_interleave(k).cast(indices.dtype), + indices.reshape([-1]), ], -1, ), @@ -153,12 +167,19 @@ def forward(ctx, x, gate_prob, k, capacity, use_pad, eps=1e-12): ctx.eps = eps ctx.capacity = capacity ctx.gate_prob = gate_prob - if "corr_bias" in inspect.signature(moe_ops.moe_gate_dispatch).parameters: + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: compat_args = (None,) else: compat_args = () - y, combine_weights, scatter_index, expert_offset, expert_id = moe_ops.moe_gate_dispatch( - x, gate_prob, *compat_args, k=k, capacity=capacity, use_pad=use_pad + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch( + x, + gate_prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=use_pad, + ) ) ctx.combine_weights = combine_weights scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] @@ -182,7 +203,9 @@ def backward(ctx, dy, dw, *_): s, k = ctx.combine_weights.shape grad = F.embedding(ctx.scatter_index, dy) # [s, k,d] mask = (ctx.combine_weights > 0.0).astype(grad.dtype) # [s,k] - dx = paddle.matmul(mask.unsqueeze(1), grad).squeeze(1) # [s,1,k] @ [s,k,d] -> [s,1,d] + dx = paddle.matmul(mask.unsqueeze(1), grad).squeeze( + 1 + ) # [s,1,k] @ [s,k,d] -> [s,1,d] if ctx.gate_prob.stop_gradient: return dx, None @@ -243,10 +266,12 @@ def backward(ctx, grad_y, *_): # reduce the hidden shape # TODO: implement reduce in cuda ops grad_combine_weight = grad_combine_weight_helper.sum(-1) - return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None - #return grad_x, grad_combine_weight_helper - - + return ( + grad_x, + grad_combine_weight.reshape(ctx.combine_weights.shape), + None, + ) + # return grad_x, grad_combine_weight_helper def combining(x, combine_weights, scatter_index, hard_gate=False): @@ -265,4 +290,3 @@ def combining(x, combine_weights, scatter_index, hard_gate=False): ret = GateCombine.apply(x, combine_weights, scatter_index) ret.stop_gradient = False return ret - diff --git a/test/legacy_test/ernie_utils/top2_gate.py b/test/legacy_test/ernie_utils/top2_gate.py index 5c3186b92b5785..1121e7625e049a 100644 --- a/test/legacy_test/ernie_utils/top2_gate.py +++ b/test/legacy_test/ernie_utils/top2_gate.py @@ -1,24 +1,38 @@ +# ruff: noqa: FA100 # !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ top2gate """ -from typing import Tuple -from functools import partial import logging +from functools import partial +from typing import Tuple + import numpy as np -import math + import paddle -from paddle import Tensor import paddle.distributed as dist import paddle.nn.functional as F -from paddle import nn -from paddle.utils import unique_name -from paddle.nn.clip import _squared_l2_norm +from paddle import Tensor, nn from paddle.distributed import fleet from paddle.incubate.nn.functional import cal_aux_loss - +from paddle.utils import unique_name try: from src.utils.misc import global_training_logs @@ -42,22 +56,33 @@ logger = logging.getLogger(__name__) - - - - class CalAuxLossFunctor(paddle.autograd.PyLayer): """CalAuxLossFunctor""" @staticmethod def forward( - ctx, gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k, clip_min=1e-6 + ctx, + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min=1e-6, ): """forward""" if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: tokens_mask = tokens_mask.astype(gate_prob.dtype) loss, seqlen_float, ce = cal_aux_loss( - gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k, clip_min + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, ) ''' ctx.save_for_backward(gate_prob, seqlen_float, ce) @@ -101,7 +126,10 @@ def cal_aux_loss_func( scale = None if dispatch_tokens_mask is not None: seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() - if tokens_mask is not None and gate_prob.shape[0] != dispatch_tokens_mask.shape[0]: + if ( + tokens_mask is not None + and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] + ): scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) elif tokens_mask is not None: seqlen_float = tokens_mask.sum() @@ -150,7 +178,9 @@ def masked_fill(x, mask, value): @paddle.no_grad() -def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): +def compute_optimal_transport( + M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10 +): """ Computes the optimal transport matrix and Slinkhorn distance using the Sinkhorn-Knopp algorithm @@ -201,7 +231,9 @@ def forward(ctx, x, w): """ ctx.dtype = paddle.float32 ctx.save_for_backward(x, w) - return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + return F.linear( + cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype) + ) @staticmethod def backward(ctx, y_grad): @@ -210,7 +242,13 @@ def backward(ctx, y_grad): """ x, w = ctx.saved_tensor() assert ctx.dtype == y_grad.dtype, "dtype not match" - x_g, w_g = matmul_bwd(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype), y_grad, False, False) + x_g, w_g = matmul_bwd( + cast_if_needed(x, ctx.dtype), + cast_if_needed(w, ctx.dtype), + y_grad, + False, + False, + ) return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) @@ -269,7 +307,9 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: self.model_dim = config.hidden_size self.num_experts = config.moe_num_experts self.num_experts_tensor = ( - sum(config.moe_num_experts) if config.multimodel_experts else config.moe_num_experts + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts ) # paddle.to_tensor(config.moe_num_experts, dtype="float32").sum() self.cap = config.moe_capacity @@ -299,15 +339,23 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: self.norm_gate_logits = config.moe_norm_gate_logits self.one = paddle.ones([], dtype="float32") - self.moe_aux_loss_lambda = paddle.to_tensor(config.moe_aux_loss_lambda, dtype="float32") - self.moe_z_loss_lambda = paddle.to_tensor(config.moe_z_loss_lambda, dtype="float32") - self.moe_orthogonal_loss_lambda = paddle.to_tensor(config.moe_orthogonal_loss_lambda, dtype="float32") + self.moe_aux_loss_lambda = paddle.to_tensor( + config.moe_aux_loss_lambda, dtype="float32" + ) + self.moe_z_loss_lambda = paddle.to_tensor( + config.moe_z_loss_lambda, dtype="float32" + ) + self.moe_orthogonal_loss_lambda = paddle.to_tensor( + config.moe_orthogonal_loss_lambda, dtype="float32" + ) if self.moe_aux_loss_lambda.ndim == 0: self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) if self.moe_z_loss_lambda.ndim == 0: self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) if self.moe_orthogonal_loss_lambda.ndim == 0: - self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze(0) + self.moe_orthogonal_loss_lambda = ( + self.moe_orthogonal_loss_lambda.unsqueeze(0) + ) self.experts_type_ids = None if config.moe_orthogonal_loss_lambda: @@ -316,8 +364,9 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: sharding_configs = strategy.hybrid_configs["sharding_configs"] pp_config = strategy.hybrid_configs["pp_configs"] assert ( - not sharding_configs.comm_overlap and not pp_config.sharding_comm_overlap - ), f"orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + not sharding_configs.comm_overlap + and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" self.eps = paddle.to_tensor([1e-12], dtype="float32") if config.multimodel_experts: @@ -325,13 +374,19 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: self.num_experts_list = [] self.experts_type_mask = [] # hard-gate + group_experts 需要对gate_logits不同部分分开计算 - experts_ids = paddle.zeros([sum(self.num_experts)], dtype="int64").reshape([config.moe_world_size, -1]) + experts_ids = paddle.zeros( + [sum(self.num_experts)], dtype="int64" + ).reshape([config.moe_world_size, -1]) offset = 0 for i, expert_num in enumerate(self.num_experts): - experts_ids[:, offset : offset + expert_num // config.moe_world_size] = i + experts_ids[ + :, offset : offset + expert_num // config.moe_world_size + ] = i offset += expert_num // config.moe_world_size self.experts_type_ids = experts_ids.reshape([-1]) - logger.info(f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}") + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) for i, expert_num in enumerate(self.num_experts): self.experts_type_mask.append( self.experts_type_ids == i, @@ -339,7 +394,9 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: self.num_experts_list.append(expert_num) else: # 非group_experts, 依赖token_type_bias实现hard-gate能力。 - assert not config.moe_group_experts, "group_experts must use hard_gate when multimodel_experts is True" + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" else: self.num_experts_list = [self.num_experts] if gate_weight is not None: @@ -350,7 +407,7 @@ def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: logger.info("moe use gate_weight from outside") # 强制在amp下任使用fp32精度 self._cast_to_low_precision = False # 兼容develop分支paddle - self._cast_to_low_precison = False + self._cast_to_low_precision = False else: self._create_gate_parameter() logger.info( @@ -372,34 +429,50 @@ def _create_gate_parameter(self): """ if self.config.multimodel_experts: # support setting lambda for each expert group - self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand(len(self.num_experts)) - self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand(len(self.num_experts)) - self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand(len(self.num_experts)) + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = ( + self.moe_orthogonal_loss_lambda.expand(len(self.num_experts)) + ) for i, num_experts in enumerate(self.num_experts): if i == 1: - with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + with paddle.utils.unique_name.guard( + f"mm_gate_{self.layer_idx}_" + ): p = self.create_parameter( shape=[self.model_dim, num_experts], dtype="float32", - attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), ) else: p = self.create_parameter( shape=[self.model_dim, num_experts], dtype="float32", - attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), ) p.expert_type = f"expert_type_{i}" self.add_parameter( - "weight" if i == 0 else f"weight_{i}", # 为了对齐原 state-dict,第一个 gate-weight 不改名. + ( + "weight" if i == 0 else f"weight_{i}" + ), # 为了对齐原 state-dict,第一个 gate-weight 不改名. p, ) else: self.weight = self.create_parameter( shape=[self.model_dim, self.num_experts], dtype="float32", - attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), # 特殊处理,有利于热启 dense-ckpt + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), # 特殊处理,有利于热启 dense-ckpt ) logger.info(f"moe-Gate, {self.weight}") @@ -408,20 +481,28 @@ def _create_gate_parameter(self): assert ( not self.config.moe_use_hard_gate ), "multimodel_experts with hard_gate is not support token_type_bias." - num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts - bias_type_num = len(self.num_experts) if self.config.multimodel_experts else 1 + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + bias_type_num = ( + len(self.num_experts) if self.config.multimodel_experts else 1 + ) self.bias = self.create_parameter( shape=[bias_type_num, num_experts], dtype="float32", attr=paddle.ParamAttr( name=unique_name.generate("moe_gate_bias"), - initializer=paddle.nn.initializer.Assign(np.zeros([bias_type_num, num_experts])), + initializer=paddle.nn.initializer.Assign( + np.zeros([bias_type_num, num_experts]) + ), ), # 特殊处理,有利于热启 dense-ckpt ) logger.info(f"using token type bias, bias: {self.bias},") # 强制在amp下任使用fp32精度 self._cast_to_low_precision = False # 兼容develop分支paddle - self._cast_to_low_precison = False + self._cast_to_low_precision = False def get_gate_weight(self, transform_weight): """ @@ -432,7 +513,11 @@ def get_gate_weight(self, transform_weight): return self.weight if not transform_weight: return paddle.concat( - [getattr(self, "weight" if i == 0 else f"weight_{i}") for i in range(len(self.num_experts))], -1 + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, ) weight = paddle.zeros( [ @@ -444,9 +529,13 @@ def get_gate_weight(self, transform_weight): ) offset = 0 for i, num_experts in enumerate(self.num_experts): - weight[:, :, offset : offset + num_experts // self.config.moe_world_size] = getattr( - self, "weight" if i == 0 else f"weight_{i}" - ).reshape([self.model_dim, self.config.moe_world_size, -1]) + weight[ + :, + :, + offset : offset + num_experts // self.config.moe_world_size, + ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( + [self.model_dim, self.config.moe_world_size, -1] + ) offset += num_experts // self.config.moe_world_size weight = weight.reshape([self.model_dim, -1]) @@ -464,12 +553,16 @@ def forward( input: paddle.Tensor[Seq, Dim], hidden-states of layer token_type_ids: paddle.Tensor[Seqw], token_type_ids of input transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified - Retruns: + Returns: paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask Tuple[paddle.Tensor]: `GateOutput` """ - num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) orig_dtype = input.dtype weight = self.get_gate_weight(transform_weight) with paddle.amp.auto_cast(False): @@ -482,7 +575,9 @@ def forward( training=self.training, ) else: - logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) if self.use_token_type_bias: assert token_type_ids is not None @@ -506,13 +601,24 @@ def forward( router_loss.stop_gradient = False combine_weights = combine_weights.cast(orig_dtype) - return capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits + return ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + logits, + ) def get_capacity(self, num_tokens, cap_factor=None): """ - return capcity + return capacity """ - num_experts = sum(self.num_experts) if self.config.multimodel_experts else self.num_experts + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) if cap_factor is not None: cap = cap_factor else: @@ -524,7 +630,9 @@ def get_capacity(self, num_tokens, cap_factor=None): cap = self.cap[1] # capacity = 2S/E capacity = int(cap * num_tokens // num_experts) - assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + assert ( + capacity > 0 + ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" return capacity def top2_gating(self, logits, cap=None, correction_bias=None): @@ -555,11 +663,19 @@ def top2_gating(self, logits, cap=None, correction_bias=None): capacity = self.get_capacity(logits.shape[0], cap) # Create a mask for 1st's expert per token - score_for_argmax = gates + correction_bias.unsqueeze(0) if correction_bias is not None else gates + score_for_argmax = ( + gates + correction_bias.unsqueeze(0) + if correction_bias is not None + else gates + ) indices1_s = paddle.argmax(score_for_argmax, axis=1) - mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast(paddle.int64) # [0,1] + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast( + paddle.int64 + ) # [0,1] - l_aux = self._cal_aux_loss(gates, mask1.sum(axis=0), self.num_experts_tensor) + l_aux = self._cal_aux_loss( + gates, mask1.sum(axis=0), self.num_experts_tensor + ) # Create a mask for 2nd's expert per token using Gumbel-max trick # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ if self.training and not self.no_jitter: @@ -574,9 +690,13 @@ def top2_gating(self, logits, cap=None, correction_bias=None): else: logits_w_noise = logits - logits_except1 = masked_fill(logits_w_noise, mask1.cast(paddle.bool), float("-inf")) + logits_except1 = masked_fill( + logits_w_noise, mask1.cast(paddle.bool), float("-inf") + ) score_for_argmax = ( - self.act(logits_except1) + correction_bias.unsqueeze(0) if correction_bias is not None else logits_except1 + self.act(logits_except1) + correction_bias.unsqueeze(0) + if correction_bias is not None + else logits_except1 ) indices2_s_original = paddle.argmax(score_for_argmax, axis=1) @@ -588,17 +708,25 @@ def top2_gating(self, logits, cap=None, correction_bias=None): c = paddle.maximum(c, paddle.zeros_like(c)) c /= c.sum() - pi, _ = compute_optimal_transport(-logits_except1.cast("float32").detach(), r, c, lam=self.sinkhorn_temp) + pi, _ = compute_optimal_transport( + -logits_except1.cast("float32").detach(), + r, + c, + lam=self.sinkhorn_temp, + ) pi = masked_fill(pi, mask1.cast(paddle.bool), float("-inf")) indices2_s = paddle.argmax(pi, axis=1) else: indices2_s = indices2_s_original - - mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast(paddle.int64) + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast( + paddle.int64 + ) # Compute locations in capacity buffer - locations1 = paddle.cumsum(mask1, axis=0) - 1 # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] + locations1 = ( + paddle.cumsum(mask1, axis=0) - 1 + ) # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] locations2 = paddle.cumsum(mask2, axis=0) - 1 # Update 2nd's location by accounting for locations of 1st locations2 += paddle.sum(mask1, axis=0, keepdim=True) @@ -659,7 +787,13 @@ def top2_gating(self, logits, cap=None, correction_bias=None): ) def _cal_aux_loss( - self, gate_prob, dispatch_mask, num_experts=None, use_group=None, tokens_mask=None, dispatch_tokens_mask=None + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, ): """ 计算辅助损失 @@ -680,16 +814,24 @@ def _cal_aux_loss( if tokens_mask is not None: gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] if gate_prob_this_modality.shape[0]: - _, top_idx = gate_prob_this_modality.topk(k=self.config.moe_k, axis=-1) + _, top_idx = gate_prob_this_modality.topk( + k=self.config.moe_k, axis=-1 + ) if int_bincount is not None: - dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) else: - mask = paddle.zeros_like(gate_prob_this_modality).put_along_axis( - top_idx, paddle.to_tensor(1.0), axis=1 + mask = paddle.zeros_like( + gate_prob_this_modality + ).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + dispatch_mask = paddle.sum( + mask.cast(paddle.int64), axis=0 ) - dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) else: - dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dispatch_mask = paddle.zeros( + gate_prob.shape[-1], dtype="int64" + ) dist.stream.all_reduce( dispatch_mask, group=self.group, @@ -698,9 +840,13 @@ def _cal_aux_loss( else: _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) if int_bincount is not None: - dispatch_mask = int_bincount(top_idx, 0, gate_prob.shape[-1], paddle.int64) + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) else: - mask = paddle.zeros_like(gate_prob).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + mask = paddle.zeros_like(gate_prob).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) if num_experts is None: @@ -711,7 +857,10 @@ def _cal_aux_loss( if ( moe_router_loss_ops is not None and (tokens_mask is None or len(tokens_mask.shape) == 1) - and (tokens_mask is None or tokens_mask.shape[0] == gate_prob.shape[0]) + and ( + tokens_mask is None + or tokens_mask.shape[0] == gate_prob.shape[0] + ) and (gate_prob.shape[0] >= gate_prob.shape[1]) and (not self.global_aux_loss) and (gate_prob.dtype == paddle.float32) @@ -741,7 +890,6 @@ def _cal_aux_loss( ) - class TopKGateFused(Top2Gate): """doc""" @@ -756,7 +904,7 @@ def forward( input: paddle.Tensor, hidden-states of layer token_type_ids: paddle.Tensor[Seqw], token_type_ids of input transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified - Retruns: + Returns: paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask Tuple[paddle.Tensor]: `GateOutput` @@ -773,7 +921,9 @@ def forward( training=self.training, ) else: - logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) if self.use_token_type_bias: assert token_type_ids is not None assert ( @@ -786,19 +936,26 @@ def forward( router_loss = paddle.zeros([1], dtype="float32") router_loss.stop_gradient = False - return logits, capacity, router_loss class DeepEPTop2Gate(TopKGateFused): """DeepEPTop2Gate""" - def forward(self, input, transform_weight=True, global_gate_mask=None, input_ids=None): + def forward( + self, + input, + transform_weight=True, + global_gate_mask=None, + input_ids=None, + ): """forward""" weight = self.get_gate_weight(transform_weight) with paddle.amp.auto_cast(False): - logits = gate_detach_matmul(input, weight, self.fuse_gate_detach_matmul) + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) if global_gate_mask is not None: logits = logits + global_gate_mask @@ -820,10 +977,14 @@ def _cal_aux_loss(self, gates, dispatch_mask, input_ids=None): paddle.Tensor: The value of auxiliary loss. """ - assert len(gates.shape) == 2, "gates.shape must be [sequence_lengh, num_experts]" + assert ( + len(gates.shape) == 2 + ), "gates.shape must be [sequence_length, num_experts]" if input_ids is not None: # has_padding = (input_ids == 0).any() - assert input_ids.shape[0] == gates.shape[0], f"check input_ids shape {input_ids.shape}" + assert ( + input_ids.shape[0] == gates.shape[0] + ), f"check input_ids shape {input_ids.shape}" valid_mask = (input_ids != 0).astype(paddle.float32) seqlen_float = valid_mask.sum().item() gates = gates * valid_mask.unsqueeze(-1) @@ -866,5 +1027,9 @@ def _cal_orthogonal_loss(self) -> paddle.Tensor: Paddle.Tensor: orthogonal loss """ weight = F.normalize(self.weight, axis=0) - orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + orthogonal_loss = paddle.mean( + paddle.square( + paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts) + ) + ) return orthogonal_loss diff --git a/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py index 4b89731112bbf5..e0a35a3f85233d 100644 --- a/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py +++ b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py @@ -1,51 +1,62 @@ -import os -import unittest +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from op_test import convert_float_to_uint16 -import random -import paddle.nn.functional as F +import logging +import unittest -import paddle import numpy as np -import random -import logging import paddle -from paddle.nn.clip import _squared_l2_norm - -from ernie_utils.top2_gate import ( - CalAuxLossFunctor, - cal_aux_loss_func, -) from paddle.incubate.nn.functional import build_src_rank_and_local_expert_id -from ernie_utils.moe_layer import fuse_logging logger = logging.getLogger(__name__) - class TestFusedCalculateAuxLoss(unittest.TestCase): def test_build_src_rank_and_local_expert_id(self): def orig_func(expert_num_global_list, num_local_experts): send_rank_cpu = np.concatenate( # TOO SLOW!!! break every thing - [np.full([j], i // num_local_experts, dtype="int32") for i, j in enumerate(expert_num_global_list)], + [ + np.full([j], i // num_local_experts, dtype="int32") + for i, j in enumerate(expert_num_global_list) + ], 0, ) local_expert_id_cpu = np.concatenate( - [np.full([j], i % num_local_experts, dtype="int32") for i, j in enumerate(expert_num_global_list)], + [ + np.full([j], i % num_local_experts, dtype="int32") + for i, j in enumerate(expert_num_global_list) + ], 0, ) send_rank = paddle.to_tensor(send_rank_cpu) local_expert_id = paddle.to_tensor(local_expert_id_cpu) return send_rank, local_expert_id - def fused_func(expert_num_global_tensor, expert_num_global, num_local_experts): + def fused_func( + expert_num_global_tensor, expert_num_global, num_local_experts + ): return build_src_rank_and_local_expert_id( expert_num_global_tensor, expert_num_global, num_local_experts ) - expert_num_global = np.random.randint(0, 512, size=[12 * 8],dtype="int32") - expert_num_global_tensor = paddle.to_tensor(expert_num_global, dtype="int64") + expert_num_global = np.random.randint( + 0, 512, size=[12 * 8], dtype="int32" + ) + expert_num_global_tensor = paddle.to_tensor( + expert_num_global, dtype="int64" + ) s1, l1 = orig_func(expert_num_global, 12) s2, l2 = fused_func(expert_num_global_tensor, expert_num_global, 12) @@ -53,6 +64,5 @@ def fused_func(expert_num_global_tensor, expert_num_global, num_local_experts): assert ((l1 - l2) == 0).all(), (l1, l2) - if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_expand_modality_expert_id.py b/test/legacy_test/test_incubate_expand_modality_expert_id.py index 699c55131b33a6..9f1d41e49697fe 100644 --- a/test/legacy_test/test_incubate_expand_modality_expert_id.py +++ b/test/legacy_test/test_incubate_expand_modality_expert_id.py @@ -1,26 +1,36 @@ -import os +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest +from collections import namedtuple +from functools import partial + +from ernie_utils.moe_all_gather_layer import MOEAllGatherLayerV2 import paddle -from paddle import _C_ops -import numpy as np -from op_test import convert_float_to_uint16 -import sys -from functools import partial -from collections import namedtuple import paddle.nn.functional as F +from paddle.incubate.nn.functional import expand_modality_expert_id -from paddle.autograd import PyLayer -from paddle import base -from paddle.base import core -import paddle.incubate.nn.functional.expand_modality_expert_id as expand_modality_expert_id -from ernie_utils.moe_all_gather_layer import MOEAllGatherLayerV2 - -def fused_gate_logits_process_ref(self, gate_logits_lm, gate_logits_mm, token_type_ids): +def fused_gate_logits_process_ref( + self, gate_logits_lm, gate_logits_mm, token_type_ids +): """process gatelogits""" top_k = self.k - num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] // self.config.moe_world_size + num_expert_per_rank_per_modality = ( + gate_logits_lm.shape[-1] // self.config.moe_world_size + ) @paddle.no_grad() def shift_ids(ids, modality_offset): @@ -34,7 +44,9 @@ def shift_ids(ids, modality_offset): ) if self.group_experts: - gate_logits_lm = gate_logits_lm.reshape([gate_logits_lm.shape[0], top_k, -1]) + gate_logits_lm = gate_logits_lm.reshape( + [gate_logits_lm.shape[0], top_k, -1] + ) prob_lm = self.gate.act(gate_logits_lm) weight_lm, expert_id_lm = prob_lm.topk(k=1, axis=-1) weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) @@ -48,9 +60,15 @@ def shift_ids(ids, modality_offset): if token_type_ids is not None: expert_id_lm = shift_ids(expert_id_lm, 0) expert_id_lm.stop_gradient = True - lm_weight_and_expert_id = paddle.concat([weight_lm, expert_id_lm.astype("float32")], -1) + lm_weight_and_expert_id = paddle.concat( + [weight_lm, expert_id_lm.astype("float32")], -1 + ) if token_type_ids is None: - return lm_weight_and_expert_id, prob_lm.reshape([prob_lm.shape[0], -1]), None + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) prob_mm = self.gate.act(gate_logits_mm) weight_mm, expert_id_mm = prob_mm.topk(k=top_k, axis=-1) @@ -58,16 +76,27 @@ def shift_ids(ids, modality_offset): expert_id_mm = shift_ids(expert_id_mm, 1) expert_id_mm.stop_gradient = True - mm_weight_and_expert_id = paddle.concat([weight_mm, expert_id_mm.astype("float32")], -1) + mm_weight_and_expert_id = paddle.concat( + [weight_mm, expert_id_mm.astype("float32")], -1 + ) token_type_ids_float = token_type_ids[:, None].astype("float32") weight_and_expert = ( - 1 - token_type_ids_float - ) * lm_weight_and_expert_id + token_type_ids_float * mm_weight_and_expert_id + (1 - token_type_ids_float) * lm_weight_and_expert_id + + token_type_ids_float * mm_weight_and_expert_id + ) return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + def test_expand_modality_expert_id(): - def expand_id_one(expert_id, num_expert_per_modality, k, group_size, modality_offset, is_group_expert): + def expand_id_one( + expert_id, + num_expert_per_modality, + k, + group_size, + modality_offset, + is_group_expert, + ): orig_shape = expert_id.shape expert_id = expert_id.reshape([-1]) xid = paddle.arange(len(expert_id)) @@ -77,22 +106,42 @@ def expand_id_one(expert_id, num_expert_per_modality, k, group_size, modality_of rank = expert_id // num_expert_per_modality expert_id_in_rank = expert_id % num_expert_per_modality - ret = rank * (num_expert_per_modality * 2) + expert_id_in_rank + modality_offset * num_expert_per_modality + ret = ( + rank * (num_expert_per_modality * 2) + + expert_id_in_rank + + modality_offset * num_expert_per_modality + ) return ret.reshape(orig_shape) S, E, k = 100, 24, 3 expert_id_mm = paddle.randint(0, 12, shape=[S, k]) num_expert_per_rank_per_modality = E // 2 // 4 group_size = E // 2 // k - print(f"num_expert_per_rank_per_modality: {num_expert_per_rank_per_modality}") - fused = expand_modality_expert_id(expert_id_mm, num_expert_per_rank_per_modality, group_size, 1, True) + print( + f"num_expert_per_rank_per_modality: {num_expert_per_rank_per_modality}" + ) + fused = expand_modality_expert_id( + expert_id_mm, num_expert_per_rank_per_modality, group_size, 1, True + ) - nonfused = expand_id_one(expert_id_mm, num_expert_per_rank_per_modality, k, group_size, 1, True) + nonfused = expand_id_one( + expert_id_mm, num_expert_per_rank_per_modality, k, group_size, 1, True + ) # num_expert_per_rank_per_modality, group_size assert (fused == nonfused).all().item() Config = namedtuple("Config", ["moe_world_size"]) - Self = namedtuple("Self", ["config", "k", "gate", "group_experts", "moe_statics", "use_correction_bias"]) + Self = namedtuple( + "Self", + [ + "config", + "k", + "gate", + "group_experts", + "moe_statics", + "use_correction_bias", + ], + ) Gate = namedtuple("Gate", ["act"]) fake_gate = Gate(act=partial(F.softmax, axis=-1)) fake_self = Self( @@ -109,18 +158,24 @@ def expand_id_one(expert_id, num_expert_per_modality, k, group_size, modality_of fake_logits = paddle.randn([S, E]) fake_logits_mm = paddle.randn([S, E]) token_type_ids = paddle.randint(0, 2, shape=[S]) - w_and_e, prob_lm, prob_mm = MOEAllGatherLayerV2.fused_gate_logits_process_fused( + w_and_e, prob_lm, prob_mm = ( + MOEAllGatherLayerV2.fused_gate_logits_process_fused( + fake_self, fake_logits, fake_logits_mm, None + ) + ) + w_and_e_ref, prob_lm_ref, prob_mm_ref = fused_gate_logits_process_ref( fake_self, fake_logits, fake_logits_mm, None ) - w_and_e_ref, prob_lm_ref, prob_mm_ref = fused_gate_logits_process_ref(fake_self, fake_logits, fake_logits_mm, None) assert (prob_lm == prob_lm_ref).all().item() assert (w_and_e == w_and_e_ref).all().item() w, e = w_and_e_ref.chunk(2, axis=-1) + class Test_expand_modality_expert_id_API(unittest.TestCase): def test_dygraph(self): test_expand_modality_expert_id() + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_fused_loss.py b/test/legacy_test/test_incubate_fused_loss.py index 65b792815d08dd..e6fe14a2d295f2 100644 --- a/test/legacy_test/test_incubate_fused_loss.py +++ b/test/legacy_test/test_incubate_fused_loss.py @@ -1,29 +1,32 @@ -import os -import unittest - -from op_test import convert_float_to_uint16 -import random -import paddle.nn.functional as F +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import paddle -import numpy as np -import random import logging +import unittest -import paddle -from paddle.nn.clip import _squared_l2_norm - +import numpy as np from ernie_utils.top2_gate import ( - CalAuxLossFunctor, cal_aux_loss_func, ) + +import paddle +import paddle.nn.functional as F from paddle.incubate.nn.functional import cal_aux_loss -from ernie_utils.moe_layer import fuse_logging logger = logging.getLogger(__name__) - class TestFusedCalculateAuxLoss(unittest.TestCase): def setUp(self): paddle.seed(42) @@ -48,19 +51,40 @@ def run_and_check( input_for_test.stop_gradient = False loss_ref = cal_aux_loss_func( - input_for_ref, dispatch_mask_for_ref, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k + input_for_ref, + dispatch_mask_for_ref, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + ) + loss, _, _ = cal_aux_loss( + input_for_test, + dispatch_mask_for_test, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + 1e-6, ) - loss,_,_= cal_aux_loss( - input_for_test, dispatch_mask_for_test, tokens_mask, dispatch_tokens_mask, num_experts, use_group, moe_k, 1e-6) loss_ref.backward() loss.backward() np.testing.assert_equal(loss.shape, loss_ref.shape) np.testing.assert_equal(loss.dtype, loss_ref.dtype) - np.testing.assert_equal(input_for_ref.grad.shape, input_for_test.grad.shape) - np.testing.assert_equal(input_for_ref.grad.dtype, input_for_test.grad.dtype) + np.testing.assert_equal( + input_for_ref.grad.shape, input_for_test.grad.shape + ) + np.testing.assert_equal( + input_for_ref.grad.dtype, input_for_test.grad.dtype + ) np.testing.assert_allclose( - loss.astype("float32").numpy(), loss_ref.astype("float32").numpy(), atol=self.atol, rtol=self.rtol + loss.astype("float32").numpy(), + loss_ref.astype("float32").numpy(), + atol=self.atol, + rtol=self.rtol, ) np.testing.assert_allclose( input_for_test.grad.astype("float32").numpy(), @@ -81,16 +105,30 @@ def run_single_case( for use_dispatch_tokens_mask in [True, False]: paddle.seed(48) gate_prob = paddle.randn([seq_len, expert_num]) - dispatch_mask = paddle.randint(0, seq_len, [expert_num]).astype("int64") - tokens_mask = paddle.randint(0, 1, [seq_len]).astype(gate_prob.dtype) if use_tokens_mask else None + dispatch_mask = paddle.randint( + 0, seq_len, [expert_num] + ).astype("int64") + tokens_mask = ( + paddle.randint(0, 1, [seq_len]).astype(gate_prob.dtype) + if use_tokens_mask + else None + ) dispatch_tokens_mask = ( - paddle.randint(0, 1, [seq_len * 2]).astype("bool") if use_dispatch_tokens_mask else None + paddle.randint(0, 1, [seq_len * 2]).astype("bool") + if use_dispatch_tokens_mask + else None ) self.run_and_check( - gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, g_num_experts, moe_k, use_group + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + g_num_experts, + moe_k, + use_group, ) - def test_trival_cases(self): + def test_trivial_cases(self): self.run_single_case(seq_len=1, expert_num=1) self.run_single_case(seq_len=3, expert_num=2) self.run_single_case(seq_len=13, expert_num=3) @@ -114,17 +152,31 @@ def test_trival_cases(self): self.run_single_case(seq_len=256 * 1024, expert_num=48) self.run_single_case(seq_len=512 * 1024, expert_num=128) - def run_special_case(self, global_seq_len, seq_len, global_expert_num, expert_num, moe_k): + def run_special_case( + self, global_seq_len, seq_len, global_expert_num, expert_num, moe_k + ): for use_group in [True, False]: paddle.seed(48) seq_len = 4096 expert_num = 48 gate_prob = F.softmax(paddle.randn([seq_len, expert_num]), axis=-1) - dispatch_mask = paddle.randint(0, seq_len, [seq_len, expert_num]).astype("int64") - tokens_mask = paddle.randint(0, 1, [seq_len]).astype(gate_prob.dtype) - dispatch_tokens_mask = paddle.randint(0, 1, [global_seq_len]).astype("bool") + dispatch_mask = paddle.randint( + 0, seq_len, [seq_len, expert_num] + ).astype("int64") + tokens_mask = paddle.randint(0, 1, [seq_len]).astype( + gate_prob.dtype + ) + dispatch_tokens_mask = paddle.randint( + 0, 1, [global_seq_len] + ).astype("bool") self.run_and_check( - gate_prob, dispatch_mask, tokens_mask, dispatch_tokens_mask, global_expert_num, moe_k, use_group + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + global_expert_num, + moe_k, + use_group, ) def test_special_cases(self): @@ -149,4 +201,4 @@ def test_special_cases(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_fused_rmsnorm_ext.py b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py index cf9877cb6f5994..89dc90aeb2f18f 100644 --- a/test/legacy_test/test_incubate_fused_rmsnorm_ext.py +++ b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py @@ -13,20 +13,22 @@ # limitations under the License. import unittest + import numpy as np + import paddle -import paddle.nn.functional as F from paddle.incubate.nn.functional import fused_rms_norm_ext # 假设 fused_rms_norm_ext 已经被导入 # from your_module import fused_rms_norm_ext + class TestFusedRMSNorm(unittest.TestCase): def setUp(self): # 设置随机种子以确保结果可复现 paddle.seed(2023) np.random.seed(2023) - + def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): """ 使用 Paddle 原生操作实现 RMS Normalization 作为参考 @@ -42,43 +44,46 @@ def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): # 应用偏置(如果有) if bias is not None: y = y + bias.reshape([1, -1]) - + # 返回归一化后的张量、均值(RMS Norm 中为0)和逆标准差 return y, (1.0 / rms).squeeze(-1) - + def test_2d_input(self): # 测试 2D 输入 rows, cols = 32, 64 x = paddle.randn([rows, cols]) scale = paddle.randn([cols]) - + # 使用我们的实现 y_fused, invvar_fused = fused_rms_norm_ext(x, scale) - + # 使用参考实现 y_ref, invvar_ref = self.rms_norm_reference(x, scale) - + # 验证结果 np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5) - - + np.testing.assert_allclose( + invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 + ) + def test_without_bias(self): # 测试没有偏置的情况 rows, cols = 32, 64 x = paddle.randn([rows, cols]) scale = paddle.randn([cols]) - + # 使用我们的实现 y_fused, invvar_fused = fused_rms_norm_ext(x, scale) - + # 使用参考实现 y_ref, invvar_ref = self.rms_norm_reference(x, scale) - + # 验证结果 np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5) - + np.testing.assert_allclose( + invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 + ) + def test_backward(self): # 测试反向传播 rows, cols = 16, 32 @@ -86,22 +91,22 @@ def test_backward(self): x.stop_gradient = False scale = paddle.randn([cols], dtype='float32') scale.stop_gradient = False - + # 前向传播 y_fused, invvar = fused_rms_norm_ext(x, scale) - + # 计算损失并反向传播 loss = paddle.mean(y_fused) loss.backward() - + # 获取梯度 x_grad_fused = x.grad.clone() scale_grad_fused = scale.grad.clone() - + # 重置梯度 x.clear_gradient() scale.clear_gradient() - + # 使用参考实现 y_ref, invvar_ref = self.rms_norm_reference(x, scale) loss_ref = paddle.mean(y_ref) @@ -110,10 +115,15 @@ def test_backward(self): # 获取参考梯度 x_grad_ref = x.grad scale_grad_ref = scale.grad - + # 验证梯度 - np.testing.assert_allclose(x_grad_fused, x_grad_ref, rtol=1e-4, atol=1e-4) - np.testing.assert_allclose(scale_grad_fused, scale_grad_ref, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose( + x_grad_fused, x_grad_ref, rtol=1e-4, atol=1e-4 + ) + np.testing.assert_allclose( + scale_grad_fused, scale_grad_ref, rtol=1e-4, atol=1e-4 + ) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_int_bincount.py b/test/legacy_test/test_incubate_int_bincount.py index daf983d4a65a97..46f43cf791c35b 100644 --- a/test/legacy_test/test_incubate_int_bincount.py +++ b/test/legacy_test/test_incubate_int_bincount.py @@ -1,30 +1,47 @@ -import paddle -import numpy as np +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest + +import numpy as np + +import paddle from paddle.incubate.nn.functional import int_bincount + class TestIntBincount(unittest.TestCase): def setUp(self): paddle.set_device('gpu') - + def test_basic(self): x = paddle.to_tensor([1, 2, 3, 1, 2, 3], dtype=paddle.int32) out = int_bincount(x, low=1, high=4, dtype=paddle.int32) - expected = np.array([2, 2, 2,0]) + expected = np.array([2, 2, 2, 0]) np.testing.assert_array_equal(out.numpy(), expected) - + def test_empty_input(self): x = paddle.to_tensor([], dtype=paddle.int32) out = int_bincount(x, low=0, high=10, dtype=paddle.int32) self.assertEqual(out.shape, [11]) self.assertEqual(out.sum().item(), 0) - + def test_different_dtypes(self): x = paddle.to_tensor([1, 3, 5, 3, 1], dtype=paddle.int64) out = int_bincount(x, low=1, high=6, dtype=paddle.int64) expected = np.array([2, 0, 2, 0, 1, 0]) np.testing.assert_array_equal(out.numpy(), expected) - + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_combine.py b/test/legacy_test/test_incubate_moe_combine.py index 6c39a1af222933..2c765e13671230 100644 --- a/test/legacy_test/test_incubate_moe_combine.py +++ b/test/legacy_test/test_incubate_moe_combine.py @@ -1,17 +1,27 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +import random import unittest import numpy as np -from op_test import convert_float_to_uint16 -import random -import paddle.nn.functional as F +from ernie_utils.moe_layer_uneven import GateCombine import paddle -from paddle import base -from paddle.base import core +import paddle.nn.functional as F from paddle.incubate.nn.functional import moe_combine -from ernie_utils.moe_layer_uneven import GateCombine - os.environ["FLAGS_flash_attn_version"] = "v1" os.environ["FLAGS_cudnn_deterministic"] = "1" @@ -37,7 +47,9 @@ def combining(x, combine_weights, scatter_index, hard_gate=False): return y -def baseline_result(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy): +def baseline_result( + x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy +): """baseline_result""" scatter_index = paddle.to_tensor(scatter_index_numpy) x = paddle.to_tensor(x_numpy).cast("float32") @@ -54,7 +66,9 @@ def baseline_result(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_nu return [x.grad, combine_weights.grad, y] -def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy): +def test_moe_combine( + x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy +): """baseline_result""" x = paddle.to_tensor(x_numpy).cast("float32") x.stop_gradient = False @@ -67,7 +81,7 @@ def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_n y = GateCombine.apply(x, combine_weights, scatter_index) paddle.autograd.backward([y], [grad], True) - #grad.backward() + # grad.backward() return [x.grad, combine_weights.grad, y] @@ -78,7 +92,9 @@ def gen_test_case(S, K, Dim, capacity_factor, seed=1234): paddle.seed(seed) x_numpy = np.random.rand(int(S * capacity_factor), Dim).astype(np.float32) combine_weights_numpy = np.random.rand(S, K).astype(np.float32) - scatter_index_numpy = np.random.permutation(max(x_numpy.shape[0], S * K))[: S * K].astype("int64") + scatter_index_numpy = np.random.permutation(max(x_numpy.shape[0], S * K))[ + : S * K + ].astype("int64") scatter_index_numpy = scatter_index_numpy.reshape([S, K]) combine_weights_numpy[scatter_index_numpy >= x_numpy.shape[0]] = 0 @@ -90,9 +106,14 @@ def gen_test_case(S, K, Dim, capacity_factor, seed=1234): def testing(test_case): """testing""" [bl_x_grad, bl_combine_weights_grad, bl_y] = baseline_result(*test_case) - [fused_x_grad, fused_combine_weights_grad, fused_y] = test_moe_combine(*test_case) + [fused_x_grad, fused_combine_weights_grad, fused_y] = test_moe_combine( + *test_case + ) np.testing.assert_allclose( - fused_y.astype("float32").numpy(), bl_y.astype("float32").numpy(), err_msg="fwd precision not pass", rtol=1e-6 + fused_y.astype("float32").numpy(), + bl_y.astype("float32").numpy(), + err_msg="fwd precision not pass", + rtol=1e-6, ) np.testing.assert_allclose( fused_x_grad.astype("float32").numpy(), @@ -106,6 +127,7 @@ def testing(test_case): rtol=1e-6, ) + class TestFused(unittest.TestCase): @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") def test_cap_lt_2( @@ -171,6 +193,7 @@ def test_k_gt_2( """ testing(gen_test_case(S=1024, K=8, Dim=4096, capacity_factor=2)) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py index 0a0f7a586a3f6b..0a19402605211d 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -1,14 +1,25 @@ +# ruff: noqa: C419 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest -import sys -from functools import partial -import numpy as np -from collections import namedtuple import paddle -from paddle.autograd import PyLayer -import paddle.nn.functional as F -from ernie_utils.moe_layer_uneven import GateDispatch -from paddle.incubate.nn.functional import moe_gate_dispatch_partial_nosoftmaxtopk, moe_gate_dispatch +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_partial_nosoftmaxtopk, +) def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(): @@ -22,7 +33,10 @@ def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(): x = paddle.arange(1, s + 1).unsqueeze(-1).expand([s, d]).astype("bfloat16") x_ = x.clone().detach() - t = ((paddle.arange(0, e)).unsqueeze(0) + paddle.arange(0, -s, -1).unsqueeze(-1)) % e + t = ( + (paddle.arange(0, e)).unsqueeze(0) + + paddle.arange(0, -s, -1).unsqueeze(-1) + ) % e gate_logits = (1 / (t + 1)).astype("float32") # gate_logits = F.softmax(paddle.randn([s,e]),-1).astype('float32') gate_logits_ = gate_logits.clone().detach() @@ -76,13 +90,15 @@ def check_ascend(index_rev, chunks): comm_sum = paddle.stack(comm).sum(0) ys_sum = paddle.concat(ys) - y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = moe_gate_dispatch( - x_, - gate_logits_, - None, - k=k, - capacity=cap, - use_pad=True, # k # cap + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x_, + gate_logits_, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) ) valid_y = y_.sum(-1) > 0.0 y_2 = y_[valid_y].squeeze() @@ -103,12 +119,19 @@ def check_ascend(index_rev, chunks): """ ) - print(f"<<< begin backward>>>") + print("<<< begin backward>>>") - assert combine_weihgts_.shape == combine_weihgts.shape, (combine_weihgts_.shape, combine_weihgts.shape) + assert combine_weihgts_.shape == combine_weihgts.shape, ( + combine_weihgts_.shape, + combine_weihgts.shape, + ) - dysum, dcombine_weights_sum = paddle.ones_like(ys_sum), paddle.randn(comm_sum.shape).astype(comm_sum.dtype) - dy_, dcombine_weights_ = paddle.ones_like(y_), paddle.ones_like(combine_weihgts_) + dysum, dcombine_weights_sum = paddle.ones_like(ys_sum), paddle.randn( + comm_sum.shape + ).astype(comm_sum.dtype) + dy_, dcombine_weights_ = paddle.ones_like(y_), paddle.ones_like( + combine_weihgts_ + ) dy_[~valid_y] = 0 y_shapes = [len(y) for y in ys] @@ -133,8 +156,6 @@ def check_ascend(index_rev, chunks): ) - - def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): S, E, D = 3, 4, 3 @@ -142,7 +163,9 @@ def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): capacity = 2 x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") cw = paddle.randn([S, k]) - eid = paddle.to_tensor([[0, 1], [0, 1], [0, 2]], dtype="int32") # 1 # 2 # 3 + eid = paddle.to_tensor( + [[0, 1], [0, 1], [0, 2]], dtype="int32" + ) # 1 # 2 # 3 ( y, cw_, @@ -167,7 +190,9 @@ def test_moe_ops_partial_nosoftmax_topk_empty_output(): x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") cw = paddle.randn([S, k]) paddle.device.synchronize() - eid = paddle.to_tensor([[0, 1], [0, 1], [0, 2]], dtype="int32") # 1 # 2 # 3 + eid = paddle.to_tensor( + [[0, 1], [0, 1], [0, 2]], dtype="int32" + ) # 1 # 2 # 3 ( y, cw_, diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py index 0b1909d287be76..56d9ddd397a776 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py @@ -1,18 +1,30 @@ # !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -import sys import unittest -import contextlib import numpy as np -import random -import time + import paddle -from paddle import _C_ops -from paddle.autograd import PyLayer import paddle.nn.functional as F -from paddle.incubate.nn.functional import moe_gate_dispatch, moe_gate_dispatch_permute - +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_permute, +) os.environ["FLAGS_flash_attn_version"] = "v1" os.environ["FLAGS_cudnn_deterministic"] = "1" @@ -38,26 +50,36 @@ def test_moe_ops(self): bias = paddle.zeros([E], dtype="float32") cap = 512 - y, combine_weihgts, scatter_index, expert_offset_, expert_id_ = moe_gate_dispatch( - x, - gate_logits, - None, - k=k, - capacity=cap, - use_pad=True, # k # cap + y, combine_weihgts, scatter_index, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x, + gate_logits, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) ) - y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = moe_gate_dispatch( - x_, - gate_logits_, - bias + 1, # +1也不会破坏路由结果 - k=k, - capacity=cap, - use_pad=True, # k # cap + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x_, + gate_logits_, + bias + 1, # +1也不会破坏路由结果 + k=k, + capacity=cap, + use_pad=True, # k # cap + ) ) bias_unbalanced = bias.clone() bias_unbalanced[0] += 1 - y__, combine_weihgts__, scatter_index__, expert_offset__, expert_id__ = moe_gate_dispatch( + ( + y__, + combine_weihgts__, + scatter_index__, + expert_offset__, + expert_id__, + ) = moe_gate_dispatch( x_, gate_logits_, bias_unbalanced, @@ -66,7 +88,9 @@ def test_moe_ops(self): use_pad=True, # k # cap ) np.testing.assert_equal( - y.astype("float32").numpy(), y_.astype("float32").numpy(), err_msg="incubate w bias not match" + y.astype("float32").numpy(), + y_.astype("float32").numpy(), + err_msg="incubate w bias not match", ) # bias 不影响 prob 概率 np.testing.assert_equal( @@ -75,7 +99,9 @@ def test_moe_ops(self): err_msg="incubate w bias not match", ) np.testing.assert_( - (y.astype("float32").numpy(0) != y__.astype("float32").numpy()).any(), + ( + y.astype("float32").numpy(0) != y__.astype("float32").numpy() + ).any(), ) @@ -93,14 +119,24 @@ def get_stage_input_list(self, x, world_size, stage): stage_input_list = [] x_list = paddle.split(x, num_or_sections=(world_size * stage), axis=0) for stage_id in range(stage): - stage_input_list.append(paddle.unsqueeze(paddle.concat(x_list[stage_id::stage], axis=0), axis=0)) + stage_input_list.append( + paddle.unsqueeze( + paddle.concat(x_list[stage_id::stage], axis=0), axis=0 + ) + ) stage_input_list = paddle.concat(stage_input_list, axis=0) return stage_input_list def test_moe_permute_ops(self): paddle.seed(2025) - test_cases = [(8, 4, 2), (64, 16, 32), (1024, 1024, 1024), (8, 2, 4), (4096, 4096, 4096)] + test_cases = [ + (8, 4, 2), + (64, 16, 32), + (1024, 1024, 1024), + (8, 2, 4), + (4096, 4096, 4096), + ] cases = list(zip(*test_cases)) for _, case in enumerate(cases): world_size, num_experts, num_tokens, k, hidden_size = case @@ -108,7 +144,9 @@ def test_moe_permute_ops(self): stages = num_experts // world_size input = paddle.randn([num_tokens, hidden_size], dtype="float32") - prob_logits = paddle.randn([num_tokens, num_experts], dtype="float32") + prob_logits = paddle.randn( + [num_tokens, num_experts], dtype="float32" + ) prob = F.softmax(prob_logits, axis=-1) input.stop_gradient = False prob.stop_gradient = False @@ -122,9 +160,18 @@ def test_moe_permute_ops(self): ref_scatter_index, ref_dispatch_mask, _, - ) = moe_gate_dispatch(ref_input, ref_prob, *compat_args, k=k, capacity=capacity, use_pad=True) + ) = moe_gate_dispatch( + ref_input, + ref_prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) - ref_stage_input_list = self.get_stage_input_list(ref_dispatched_input, world_size, stages) + ref_stage_input_list = self.get_stage_input_list( + ref_dispatched_input, world_size, stages + ) test_input, test_prob = self.get_detached_input(input, prob) ( @@ -134,14 +181,23 @@ def test_moe_permute_ops(self): test_dispatch_mask, _, ) = moe_gate_dispatch_permute( - test_input, test_prob, *compat_args, k=k, capacity=capacity, world_size=world_size + test_input, + test_prob, + *compat_args, + k=k, + capacity=capacity, + world_size=world_size, ) np.testing.assert_equal( - test_dispatched_input.shape, ref_stage_input_list.shape, err_msg="moe_permute_ops not match" + test_dispatched_input.shape, + ref_stage_input_list.shape, + err_msg="moe_permute_ops not match", ) np.testing.assert_equal( - test_dispatched_input._md5sum(), ref_stage_input_list._md5sum(), err_msg="moe_permute_ops not match" + test_dispatched_input._md5sum(), + ref_stage_input_list._md5sum(), + err_msg="moe_permute_ops not match", ) diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py index a1a9f61aee3440..14cb8358078f51 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py @@ -1,17 +1,29 @@ # !/usr/bin/env python3 -import os -import sys + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest -import contextlib import numpy as np -import random -import time + import paddle -from paddle import _C_ops -from paddle.autograd import PyLayer import paddle.nn.functional as F -from paddle.incubate.nn.functional import moe_gate_dispatch, moe_gate_dispatch_permute +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_permute, +) batch_size = 4 hidden_size = 2 @@ -24,16 +36,18 @@ class TestLayer(paddle.nn.Layer): def forward(self, x, gate_prob, k, capacity): - y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch( - x, gate_prob, None, k, capacity, True + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch(x, gate_prob, None, k, capacity, True) ) return y, combine_weights, scatter_index, expert_offset, expert_id class TestLayerPermute(paddle.nn.Layer): def forward(self, x, gate_prob, k, capacity): - y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( - x, gate_prob, None, k, capacity, world_size=world_size + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch_permute( + x, gate_prob, None, k, capacity, world_size=world_size + ) ) return y, combine_weights, scatter_index, expert_offset, expert_id @@ -54,7 +68,9 @@ def check_backward_correctness(layer_cls): input.stop_gradient = False gate_prob.stop_gradient = False - output, combine_weights, scatter_index, expert_offset, expert_id = layer(input, gate_prob, k, capacity) + output, combine_weights, scatter_index, expert_offset, expert_id = layer( + input, gate_prob, k, capacity + ) print(f"output: {output}") print(f"combine_weights: {combine_weights}") @@ -85,8 +101,12 @@ def check_backward_correctness(layer_cls): input_pos.flat[i] += epsilon input_neg.flat[i] -= epsilon - output_pos, _, _, _, _ = layer(paddle.to_tensor(input_pos), gate_prob, k, capacity) - output_neg, _, _, _, _ = layer(paddle.to_tensor(input_neg), gate_prob, k, capacity) + output_pos, _, _, _, _ = layer( + paddle.to_tensor(input_pos), gate_prob, k, capacity + ) + output_neg, _, _, _, _ = layer( + paddle.to_tensor(input_neg), gate_prob, k, capacity + ) ''' flattened[i] = (output_pos.astype("float32").numpy() - output_neg.astype("float32").numpy()).sum() / ( @@ -94,14 +114,17 @@ def check_backward_correctness(layer_cls): ) ''' grad_value = (output_pos - output_neg).sum() / (2 * epsilon) - flattened[i] = grad_value + flattened[i] = grad_value flattened = flattened.reshape(input.shape) print(f"input gradient: {input.grad}") print(f"numerical gradient: {flattened}") np.testing.assert_allclose( - input.grad.astype("float32").numpy(), flattened.astype("float32").numpy(), rtol=1e-5, atol=0 + input.grad.astype("float32").numpy(), + flattened.astype("float32").numpy(), + rtol=1e-5, + atol=0, ) # 数值估算 gate_prob @@ -116,17 +139,26 @@ def check_backward_correctness(layer_cls): input_pos.flat[i] += epsilon input_neg.flat[i] -= epsilon - _, output_pos, _, _, _ = layer(input, paddle.to_tensor(input_pos), k, capacity) - _, output_neg, _, _, _ = layer(input, paddle.to_tensor(input_neg), k, capacity) + _, output_pos, _, _, _ = layer( + input, paddle.to_tensor(input_pos), k, capacity + ) + _, output_neg, _, _, _ = layer( + input, paddle.to_tensor(input_neg), k, capacity + ) - flattened[i] = (output_pos.numpy() - output_neg.numpy()).sum() / (2 * epsilon) + flattened[i] = (output_pos.numpy() - output_neg.numpy()).sum() / ( + 2 * epsilon + ) flattened = flattened.reshape(gate_prob.shape) print(f"gate_prob gradient: {gate_prob.grad}") print(f"numerical gradient: {flattened}") np.testing.assert_allclose( - gate_prob.grad.astype("float32").numpy(), flattened.astype("float32").numpy(), rtol=1e-4, atol=0 + gate_prob.grad.astype("float32").numpy(), + flattened.astype("float32").numpy(), + rtol=1e-4, + atol=0, ) @@ -139,4 +171,4 @@ def test_moe_permute_backward(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From bba06362ab51d1c6a4e2702d9aac65e371491e7e Mon Sep 17 00:00:00 2001 From: Pan Zhaowu Date: Wed, 28 May 2025 13:46:58 +0000 Subject: [PATCH 54/71] Fix miscs --- paddle/phi/kernels/moe_fuse_bwd_op.h | 2 ++ paddle/phi/kernels/moe_kernel_impl.h | 3 +++ python/paddle/incubate/nn/functional/__init__.py | 4 +++- .../test_incubate_moe_gate_dispatch_w_permute_bwd.py | 5 +++-- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index e5af322d04f2f5..1b2f474198741f 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -17,6 +17,7 @@ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" +#ifdef PADDLE_WITH_CUDA template __global__ void gather_with_mask_permute_kernel( const T* dy, // [s*k, d] @@ -309,3 +310,4 @@ void topk_grad_with_mask_launcher(const T* dy, // [s, k] topk_grad_with_mask<<>>( dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); } +#endif diff --git a/paddle/phi/kernels/moe_kernel_impl.h b/paddle/phi/kernels/moe_kernel_impl.h index 2881463e8d3045..68b84efc9fdfcb 100644 --- a/paddle/phi/kernels/moe_kernel_impl.h +++ b/paddle/phi/kernels/moe_kernel_impl.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_CUDA #include #include #include @@ -644,3 +645,5 @@ __global__ void initialize_moe_routing_kernel( } } // namespace phi + +#endif diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index bed678f8fa606d..05ec5f17620df3 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -80,7 +80,9 @@ "swiglu", "moe_combine", "expand_modality_expert_id", - "cal_aux_loss" "build_src_rank_and_local_expert_id" "int_bincount", + "cal_aux_loss", + "build_src_rank_and_local_expert_id", + "int_bincount", "fused_rms_norm_ext", "moe_gate_dispatch", "moe_gate_dispatch_permute", diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py index 14cb8358078f51..bf03ffa20d4c1d 100644 --- a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py @@ -146,9 +146,10 @@ def check_backward_correctness(layer_cls): input, paddle.to_tensor(input_neg), k, capacity ) - flattened[i] = (output_pos.numpy() - output_neg.numpy()).sum() / ( - 2 * epsilon + grad_value = paddle.to_tensor( + (output_pos.numpy() - output_neg.numpy()).sum() / (2 * epsilon) ) + flattened[i] = grad_value flattened = flattened.reshape(gate_prob.shape) From fcc4f8114e36c846b1e339cbcd1482669ddc41f0 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Thu, 29 May 2025 08:42:23 +0000 Subject: [PATCH 55/71] try to pass CI --- paddle/phi/kernels/CMakeLists.txt | 2 +- .../gpu/moe_gate_dispatch_grad_kernel.cu | 1 + paddle/phi/kernels/moe_fuse_op.h | 1 + .../nn/functional/moe_gate_dispatch.py | 22 ++++---- ...moe_gate_dispatch_partial_nosoftmaxtopk.py | 55 ------------------- .../functional/moe_gate_dispatch_permute.py | 46 ---------------- 6 files changed, 14 insertions(+), 113 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 648cf71d2be159..40ac78b7e8c565 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -229,7 +229,7 @@ if(WITH_ROCM) REMOVE_ITEM kernel_gpu "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" - "gpu/moe_gate_dispatch_permute_kernel_grad.cu" + "gpu/moe_gate_dispatch_permute_grad_kernel.cu" "gpu/moe_gate_dispatch_permute_kernel.cu" "gpu/expand_modality_expert_id_kernel.cu" "gpu/moe_combine_kernel.cu" diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu index 6c7724ab2cc8b1..ba37d065f94373 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu @@ -160,5 +160,6 @@ PD_REGISTER_KERNEL(moe_gate_dispatch_grad, ALL_LAYOUT, phi::MoeGateDispatchGradKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index 40c2fc3fcf57a5..0fddd734abb2fa 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -23,6 +23,7 @@ #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" +#include "paddle/phi/kernels/funcs/math_function.h" template __launch_bounds__(TPB) __global__ diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py index e0c9051804bfae..2e50f6c1698f22 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py @@ -38,19 +38,19 @@ def moe_gate_dispatch( ) -> Tensor: """ Args: - x, - gate_logits, - corr_bias, - k, - capacity, - use_pad + x: + gate_logits: + corr_bias: + k: + capacity: + use_pad: Returns: - y, - combine_weights, - scatter_index, - expert_offset, - expert_id + y: + combine_weights: + scatter_index: + expert_offset: + expert_id: """ if in_dynamic_or_pir_mode(): return _C_ops.moe_gate_dispatch( diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py index e8146589b1ad96..ef591637fb2502 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -95,58 +95,3 @@ def moe_gate_dispatch_partial_nosoftmaxtopk( expert_offset, expert_nums_local, ) - - -# import paddle -# import numpy as np - -# num_rows = 4 -# feature_dim = 8 -# num_experts = 3 -# k = 2 -# capacity = 5 - -# # 输入张量 -# x = paddle.to_tensor(np.random.rand(num_rows, feature_dim).astype('float32'), stop_gradient=False) - -# # 合并权重张量 -# combine_weights = paddle.to_tensor(np.random.rand(num_rows, k).astype('float32'), stop_gradient=False) - -# # 专家ID张量 -# expert_id = paddle.to_tensor(np.random.randint(0, num_experts, size=(num_rows, k)).astype('int32'), stop_gradient=False) - -# print("x type:", x.dtype) -# print("combine_weights type:", combine_weights.dtype) -# print("expert_id type:", expert_id.dtype) -# # 其他参数 -# use_pad = True -# expert_start_index = 0 -# expert_end_index = num_experts -# reverse_token_drop = False - -# # 调用自定义算子 -# y, combine_weights_out, scatter_index, scatter_index_rev, expert_offset, expert_nums_local = moe_ops_partial_nosoftmaxtopk( -# x=x, -# combine_weights=combine_weights, -# expert_id=expert_id, -# k=k, -# capacity=capacity, -# num_experts=num_experts, -# use_pad=use_pad, -# expert_start_index=expert_start_index, -# expert_end_index=expert_end_index, -# reverse_token_drop=reverse_token_drop -# ) - -# # 打印结果 -# print("y:", y.numpy()) -# print("combine_weights_out:", combine_weights_out.numpy()) -# print("scatter_index:", scatter_index.numpy()) -# print("scatter_index_rev:", scatter_index_rev.numpy()) -# print("expert_offset:", expert_offset.numpy()) -# print("expert_nums_local:", expert_nums_local.numpy()) - -# a = paddle.sum(y)+paddle.sum(combine_weights_out) -# a.backward() -# print("\n##########backward output##########\n") -# print(f"x.grad: {x.grad}\n combine_weights.grad: {combine_weights.grad}\n expert_id.grad: {expert_id.grad}") diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py index 23e762a4421805..9721590f1443f0 100644 --- a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py @@ -86,49 +86,3 @@ def moe_gate_dispatch_permute( attrs=attrs, ) return y, combine_weights, scatter_index, expert_offset, expert_id - - -# # 定义输入参数 -# num_rows = 10 # 示例行数 -# hidden_size = 128 # 隐藏层维度 -# num_experts = 4 # 专家数 -# world_size = 2 # 分布式世界大小 -# k = 2 # 选择的Top-k专家 -# capacity = 5 # 每个专家的处理容量 - -# # 确保num_experts可以被world_size整除 -# assert num_experts % world_size == 0 - -# # 生成输入数据 -# x = paddle.randn([num_rows, hidden_size], dtype='float32') -# gate_logits = paddle.randn([num_rows, num_experts], dtype='float32') -# x.stop_gradient = False -# gate_logits.stop_gradient = False - -# # 可选的修正偏差 -# # corr_bias = paddle.randn([num_rows], dtype='float32') -# corr_bias = None - -# # 调用封装的API -# y, combine_weights, scatter_index, expert_offset, expert_id = moe_gate_dispatch_permute( -# x=x, -# gate_logits=gate_logits, -# corr_bias=corr_bias, -# k=k, -# capacity=capacity, -# world_size=world_size -# ) - -# # 打印输出结果的形状和类型,验证结果 -# print("Output y shape:", y.shape) -# print("Combine weights shape:", combine_weights.shape) -# print("Scatter index shape:", scatter_index.shape) -# print("Expert offset shape:", expert_offset.shape) -# print("Expert ID shape:", expert_id.shape) - -# a = paddle.sum(y)+paddle.sum(combine_weights)+paddle.sum(scatter_index)+paddle.sum(expert_offset)+paddle.sum(expert_id) -# a.backward() - -# print("Gradient of x:", x.grad) -# print("Gradient of gate_logits:", gate_logits.grad) -# print("Gradient of corr_bias:", corr_bias.grad) From 29e0b021f0e34305ee2371c7da910b3c490e6e2b Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Thu, 29 May 2025 08:54:24 +0000 Subject: [PATCH 56/71] format header file --- paddle/phi/kernels/fused_moe_bwd_op.h | 351 ------------------ .../gpu/moe_gate_dispatch_grad_kernel.cu | 2 +- paddle/phi/kernels/moe_fuse_bwd_op.h | 4 + paddle/phi/kernels/moe_fuse_op.h | 3 + 4 files changed, 8 insertions(+), 352 deletions(-) delete mode 100644 paddle/phi/kernels/fused_moe_bwd_op.h diff --git a/paddle/phi/kernels/fused_moe_bwd_op.h b/paddle/phi/kernels/fused_moe_bwd_op.h deleted file mode 100644 index 3714e5872d8eb0..00000000000000 --- a/paddle/phi/kernels/fused_moe_bwd_op.h +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#ifndef _FUSED_MOE_BWD_OP_H_ -#define _FUSED_MOE_BWD_OP_H_ - -#include -#include -#include - -#include "cutlass/array.h" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" -#include "paddle/phi/backends/gpu/gpu_info.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" - -#define WARP_SIZE 32 -// Ignore CUTLASS warnings about type punning -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#pragma GCC diagnostic ignored "-Wunused-function" - -#pragma GCC diagnostic pop - -// namespace paddle { -// namespace operators { - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -template -__global__ void topk_grad_with_mask(const T* dy, // [s, k] - const int* topk_idx, // [s, k] - const T* combine_weights, // [s, k] - T* dx, // [s, e] - int64_t num_rows, // s - int64_t k, // k - int64_t num_experts // e -) { - // init dx to zero - for (int i = blockIdx.x; i < num_rows; i += gridDim.x) { - int base_grad = i * num_experts; - for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { - dx[base_grad + j] = static_cast(0); - } - __syncthreads(); - int base_index = i * k; - for (int j = threadIdx.x; j < k; j += blockDim.x) { - int64_t idx = topk_idx[base_index + j]; - if (combine_weights[base_index + j] > static_cast(0)) { - dx[base_grad + idx] = dy[base_index + j]; - } - } - } -} - -// y=zero_part(topk(x)) 的反向过程 -// x: [s,e] -// dy: [s,k] -// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> combine_weights: [s, k] -template -void topk_grad_with_mask_launcher(const T* dy, // [s, k] - const int* topk_idx, // [s, k] - const T* combine_weights, // [s, k] - T* dx, // [s, e] - int64_t num_rows, // s - int64_t k, // k - int64_t num_experts, // e - cudaStream_t stream) { - int blocks = num_rows; - int threads = 1024; - - topk_grad_with_mask<<>>( - dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); -} - -template -__global__ void gather_with_mask_permute_kernel( - const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s, d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t N, - int64_t num_active, // skip > num_active pos is num_active specified - int64_t s_shared_num, - int64_t capacity, - int64_t world_size, - int64_t num_local_experts) { - extern __shared__ char shared[]; - int* scatter_index_shared = reinterpret_cast(shared); - float* combine_weights_shared = - reinterpret_cast(shared + s_shared_num * k * sizeof(int)); - int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; - - for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; - idx < N; - idx += blockDim.x * gridDim.x * vec_size) { - int64_t si = idx / dim; - int64_t di_begin = idx % dim; - int64_t si_shared_begin = shared_idx_begin / dim; - int64_t shared_stride = - min(static_cast(blockDim.x), N - shared_idx_begin); - - for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { - if (si_shared_begin * k + i >= num_rows * k) { - break; - } - scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; - combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; - } - __syncthreads(); - - phi::AlignedVector in_vec; - phi::AlignedVector out_vec; - for (int ii = 0; ii < vec_size; ++ii) { - out_vec[ii] = static_cast(0); - } - - for (int64_t i = 0; i < k; ++i) { - int64_t scatter_offset = (si - si_shared_begin) * k + i; - int id = scatter_index_shared[scatter_offset]; - if (num_active >= 0 && id >= num_active) { - continue; - } - if (combine_weights_shared[scatter_offset] > 0.f) { - int64_t remaining_after_irank = id % (num_local_experts * capacity); - - int64_t irank = id / (num_local_experts * capacity); - int64_t local_iexpert = remaining_after_irank / capacity; - int64_t row_in_expert = remaining_after_irank % capacity; - int64_t permuted_id = local_iexpert * (world_size * capacity) + - irank * capacity + row_in_expert; - int64_t in_offset = permuted_id * dim + di_begin; - phi::Load(dy + in_offset, &in_vec); - for (int64_t j = 0; j < vec_size; ++j) { - out_vec[j] += in_vec[j]; - } - } - } - phi::Store(out_vec, dx + idx); - shared_idx_begin += blockDim.x * gridDim.x * vec_size; - } -} - -template -__global__ void gather_with_mask_kernel( - const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s, d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t N, - int64_t num_active, // skip > num_active pos is num_active specified - int64_t s_shared_num) { - extern __shared__ char shared[]; - int* scatter_index_shared = reinterpret_cast(shared); - float* combine_weights_shared = - reinterpret_cast(shared + s_shared_num * k * sizeof(int)); - int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; - - for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; - idx < N; - idx += blockDim.x * gridDim.x * vec_size) { - int64_t si = idx / dim; - int64_t di_begin = idx % dim; - int64_t si_shared_begin = shared_idx_begin / dim; - int64_t shared_stride = - min(static_cast(blockDim.x), N - shared_idx_begin); - - for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { - if (si_shared_begin * k + i >= num_rows * k) { - break; - } - scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; - combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; - } - __syncthreads(); - - phi::AlignedVector in_vec; - phi::AlignedVector out_vec; - for (int ii = 0; ii < vec_size; ++ii) { - out_vec[ii] = static_cast(0); - } - - for (int64_t i = 0; i < k; ++i) { - int64_t scatter_offset = (si - si_shared_begin) * k + i; - int id = scatter_index_shared[scatter_offset]; - if (num_active >= 0 && id >= num_active) { - continue; - } - if (combine_weights_shared[scatter_offset] > 0.f) { - int64_t in_offset = id * dim + di_begin; - phi::Load(dy + in_offset, &in_vec); - for (int64_t j = 0; j < vec_size; ++j) { - out_vec[j] += in_vec[j]; - } - } - } - phi::Store(out_vec, dx + idx); - shared_idx_begin += blockDim.x * gridDim.x * vec_size; - } -} - -template -inline T DivUp(T a, T b) { - return (a + b - 1) / b; -} - -inline int64_t max_shared_s_num(int64_t num_rows, - int64_t dim, - int64_t threads, - int64_t vec_size) { - if ((threads * vec_size) % dim == 0) { - return min(num_rows, threads * vec_size / dim); - } else { - int64_t max_res = DivUp(threads * 4, dim); - for (int64_t idx = 0; idx < num_rows * dim; idx += vec_size * threads) { - int64_t si_start = idx / dim; - int64_t si_end = min(num_rows * dim, idx + vec_size * threads - 1) / dim; - max_res = max(max_res, (si_end - si_start + 1)); - } - return min(num_rows, max_res); - } -} - -template -void gather_with_mask_launcher(const T* dy, // [s*k, d] - const int* scatter_index, // [s, k] - const float* combine_weights, // [s, k] - T* dx, // [s,k,d] - int64_t num_rows, // s - int64_t k, // k - int64_t dim, // d - int64_t num_active, - cudaStream_t stream, - bool use_all2all_permute = false, - int64_t world_size = -1, - int64_t num_local_experts = -1, - int64_t capacity = -1) { - int numel = num_rows * dim; -#ifdef DEBUG_MOE_OP - std::cerr << "[DEBUG-BWD] launch kernel, num_active=" << num_active - << ", num_rows=" << num_rows << ", dim=" << dim << std::endl; -#endif - - int64_t threads = 512; - if (dim % 4 == 0) { - int64_t blocks = DivUp(DivUp(numel, 4), threads); - int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); - size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); - -#ifdef DEBUG_MOE_OP - std::cerr << "[DEBUG-BWD] gather_with_mask with vectorized, s_shared_num=" - << s_shared_num << ", block=" << blocks << std::endl; -#endif - if (!use_all2all_permute) { - gather_with_mask_kernel - <<>>(dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num); - } else { - PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); - gather_with_mask_permute_kernel - <<>>(dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num, - capacity, - world_size, - num_local_experts); - } - } else { - int64_t blocks = DivUp(DivUp(numel, 1), threads); - int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); - size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); - -#ifdef DEBUG_MOE_OP - std::cerr - << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" - << s_shared_num << ", block=" << blocks << std::endl; -#endif - - if (!use_all2all_permute) { - gather_with_mask_kernel - <<>>(dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num); - } else { - gather_with_mask_permute_kernel - <<>>(dy, - scatter_index, - combine_weights, - dx, - num_rows, - k, - dim, - numel, - num_active, - s_shared_num, - capacity, - world_size, - num_local_experts); - } - } -} - -// } // namespace operators -// } // namespace paddle - -#endif diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu index ba37d065f94373..7612d36435880d 100644 --- a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu @@ -19,7 +19,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/contiguous_kernel.h" -#include "paddle/phi/kernels/fused_moe_bwd_op.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index 1b2f474198741f..65c6f95e2944bf 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -17,6 +17,8 @@ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" +namespace phi{ + #ifdef PADDLE_WITH_CUDA template __global__ void gather_with_mask_permute_kernel( @@ -311,3 +313,5 @@ void topk_grad_with_mask_launcher(const T* dy, // [s, k] dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); } #endif + +} // namepsace phi \ No newline at end of file diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index 0fddd734abb2fa..2458902745455a 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -25,6 +25,8 @@ #include "paddle/phi/kernels/moe_kernel_impl.h" #include "paddle/phi/kernels/funcs/math_function.h" +namespace phi{ + template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, @@ -812,3 +814,4 @@ void copy_unpermuted_to_permuted_kernelLauncher( num_cols); } } +} // namespace phi From c5cab082e135df6f3b327a661231fba3ce779dee Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Thu, 29 May 2025 11:23:25 +0000 Subject: [PATCH 57/71] remove win32 supported --- paddle/phi/kernels/CMakeLists.txt | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 40ac78b7e8c565..4b386afd7bc3f2 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -46,7 +46,23 @@ file( if(APPLE OR WIN32) list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") - list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") + list(REMOVE_ITEM + kernel_cu + "sparse/gpu/conv_kernel_igemm.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_kernel.cu" + "gpu/expand_modality_expert_id_kernel.cu" + "gpu/moe_combine_kernel.cu" + "gpu/moe_combine_grad_kernel.cu" + "gpu/cal_aux_loss_kernel.cu" + "gpu/cal_aux_loss_grad_kernel.cu" + "gpu/build_src_rank_and_local_expert_id_kernel.cu" + "gpu/moe_gate_dispatch_kernel.cu" + "gpu/moe_gate_dispatch_grad_kernel.cu" + "gpu/int_bincount.cu" + "gpu/layer_norm_cuda_kernel.cu") endif() if(NOT WITH_DGC) @@ -229,6 +245,7 @@ if(WITH_ROCM) REMOVE_ITEM kernel_gpu "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" "gpu/moe_gate_dispatch_permute_grad_kernel.cu" "gpu/moe_gate_dispatch_permute_kernel.cu" "gpu/expand_modality_expert_id_kernel.cu" From 6a1a3180a9e5e850b62db61823710db1abc7101e Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 03:27:23 +0000 Subject: [PATCH 58/71] check OP type --- paddle/phi/kernels/CMakeLists.txt | 4 ++-- .../gpu/build_src_rank_and_local_expert_id_kernel.cu | 1 + paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu | 10 ++++++++-- paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu | 1 + 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 4b386afd7bc3f2..bc98026bb3a850 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -46,9 +46,9 @@ file( if(APPLE OR WIN32) list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") + list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") list(REMOVE_ITEM - kernel_cu - "sparse/gpu/conv_kernel_igemm.cu" + kernel_gpu "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" "gpu/moe_gate_dispatch_permute_grad_kernel.cu" diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index 07ba0fc48d0784..901b09207ab6f8 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -95,5 +95,6 @@ PD_REGISTER_KERNEL(build_src_rank_and_local_expert_id, GPU, ALL_LAYOUT, phi::BuildSrcRankAndLocalExpertIdKernel, + int, int32_t, int64_t) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu index 1dbc62b1fadc3e..95767290789931 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu @@ -107,5 +107,11 @@ void CalAuxLossGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - cal_aux_loss_grad, GPU, ALL_LAYOUT, phi::CalAuxLossGradKernel, float) {} +PD_REGISTER_KERNEL(cal_aux_loss_grad, + GPU, + ALL_LAYOUT, + phi::CalAuxLossGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu index d5fb1a32836ed5..21cbda4fe0303c 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -269,5 +269,6 @@ PD_REGISTER_KERNEL(cal_aux_loss, ALL_LAYOUT, phi::CalAuxLossKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) {} From d290b3b84adbadff85e05ce80f8389754e23d90b Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 03:32:24 +0000 Subject: [PATCH 59/71] remove optest for WIN & APPLE --- test/legacy_test/CMakeLists.txt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 3282d661b1f6e5..4f8042b86fb60a 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -477,6 +477,17 @@ if(NOT WITH_GPU OR APPLE) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) list(REMOVE_ITEM TEST_OPS test_sparse_conv_igemm_op) + list(REMOVE_ITEM + TEST_OPS + test_incubate_build_src_rank_and_local_expert_id + test_incubate_expand_modality_expert_id + test_incubate_fused_loss + test_incubate_fused_rmsnorm_ext + test_incubate_int_bincount + test_incubate_moe_combine + test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk + test_incubate_moe_gate_dispatch_w_permute_bwd + test_incubate_moe_gate_dispatch_w_permute) endif() if(NOT WITH_CUDNN_FRONTEND) From 31d646515c0737ec04fc9e223ffe89c36880609d Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 05:00:02 +0000 Subject: [PATCH 60/71] fix bug for (int32_t and int) --- .../phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu index 901b09207ab6f8..26837ada694e2d 100644 --- a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -96,5 +96,4 @@ PD_REGISTER_KERNEL(build_src_rank_and_local_expert_id, ALL_LAYOUT, phi::BuildSrcRankAndLocalExpertIdKernel, int, - int32_t, int64_t) {} From 0b990d6bacfad39388d01ec63e272f2e7f176aec Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 08:57:19 +0000 Subject: [PATCH 61/71] rename fused_rms_norm op --- paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu | 4 ++-- paddle/phi/ops/yaml/backward.yaml | 6 +++--- paddle/phi/ops/yaml/ops.yaml | 6 +++--- python/paddle/incubate/nn/functional/fused_rms_norm_ext.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu index 173d4f20b96ffe..0735485b1d435e 100644 --- a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -82,7 +82,7 @@ void RMSLnBwd(const Context &ctx, } // namespace phi PD_REGISTER_KERNEL( - fused_rms_norm, GPU, ALL_LAYOUT, phi::RMSLnFwd, float, double) {} + fused_rms_norm_ext, GPU, ALL_LAYOUT, phi::RMSLnFwd, float, double) {} PD_REGISTER_KERNEL( - fused_rms_norm_grad, GPU, ALL_LAYOUT, phi::RMSLnBwd, float, double) {} + fused_rms_norm_ext_grad, GPU, ALL_LAYOUT, phi::RMSLnBwd, float, double) {} diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 2f67cd8ebcb81d..3a237a78c29b61 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3871,14 +3871,14 @@ func: check_model_nan_inf data_type: out_grad -- backward_op: fused_rms_norm_grad - forward: fused_rms_norm (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) +- backward_op: fused_rms_norm_ext_grad + forward: fused_rms_norm_ext (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) output: Tensor(x_grad), Tensor(scale_grad) infer_meta: func: FusedRMSNormGradInferMeta kernel: - func: fused_rms_norm_grad + func: fused_rms_norm_ext_grad - backward_op: im2sequence_grad forward: im2sequence (Tensor x, Tensor y, int[] kernels, int[] strides = {1, 1}, int[] paddings diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 7905b122556c3c..ec2ec08e1d2390 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -5739,15 +5739,15 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface traits : paddle::dialect::ForwardOnlyTrait -- op: fused_rms_norm +- op: fused_rms_norm_ext args: (Tensor x, Tensor scale, float epsilon) output: Tensor(y), Tensor(invvar) infer_meta: func: FusedRMSNormInferMeta kernel: - func: fused_rms_norm + func: fused_rms_norm_ext data_type: x - backward: fused_rms_norm_grad + backward: fused_rms_norm_ext_grad - op: int_bincount args: (Tensor x, int64_t low, int64_t high, int64_t dtype) diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py index 61f9cce5a440f6..dd3cb392793e46 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py @@ -34,8 +34,8 @@ def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ if in_dynamic_or_pir_mode(): - return _C_ops.fused_rms_norm(x, scale, epsilon) - helper = LayerHelper('fused_rms_norm', **locals()) + return _C_ops.fused_rms_norm_ext(x, scale, epsilon) + helper = LayerHelper('fused_rms_norm_ext', **locals()) dtype = convert_dtype(x.dtype) y = helper.create_variable_for_type_inference(dtype) invvar = helper.create_variable_for_type_inference('float32') @@ -43,7 +43,7 @@ def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): inputs = {'x': x, 'scale': scale} helper.append_op( - type='fused_rms_norm', + type='fused_rms_norm_ext', inputs=inputs, outputs={'y': y, 'invvar': invvar}, attrs={'epsilon': epsilon}, From 2fe0ef3f983ea72b40adc6ce882564af2c2157a4 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 08:58:03 +0000 Subject: [PATCH 62/71] select op test env not for Volta --- test/legacy_test/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 4f8042b86fb60a..611e7461001c75 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -477,6 +477,13 @@ if(NOT WITH_GPU OR APPLE) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) list(REMOVE_ITEM TEST_OPS test_sparse_conv_igemm_op) +endif() + +# if win32 or APPLE or NOT WITH_GPU or CUDA_ARCH_NAME==Volta, skip some op test +if(NOT WITH_GPU + OR WIN32 + OR APPLE + OR (${CUDA_ARCH_NAME} STREQUAL "Volta")) list(REMOVE_ITEM TEST_OPS test_incubate_build_src_rank_and_local_expert_id From e8a30df0556dde8d37d9cd500b2e560929eda031 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Fri, 30 May 2025 09:01:36 +0000 Subject: [PATCH 63/71] fix openblas mistake --- third_party/openblas | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/openblas b/third_party/openblas index 5ef8b1964658f9..5f36f18148603f 160000 --- a/third_party/openblas +++ b/third_party/openblas @@ -1 +1 @@ -Subproject commit 5ef8b1964658f9cb6a6324a06f6a1a022609b0c5 +Subproject commit 5f36f18148603facb6c3540e673610d6b24cbfbb From 4fd618618f1839de22abd9d96b1c080b72f135f4 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 09:32:57 +0000 Subject: [PATCH 64/71] CMake code format --- paddle/phi/kernels/CMakeLists.txt | 33 ++++++++++++++++--------------- test/legacy_test/CMakeLists.txt | 23 ++++++++++----------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index bc98026bb3a850..71ea42a412a0d0 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -47,22 +47,23 @@ file( if(APPLE OR WIN32) list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") - list(REMOVE_ITEM - kernel_gpu - "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" - "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" - "gpu/moe_gate_dispatch_permute_grad_kernel.cu" - "gpu/moe_gate_dispatch_permute_kernel.cu" - "gpu/expand_modality_expert_id_kernel.cu" - "gpu/moe_combine_kernel.cu" - "gpu/moe_combine_grad_kernel.cu" - "gpu/cal_aux_loss_kernel.cu" - "gpu/cal_aux_loss_grad_kernel.cu" - "gpu/build_src_rank_and_local_expert_id_kernel.cu" - "gpu/moe_gate_dispatch_kernel.cu" - "gpu/moe_gate_dispatch_grad_kernel.cu" - "gpu/int_bincount.cu" - "gpu/layer_norm_cuda_kernel.cu") + list( + REMOVE_ITEM + kernel_gpu + "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_kernel.cu" + "gpu/expand_modality_expert_id_kernel.cu" + "gpu/moe_combine_kernel.cu" + "gpu/moe_combine_grad_kernel.cu" + "gpu/cal_aux_loss_kernel.cu" + "gpu/cal_aux_loss_grad_kernel.cu" + "gpu/build_src_rank_and_local_expert_id_kernel.cu" + "gpu/moe_gate_dispatch_kernel.cu" + "gpu/moe_gate_dispatch_grad_kernel.cu" + "gpu/int_bincount.cu" + "gpu/layer_norm_cuda_kernel.cu") endif() if(NOT WITH_DGC) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 611e7461001c75..9b2f557bb10a45 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -484,17 +484,18 @@ if(NOT WITH_GPU OR WIN32 OR APPLE OR (${CUDA_ARCH_NAME} STREQUAL "Volta")) - list(REMOVE_ITEM - TEST_OPS - test_incubate_build_src_rank_and_local_expert_id - test_incubate_expand_modality_expert_id - test_incubate_fused_loss - test_incubate_fused_rmsnorm_ext - test_incubate_int_bincount - test_incubate_moe_combine - test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk - test_incubate_moe_gate_dispatch_w_permute_bwd - test_incubate_moe_gate_dispatch_w_permute) + list( + REMOVE_ITEM + TEST_OPS + test_incubate_build_src_rank_and_local_expert_id + test_incubate_expand_modality_expert_id + test_incubate_fused_loss + test_incubate_fused_rmsnorm_ext + test_incubate_int_bincount + test_incubate_moe_combine + test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk + test_incubate_moe_gate_dispatch_w_permute_bwd + test_incubate_moe_gate_dispatch_w_permute) endif() if(NOT WITH_CUDNN_FRONTEND) From ebd2244d027e9a74541f406b3bbe21c7859b9606 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 09:35:25 +0000 Subject: [PATCH 65/71] fix bugs in CPU --- paddle/phi/kernels/gpu/int_bincount.cu | 1 - paddle/phi/kernels/int_bincount.h | 11 +---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/paddle/phi/kernels/gpu/int_bincount.cu b/paddle/phi/kernels/gpu/int_bincount.cu index 265a267929675c..733a0647bd0381 100644 --- a/paddle/phi/kernels/gpu/int_bincount.cu +++ b/paddle/phi/kernels/gpu/int_bincount.cu @@ -19,7 +19,6 @@ #include #include "cub/device/device_histogram.cuh" #include "paddle/common/flags.h" -#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/empty_kernel.h" // NOLINT diff --git a/paddle/phi/kernels/int_bincount.h b/paddle/phi/kernels/int_bincount.h index 18c44cc520505e..29dfb582a14211 100644 --- a/paddle/phi/kernels/int_bincount.h +++ b/paddle/phi/kernels/int_bincount.h @@ -13,17 +13,8 @@ // limitations under the License. #pragma once -#include -#include -#include "cub/device/device_histogram.cuh" -#include "paddle/common/flags.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/utils/data_type.h" -#include "paddle/phi/kernels/empty_kernel.h" // NOLINT -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { From 10f6058aa8ca41decc42b198dd427d58f44987ed Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 09:36:18 +0000 Subject: [PATCH 66/71] CodeStyle format --- .../kernels/gpu/cal_aux_loss_grad_kernel.cu | 6 +++--- paddle/phi/kernels/moe_fuse_bwd_op.h | 4 ++-- paddle/phi/kernels/moe_fuse_op.h | 6 +++--- .../ernie_utils/moe_all_gather_layer.py | 21 ++++++++++++------- test/legacy_test/ernie_utils/moe_layer.py | 11 ++++++---- test/legacy_test/ernie_utils/top2_gate.py | 6 +++--- 6 files changed, 31 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu index 95767290789931..f0d9951e3654c8 100644 --- a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu @@ -107,10 +107,10 @@ void CalAuxLossGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(cal_aux_loss_grad, +PD_REGISTER_KERNEL(cal_aux_loss_grad, GPU, - ALL_LAYOUT, - phi::CalAuxLossGradKernel, + ALL_LAYOUT, + phi::CalAuxLossGradKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index 65c6f95e2944bf..419f4d3a727cbd 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -17,7 +17,7 @@ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" -namespace phi{ +namespace phi { #ifdef PADDLE_WITH_CUDA template @@ -314,4 +314,4 @@ void topk_grad_with_mask_launcher(const T* dy, // [s, k] } #endif -} // namepsace phi \ No newline at end of file +} // namespace phi diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index 2458902745455a..a06c1347ec215b 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -22,10 +22,10 @@ #include "paddle/common/exception.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/moe_kernel_impl.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" -namespace phi{ +namespace phi { template __launch_bounds__(TPB) __global__ @@ -814,4 +814,4 @@ void copy_unpermuted_to_permuted_kernelLauncher( num_cols); } } -} // namespace phi +} // namespace phi diff --git a/test/legacy_test/ernie_utils/moe_all_gather_layer.py b/test/legacy_test/ernie_utils/moe_all_gather_layer.py index 2f5e4ef911689a..53186898788748 100644 --- a/test/legacy_test/ernie_utils/moe_all_gather_layer.py +++ b/test/legacy_test/ernie_utils/moe_all_gather_layer.py @@ -28,13 +28,14 @@ """ + +from __future__ import annotations + import contextlib import logging -from typing import List, Optional import paddle from paddle import nn -from paddle.distributed.communication.group import Group from paddle.incubate.nn.functional import expand_modality_expert_id from .moe_layer import MOELayer @@ -43,6 +44,10 @@ from src.utils.misc import global_training_logs except ModuleNotFoundError: global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from paddle.distributed.communication.group import Group def profile(_): @@ -72,10 +77,10 @@ class MOEAllGatherLayer(MOELayer): def __init__( self, gate: nn.Layer, - experts: List[nn.Layer], + experts: list[nn.Layer], layer_idx, - shared_experts: Optional[List[nn.Layer]] = None, - dense_experts: Optional[List[nn.Layer]] = None, # no use + shared_experts: list[nn.Layer] | None = None, + dense_experts: list[nn.Layer] | None = None, # no use group: Group = None, recompute=False, enable_logging: bool = False, @@ -112,10 +117,10 @@ class MOEAllGatherLayerV2(MOEAllGatherLayer): def __init__( self, gate: nn.Layer, - experts: List[nn.Layer], + experts: list[nn.Layer], layer_idx, - shared_experts: Optional[List[nn.Layer]] = None, - dense_experts: Optional[List[nn.Layer]] = None, + shared_experts: list[nn.Layer] | None = None, + dense_experts: list[nn.Layer] | None = None, group: Group = None, recompute=False, enable_logging: bool = False, diff --git a/test/legacy_test/ernie_utils/moe_layer.py b/test/legacy_test/ernie_utils/moe_layer.py index 25f5007e7461d0..b5fbf11791d090 100644 --- a/test/legacy_test/ernie_utils/moe_layer.py +++ b/test/legacy_test/ernie_utils/moe_layer.py @@ -20,16 +20,19 @@ Returns: _type_: _description_ """ +from __future__ import annotations + import logging from collections import namedtuple -from typing import List, Optional +from typing import TYPE_CHECKING import paddle import paddle.distributed as dist from paddle import nn from paddle.distributed import fleet -from paddle.distributed.communication.group import Group +if TYPE_CHECKING: + from paddle.distributed.communication.group import Group try: from src.utils.misc import global_training_logs except ModuleNotFoundError: @@ -103,9 +106,9 @@ class MOELayer(nn.Layer): def __init__( self, gate: nn.Layer, - experts: List[nn.Layer], + experts: list[nn.Layer], layer_idx, - shared_experts: Optional[List[nn.Layer]] = None, + shared_experts: list[nn.Layer] | None = None, group: Group = None, recompute=False, enable_logging: bool = False, diff --git a/test/legacy_test/ernie_utils/top2_gate.py b/test/legacy_test/ernie_utils/top2_gate.py index 1121e7625e049a..8ab34b5f04c19b 100644 --- a/test/legacy_test/ernie_utils/top2_gate.py +++ b/test/legacy_test/ernie_utils/top2_gate.py @@ -19,10 +19,10 @@ top2gate """ +from __future__ import annotations import logging from functools import partial -from typing import Tuple import numpy as np @@ -547,7 +547,7 @@ def forward( token_type_ids: Tensor = None, transform_weight: bool = True, # [seq] correction_bias: Tensor = None, # [seq] - ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + ) -> tuple[Tensor, Tensor, Tensor]: # type: ignore """ Args: input: paddle.Tensor[Seq, Dim], hidden-states of layer @@ -898,7 +898,7 @@ def forward( input: Tensor, token_type_ids=None, transform_weight=True, - ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + ) -> tuple[Tensor, Tensor, Tensor]: # type: ignore """ Args: input: paddle.Tensor, hidden-states of layer From 0637a027db4e6c4c83e09b1c201fa7ff11d04c8e Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 10:30:22 +0000 Subject: [PATCH 67/71] fix bugs in CPU --- paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu | 2 +- paddle/phi/kernels/{ => gpu}/layer_norm_cuda_kernel.h | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename paddle/phi/kernels/{ => gpu}/layer_norm_cuda_kernel.h (100%) diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu index 0735485b1d435e..93e831dedf6410 100644 --- a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -20,7 +20,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/layer_norm_cuda_kernel.h" // NOLINT +#include "paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h" // NOLINT namespace phi { // #define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") diff --git a/paddle/phi/kernels/layer_norm_cuda_kernel.h b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h similarity index 100% rename from paddle/phi/kernels/layer_norm_cuda_kernel.h rename to paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h From e4ecf9bdaff0e518dc0cf47363c75003febde31a Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 13:22:39 +0000 Subject: [PATCH 68/71] fix bugs in CPU --- paddle/phi/kernels/moe_fuse_op.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h index a06c1347ec215b..80d51844b49efc 100644 --- a/paddle/phi/kernels/moe_fuse_op.h +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#ifdef PADDLE_WITH_CUDA #include // 包含常用的 thrust 算法 #include #include @@ -815,3 +816,4 @@ void copy_unpermuted_to_permuted_kernelLauncher( } } } // namespace phi +#endif From 54dda4542cc459098d88b144765ca6076296609b Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 13:23:53 +0000 Subject: [PATCH 69/71] skip some op when CUDA<12.0 --- paddle/phi/kernels/CMakeLists.txt | 4 +++- test/legacy_test/CMakeLists.txt | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 71ea42a412a0d0..56bc0ab2e7d7dc 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -44,7 +44,9 @@ file( RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "gpu/*.cu" "gpu/*.cu.cc") -if(APPLE OR WIN32) +if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0)) + OR APPLE + OR WIN32) list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") list( diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 9b2f557bb10a45..5e8669479aa082 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -479,11 +479,12 @@ if(NOT WITH_GPU list(REMOVE_ITEM TEST_OPS test_sparse_conv_igemm_op) endif() -# if win32 or APPLE or NOT WITH_GPU or CUDA_ARCH_NAME==Volta, skip some op test +# New Op only supported by CUDA>=12.0 and Linux, CUDA_ARCH_NAME==Volta, skip some op test if(NOT WITH_GPU OR WIN32 OR APPLE - OR (${CUDA_ARCH_NAME} STREQUAL "Volta")) + OR (${CUDA_ARCH_NAME} STREQUAL "Volta") + OR ((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))) list( REMOVE_ITEM TEST_OPS From 0d1b3d02072d825665b1451022add0740d35dba2 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Tue, 3 Jun 2025 13:34:05 +0000 Subject: [PATCH 70/71] skip op when CUDA<12.0 --- paddle/phi/kernels/CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 56bc0ab2e7d7dc..37bd657297ec13 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -44,11 +44,15 @@ file( RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "gpu/*.cu" "gpu/*.cu.cc") +if(APPLE OR WIN32) + list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") + list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") +endif() + +# New Op only supported by CUDA>=12.0 and Linux if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0)) OR APPLE OR WIN32) - list(REMOVE_ITEM kernel_cu "fusion/gpu/fusion_group_kernel.cu") - list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") list( REMOVE_ITEM kernel_gpu From 8e0817af12268b7869a638eb7ae9bedd8512ee83 Mon Sep 17 00:00:00 2001 From: pesionzhao Date: Wed, 4 Jun 2025 01:54:15 +0000 Subject: [PATCH 71/71] fix bugs in CPU --- paddle/phi/kernels/moe_fuse_bwd_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h index 419f4d3a727cbd..e1a008baecf225 100644 --- a/paddle/phi/kernels/moe_fuse_bwd_op.h +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -13,13 +13,13 @@ // limitations under the License. #pragma once +#ifdef PADDLE_WITH_CUDA #include "paddle/common/exception.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/moe_kernel_impl.h" namespace phi { -#ifdef PADDLE_WITH_CUDA template __global__ void gather_with_mask_permute_kernel( const T* dy, // [s*k, d] @@ -312,6 +312,6 @@ void topk_grad_with_mask_launcher(const T* dy, // [s, k] topk_grad_with_mask<<>>( dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); } -#endif } // namespace phi +#endif