diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a7d368ea869b22..9db6ba0857d42b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1217,6 +1217,106 @@ void MeshgridGradInferMeta(const std::vector& inputs, } } +void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper) { + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + errors::InvalidArgument("The input X should have 2 dimensions" + "But received X's dimension = %d", + x_dim.size())); + PADDLE_ENFORCE_EQ( + (scatter_index.dtype() == phi::DataType::INT32), + true, + errors::InvalidArgument("The input scatter_index type should be int32" + "But received scatter_index type = %s", + scatter_index.dtype())); + grad_x->set_dims(common::make_ddim({x_dim[0], x_dim[1]})); + grad_x->set_dtype(x.dtype()); + grad_combine_weights_helper->set_dims(common::make_ddim( + {combine_weights_shape[0], combine_weights_shape[1], x_dim[1]})); + grad_combine_weights_helper->set_dtype(x.dtype()); +} + +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta( + const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad) { + int64_t num_experts = expert_offset.dims()[0]; + int64_t hidden_size = y_grad.dims()[1]; + int64_t num_rows = scatter_index.dims()[1]; + PADDLE_ENFORCE_GT(num_experts, + 0, + common::errors::InvalidArgument( + "Input num_experts should be greater than 0")); + PADDLE_ENFORCE_EQ((expert_offset.dtype() == phi::DataType::INT64), + true, + common::errors::InvalidArgument( + "Input expert_offset type should be int64")); + if (use_pad) { + PADDLE_ENFORCE_GE(num_experts, + y_grad.dims()[0] / capacity, + common::errors::InvalidArgument( + "Number of experts should be greater than or equal " + "to y_grad.dims()[0]/capacity")); + } else { + PADDLE_ENFORCE_GT(y_grad.dims()[0], + 0, + common::errors::InvalidArgument( + "Input y_grad.dims()[0] should be greater than 0")); + } + combine_weights_grad->set_dims(combine_weights_out_grad.dims()); + combine_weights_grad->set_dtype(phi::DataType::FLOAT32); + x_grad->set_dims({num_rows, hidden_size}); + x_grad->set_dtype(y_grad.dtype()); +} + +void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad) { + auto y_grad_dims = y_grad.dims(); + PADDLE_ENFORCE_EQ( + y_grad_dims[1], + world_size, + common::errors::InvalidArgument( + "The second dimension of y_grad should be equal to world_size, but " + "received y_grad_dims[1] = %d, world_size = %d", + y_grad_dims[1], + world_size)); + int64_t num_local_experts = y_grad_dims[0]; + int64_t num_experts = world_size * num_local_experts; + int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; + int64_t num_rows = scatter_index.dims()[1]; + x_grad->set_dims({num_rows, hidden_size}); + x_grad->set_dtype(y_grad.dtype()); + gate_logits_grad->set_dims({num_rows, num_experts}); + gate_logits_grad->set_dtype(phi::DataType::FLOAT32); +} + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad) { @@ -1887,4 +1987,89 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, value_grad->share_lod(values); } } + +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, + const MetaTensor& seqlen_float, + const MetaTensor& ce, + const MetaTensor& l_aux_loss_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + MetaTensor* gate_prob_grad) { + auto gate_prob_dims = gate_prob.dims(); + + PADDLE_ENFORCE_EQ( + gate_prob.dtype(), + l_aux_loss_grad.dtype(), + errors::InvalidArgument( + "The input out_grad type should be equal to gate_prob type")); + + gate_prob_grad->set_dims({gate_prob_dims}); + gate_prob_grad->set_dtype(gate_prob.dtype()); +} + +void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad) { + auto combine_weights_dims = combine_weights.dims(); + auto scatter_index_dims = scatter_index.dims(); + auto expert_id_dims = expert_id.dims(); + auto y_grad_dims = y_grad.dims(); + auto combine_weights_grad_dims = combine_weights_grad.dims(); + + PADDLE_ENFORCE_EQ(combine_weights_dims.size(), + 2, + errors::InvalidArgument( + "Input combine_weights should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + scatter_index_dims.size(), + 2, + errors::InvalidArgument("Input scatter_index should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + expert_id_dims.size(), + 2, + errors::InvalidArgument("Input expert_id should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + y_grad_dims.size(), + 2, + errors::InvalidArgument("Input y_grad should have 2 dimensions")); + + PADDLE_ENFORCE_EQ(combine_weights_grad_dims.size(), + 2, + errors::InvalidArgument( + "Input combine_weights_grad should have 2 dimensions")); + + int64_t num_experts = y_grad_dims[0] / capacity; + int64_t hidden_size = y_grad_dims[1]; + + int64_t num_rows = scatter_index_dims[1]; + + gate_logits_grad->set_dims(common::make_ddim({num_rows, num_experts})); + gate_logits_grad->set_dtype(phi::DataType::FLOAT32); + + x_grad->set_dims(common::make_ddim({num_rows, hidden_size})); + x_grad->set_dtype(y_grad.dtype()); +} +void FusedRMSNormGradInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& dy, + float epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad) { + x_grad->set_dims(x.dims()); + x_grad->set_dtype(x.dtype()); + scale_grad->set_dims(scale.dims()); + scale_grad->set_dtype(scale.dtype()); +} } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bca0c6f53906f9..72c4c0e69a377c 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -462,6 +462,44 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query, MetaTensor* value_grad, MetaTensor* bias_grad); +void MoeCombineGradInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& grad_y, + MetaTensor* grad_x, + MetaTensor* grad_combine_weights_helper); +// Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, +// Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor +// combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t +// expert_start_index, int64_t expert_end_index) +// output : Tensor(x_grad), Tensor(combine_weights_grad) +void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta( + const MetaTensor& combine_weights_out, + const MetaTensor& scatter_index, + const MetaTensor& scatter_index_rev, + const MetaTensor& expert_offset, + const MetaTensor& expert_offset_local, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + MetaTensor* x_grad, + MetaTensor* combine_weights_grad); + +void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad); + void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad); @@ -680,4 +718,31 @@ void SetValueGradInferMeta(const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* value_grad); +void CalAuxLossGradInferMeta(const MetaTensor& gate_prob, + const MetaTensor& seqlen_float, + const MetaTensor& ce, + const MetaTensor& l_aux_loss_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + MetaTensor* gate_prob_grad); + +void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + const MetaTensor& expert_id, + const MetaTensor& y_grad, + const MetaTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* x_grad, + MetaTensor* gate_logits_grad); + +void FusedRMSNormGradInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& dy, + float epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad); } // namespace phi diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a23fa98a79af7f..39aeeacd82a204 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -4592,6 +4592,20 @@ void WeightDequantizeInferMeta(const MetaTensor& x, out->set_dtype(scale.dtype()); } +void FusedRMSNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar) { + // Y: same shape, dtype, layout as X + y->set_dims(x.dims()); + y->set_dtype(x.dtype()); + // mean & invvar: 1-D length = x.dims()[0] + int64_t rows = x.dims()[0]; + invvar->set_dims(DDim({rows})); + invvar->set_dtype(DataType::FLOAT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 0b4d20862f4773..81041d00c73903 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -790,5 +790,10 @@ void WeightDequantizeInferMeta(const MetaTensor& x, const std::string& algo, const int32_t group_size, MetaTensor* out); +void FusedRMSNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + float epsilon, + MetaTensor* y, + MetaTensor* invvar); } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ef5cd2679006a..b7576d437edfa4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -6273,5 +6273,201 @@ void TopPSamplingInferMeta(const MetaTensor& x, } } +void CalAuxLossInferMeta(const MetaTensor& gate_prob, + const MetaTensor& dispatch_mask, + const MetaTensor& tokens_mask, + const MetaTensor& dispatch_tokens_mask, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + MetaTensor* l_aux_loss, + MetaTensor* seqlen_floats, + MetaTensor* ce) { + auto gate_prob_dims = gate_prob.dims(); + auto dispatch_mask_dims = dispatch_mask.dims(); + + PADDLE_ENFORCE_EQ( + gate_prob_dims.size(), + 2, + errors::InvalidArgument("Input gate_prob_dims should have 2 dimensions")); + + PADDLE_ENFORCE_EQ(gate_prob_dims[0] >= gate_prob_dims[1], + true, + errors::InvalidArgument( + "The value of gate_prob_dims[0] should be greater than " + "or equal to that of gate_prob_dims[1].")); + + PADDLE_ENFORCE_EQ( + gate_prob_dims[1] <= 1024, + true, + errors::InvalidArgument( + "The value of gate_prob_dims[1] should be less than 1024.")); + + PADDLE_ENFORCE_EQ( + (dispatch_mask_dims.size() == 1) || (dispatch_mask_dims.size() == 2), + true, + errors::InvalidArgument( + "Input dispatch_mask_dims should have 1 or 2 dimensions")); + + if (dispatch_mask_dims.size() == 1) { + PADDLE_ENFORCE_EQ( + dispatch_mask_dims[0], + gate_prob_dims[1], + errors::InvalidArgument("The value of dispatch_mask_shape.back() " + "should be equal to gate_prob_shape.back().")); + } else { + PADDLE_ENFORCE_EQ( + dispatch_mask_dims[1], + gate_prob_dims[1], + errors::InvalidArgument("The value of dispatch_mask_shape.back() " + "should be equal to gate_prob_shape.back().")); + } + + PADDLE_ENFORCE_EQ( + dispatch_mask.dtype(), + phi::DataType::INT64, + errors::InvalidArgument("The input dispatch_mask type should be INT64")); + + if (tokens_mask) { + auto tokens_mask_dims = tokens_mask.dims(); + PADDLE_ENFORCE_EQ( + tokens_mask_dims.size(), + 1, + errors::InvalidArgument("Input tokens_mask should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + tokens_mask.dtype(), + gate_prob.dtype(), + errors::InvalidArgument( + "The input tokens_mask type should be equal to gate_prob type")); + + PADDLE_ENFORCE_EQ( + tokens_mask_dims[0], + gate_prob_dims[0], + errors::InvalidArgument( + "The 0-th dimension of tokens_mask [%d] " + "must match that of the 0-th dimension of gate_prob [%d].", + tokens_mask_dims[0], + gate_prob_dims[0])); + } + + if (dispatch_tokens_mask) { + auto dispatch_tokens_mask_dims = dispatch_tokens_mask.dims(); + + PADDLE_ENFORCE_EQ( + dispatch_tokens_mask_dims.size(), + 1, + errors::InvalidArgument( + "Input dispatch_tokens_mask should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + dispatch_tokens_mask.dtype(), + phi::DataType::BOOL, + errors::InvalidArgument( + "The input dispatch_tokens_mask type should be BOOL")); + } + + l_aux_loss->set_dims(phi::make_ddim({})); + l_aux_loss->set_dtype(gate_prob.dtype()); + + seqlen_floats->set_dims(phi::make_ddim({})); + seqlen_floats->set_dtype(gate_prob.dtype()); + + ce->set_dims({gate_prob_dims[1]}); + ce->set_dtype(gate_prob.dtype()); +} + +void MoeGateDispatchInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id) { + auto x_dims = x.dims(); + auto gate_logits_dims = gate_logits.dims(); + + const int64_t num_rows = x_dims[0]; + const int64_t num_experts = gate_logits_dims[1]; + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + errors::InvalidArgument("Input x should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + gate_logits_dims.size(), + 2, + errors::InvalidArgument("Input gate_logits should have 2 dimensions")); + + PADDLE_ENFORCE_EQ( + x_dims[0], + gate_logits_dims[0], + errors::InvalidArgument( + "The 0-th dimension of x [%d] " + "must match that of the 0-th dimension gate_logits [%d].", + x_dims[0], + gate_logits_dims[0])); + + PADDLE_ENFORCE_EQ(gate_logits_dims[1] >= k, + true, + errors::InvalidArgument( + "The 1-th dimension of gate_logits [%d] " + "must be greater than or equal to that of k [%d].", + gate_logits_dims[1], + k)); + + if (corr_bias) { + auto corr_bias_dims = corr_bias.dims(); + PADDLE_ENFORCE_EQ( + corr_bias.dtype(), + phi::DataType::FLOAT32, + errors::InvalidArgument( + "The dtype of rotary_tensor must be float32, but got %d", + corr_bias.dtype())); + + PADDLE_ENFORCE_EQ( + corr_bias_dims.size(), + 1, + errors::InvalidArgument("Input corr_bias should have 1 dimensions")); + + PADDLE_ENFORCE_EQ( + corr_bias_dims[0], + gate_logits_dims[1], + errors::InvalidArgument( + "The 0-th dimension of x [%d] " + "must match that of the 0-th dimension gate_logits [%d].", + corr_bias_dims[0], + gate_logits_dims[1])); + } + + std::vector y_dims; + if (use_pad) { + y_dims = {num_experts * capacity, x_dims[1]}; + } else { + y_dims = {num_rows * k, x_dims[1]}; + } + + y->set_dims(common::make_ddim(y_dims)); + y->set_dtype(x.dtype()); + + combine_weights->set_dims(common::make_ddim({num_rows, k})); + combine_weights->set_dtype(phi::DataType::FLOAT32); + + scatter_index->set_dims(common::make_ddim({k, num_rows})); + scatter_index->set_dtype(phi::DataType::INT32); + + expert_offset->set_dims(common::make_ddim({num_experts})); + expert_offset->set_dtype(phi::DataType::INT64); + + expert_id->set_dims(common::make_ddim({num_rows, k})); + expert_id->set_dtype(phi::DataType::INT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index dfe1af6754aa9d..9e364f96612e3b 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -1284,4 +1284,28 @@ void TopPSamplingInferMeta(const MetaTensor& x, MetaTensor* topk_scores, MetaTensor* topk_ids); +void CalAuxLossInferMeta(const MetaTensor& gate_prob, + const MetaTensor& dispatch_mask, + const MetaTensor& tokens_mask, + const MetaTensor& dispatch_tokens_mask, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + MetaTensor* l_aux_loss, + MetaTensor* seqlen_floats, + MetaTensor* ce); + +void MoeGateDispatchInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id); + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a66797a4d22437..df96838f1132bd 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1612,6 +1612,243 @@ void MultiClassNMSInferMeta(const MetaTensor& bboxes, nms_rois_num->set_dtype(DataType::INT32); } +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y) { + auto x_dim = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + PADDLE_ENFORCE_EQ(x_dim.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 1, but " + "received dimensions of" + "Input(x) is [%d]", + x_dim.size())); + // maybe there is more conditions here.... + y->set_dims(phi::make_ddim({combine_weights_shape[0], x_dim[1]})); + y->set_dtype(x.dtype()); +} + +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta( + const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); + auto combine_weights_dims = combine_weights.dims(); + PADDLE_ENFORCE_EQ( + combine_weights_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(combine_weights) must be 2, but " + "received dimensions of" + "Input(combine_weights) is [%d]", + combine_weights_dims.size())); + PADDLE_ENFORCE_EQ(combine_weights_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(combine_weights) must " + "be equal to the first " + "dimension of Input(x), but received " + "Input(combine_weights) shape is [%d]," + "Input(x) shape is [%d]", + combine_weights_dims[0], + x_dims[0])); + PADDLE_ENFORCE_GT(expert_end_index, + expert_start_index, + common::errors::InvalidArgument( + "expert_end_index must be greater than " + "expert_start_index, but received " + "expert_end_index = %d, expert_start_index = %d", + expert_end_index, + expert_start_index)); + PADDLE_ENFORCE_EQ( + combine_weights.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument("The dtype of Input(combine_weights) " + "must be FLOAT32, but received %s", + combine_weights.dtype())); + PADDLE_ENFORCE_EQ( + expert_id.dtype(), + phi::DataType::INT32, + common::errors::InvalidArgument( + "The dtype of Input(expert_id) must be INT32, but received %s", + expert_id.dtype())); + PADDLE_ENFORCE_GT(k, + 0, + common::errors::InvalidArgument( + "k must be greater than 0, but received k = %d", k)); + PADDLE_ENFORCE_GT( + x_dims[0], + 0, + common::errors::InvalidArgument( + "num_rows must be greater than 0, but received num_rows = %d", + x_dims[0])); + PADDLE_ENFORCE_GE(num_experts, + k, + common::errors::InvalidArgument( + "num_experts must be greater than or equal to k, but " + "received num_experts = %d, k = %d", + num_experts, + k)); + PADDLE_ENFORCE_EQ( + !reverse_token_drop || !use_pad, + true, + common::errors::InvalidArgument( + "use_pad must be false when reverse_token_drop is true, but received " + "use_pad = %d, reverse_token_drop = %d", + use_pad, + reverse_token_drop)); + PADDLE_ENFORCE_EQ( + combine_weights.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument("The dtype of Input(combine_weights) " + "must be FLOAT32, but received %s", + combine_weights.dtype())); + // int64_t num_experts_diff = expert_end_index - expert_start_index; + int64_t num_rows = x_dims[0]; + // if (use_pad) + // y->set_dims({num_experts_diff * capacity, x_dims[1]}) ; + y->set_dims({-1, x_dims[1]}); + y->set_dtype(x.dtype()); + scatter_index->set_dims({k, num_rows}); + scatter_index->set_dtype(phi::DataType::INT32); + scatter_index_rev->set_dims({num_experts * capacity}); + scatter_index_rev->set_dtype(phi::DataType::INT32); + expert_offset->set_dims({num_experts}); + expert_offset->set_dtype(phi::DataType::INT64); + expert_nums_local->set_dims({num_experts}); + expert_nums_local->set_dtype(phi::DataType::INT64); + combine_weights_out->set_dims(combine_weights_dims); + combine_weights_out->set_dtype(combine_weights.dtype()); + // combine_weights_out->share_meta(combine_weights); +} + +void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(x) must be 2, but " + "received dimensions of" + "Input(x) is [%d]", + x_dims.size())); + auto gate_logits_dims = gate_logits.dims(); + PADDLE_ENFORCE_EQ(gate_logits_dims.size(), + 2, + common::errors::InvalidArgument( + "The dimensions of Input(gate_logits) must be 2, but " + "received dimensions of" + "Input(gate_logits) is [%d]", + gate_logits_dims.size())); + PADDLE_ENFORCE_EQ(gate_logits_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The first dimensions of Input(gate_logits) must be " + "equal to the first " + "dimension of Input(x), but received " + "Input(gate_logits) shape is [%d]," + "Input(x) shape is [%d]", + gate_logits_dims[0], + x_dims[0])); + PADDLE_ENFORCE_EQ( + gate_logits_dims[1] % world_size, + 0, + common::errors::InvalidArgument( + "The number of experts (the second dimension of Input(gate_logits)) " + "must be divisible by world_size, but received " + "num_experts = %d, world_size = %d", + gate_logits_dims[1], + world_size)); + + PADDLE_ENFORCE_GE(gate_logits_dims[1], + k, + common::errors::InvalidArgument( + "The number of experts ((the second dimension of " + "Input(gate_logits))) must be greater than or equal to " + "k, but received " + "num_experts = %d, k = %d", + gate_logits_dims[1], + k)); + + PADDLE_ENFORCE_EQ( + gate_logits.dtype(), + phi::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(gate_logits) must be FLOAT32, but received %s", + gate_logits.dtype())); + + if (corr_bias) { + auto corr_bias_dims = corr_bias.dims(); + PADDLE_ENFORCE_EQ( + corr_bias_dims.size(), + 1, + common::errors::InvalidArgument( + "The dimensions of Input(corr_bias) must be 1, but received " + "dimensions of Input(corr_bias) is [%d]", + corr_bias_dims.size())); + PADDLE_ENFORCE_EQ( + corr_bias_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The dimensions of Input(corr_bias) must be equal to the first " + "dimension of Input(x), but received Input(corr_bias) first " + "dimension is [%d]," + "Input(x) first dimension is [%d]", + corr_bias_dims[0], + x_dims[0])); + PADDLE_ENFORCE_EQ( + corr_bias.dtype(), + paddle::DataType::FLOAT32, + common::errors::InvalidArgument( + "The dtype of Input(corr_bias) must be FLOAT32, but received %s", + corr_bias.dtype())); + } + int64_t num_experts = gate_logits_dims[1]; + int64_t num_local_experts = num_experts / world_size; + int64_t num_rows = x_dims[0]; + y->set_dims({num_local_experts, world_size, capacity, x_dims[1]}); + y->set_dtype(x.dtype()); + combine_weights->set_dims({num_rows, k}); + combine_weights->set_dtype(phi::DataType::FLOAT32); + scatter_index->set_dims({k, num_rows}); + scatter_index->set_dtype(phi::DataType::INT32); + expert_offset->set_dims({num_experts}); + expert_offset->set_dtype(phi::DataType::INT64); + expert_id->set_dims({num_rows, k}); + expert_id->set_dtype(phi::DataType::INT32); +} + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 14dd2685949573..c462d51ecd0d9a 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -269,6 +269,41 @@ void MatrixRankAtolRtolInferMeta(const MetaTensor& x, bool hermitian, MetaTensor* out); +void MoeCombineInferMeta(const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& scatter_index, + MetaTensor* y); + +void MoeGateDispatchPartialNoSoftmaxTopKInferMeta( + const MetaTensor& x, + const MetaTensor& combine_weights, + const MetaTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + MetaTensor* y, + MetaTensor* combine_weights_out, + MetaTensor* scatter_index, + MetaTensor* scatter_index_rev, + MetaTensor* expert_offset, + MetaTensor* expert_nums_local); + +void MoeGateDispatchPermuteInferMeta(const MetaTensor& x, + const MetaTensor& gate_logits, + const MetaTensor& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + MetaTensor* y, + MetaTensor* combine_weights, + MetaTensor* scatter_index, + MetaTensor* expert_offset, + MetaTensor* expert_id); + void MovingAverageAbsMaxScaleInferMeta(const MetaTensor& x, const MetaTensor& in_accum, const MetaTensor& in_state, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index abf1823d67c86e..8f5f177f1c3b97 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1365,6 +1365,36 @@ void ExpandInferMeta(const MetaTensor& x, #undef EXPAND_MAX_RANK_SUPPORTED } +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + MetaTensor* expert_id_out) { + auto expert_id_dims = expert_id.dims(); + PADDLE_ENFORCE_EQ( + expert_id_dims.size(), + 2, + common::errors::InvalidArgument( + "The input expert_id's dimensions size should be 2. But received " + "expert_id's dimensions size=[%d], expert_id's dimensions=[%s].", + expert_id_dims.size(), + expert_id_dims)); + PADDLE_ENFORCE_EQ( + expert_id.dtype() == DataType::INT32 || + expert_id.dtype() == DataType::INT64, + true, + common::errors::InvalidArgument( + "The dtype of expert_id should be INT32 or INT64. But received" + "dtype=%s.", + DataTypeToString(expert_id.dtype()))); + + int64_t seqlen = expert_id_dims[0]; + int64_t k = expert_id_dims[1]; + expert_id_out->set_dims(common::make_ddim({seqlen, k})); + expert_id_out->set_dtype(expert_id.dtype()); +} + void FakeChannelWiseQuantizeAbsMaxInferMeta(const MetaTensor& x, int bit_length, int round_type, @@ -6164,6 +6194,49 @@ void ArrayPopInferMeta(const MetaTensor& array, out->set_dtype(array.dtype()); } +void BuildSrcRankAndLocalExpertIdInferMeta( + const MetaTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + MetaTensor* src_rank, + MetaTensor* local_expert_id) { + int64_t token_num = + std::accumulate(expert_num_global.begin(), expert_num_global.end(), 0); + + PADDLE_ENFORCE_EQ( + expert_num_global_tensor.dtype(), + phi::DataType::INT64, + errors::InvalidArgument( + "The input expert_num_global_tensor type should be INT64")); + + src_rank->set_dims({token_num}); + src_rank->set_dtype(DataType::INT32); + + local_expert_id->set_dims({token_num}); + local_expert_id->set_dtype(DataType::INT32); +} + +void IntBincountInferMeta(const MetaTensor& x, + int64_t low, + int64_t high, + int64_t dtype, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + x.dims().size(), + 1, + errors::InvalidArgument( + "The input 'x' of int_bincount must be a 1-D Tensor, but got %u-D.", + x.dims().size())); + PADDLE_ENFORCE_GT( + high, + low, + errors::InvalidArgument("Attr high (%d) must be > low (%d).", high, low)); + int64_t bin_count = high - low + 1; + + out->set_dims(phi::make_ddim({bin_count})); + out->set_dtype(x.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6e9454e9fdac9d..e6c16debb0a7ee 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -248,6 +248,13 @@ void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out); +void ExpandModalityExpertIdInferMeta(const MetaTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + MetaTensor* expert_id_out); + void FakeChannelWiseQuantizeAbsMaxInferMeta(const MetaTensor& x, int bit_length, int round_type, @@ -998,4 +1005,17 @@ void ArrayPopInferMeta(const MetaTensor& array, MetaTensor* out, MetaConfig config = MetaConfig()); +void BuildSrcRankAndLocalExpertIdInferMeta( + const MetaTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + MetaTensor* src_rank, + MetaTensor* local_expert_id); + +void IntBincountInferMeta(const MetaTensor& x, + int64_t low, + int64_t high, + int64_t dtype, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 2f45770291bd58..37bd657297ec13 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -49,6 +49,29 @@ if(APPLE OR WIN32) list(REMOVE_ITEM kernel_cu "sparse/gpu/conv_kernel_igemm.cu") endif() +# New Op only supported by CUDA>=12.0 and Linux +if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0)) + OR APPLE + OR WIN32) + list( + REMOVE_ITEM + kernel_gpu + "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_kernel.cu" + "gpu/expand_modality_expert_id_kernel.cu" + "gpu/moe_combine_kernel.cu" + "gpu/moe_combine_grad_kernel.cu" + "gpu/cal_aux_loss_kernel.cu" + "gpu/cal_aux_loss_grad_kernel.cu" + "gpu/build_src_rank_and_local_expert_id_kernel.cu" + "gpu/moe_gate_dispatch_kernel.cu" + "gpu/moe_gate_dispatch_grad_kernel.cu" + "gpu/int_bincount.cu" + "gpu/layer_norm_cuda_kernel.cu") +endif() + if(NOT WITH_DGC) list(REMOVE_ITEM kernel_gpu "gpu/dgc_kernel.cu") endif() @@ -228,7 +251,19 @@ if(WITH_ROCM) list( REMOVE_ITEM kernel_gpu + "gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu" + "gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_grad_kernel.cu" + "gpu/moe_gate_dispatch_permute_kernel.cu" + "gpu/expand_modality_expert_id_kernel.cu" + "gpu/moe_combine_kernel.cu" + "gpu/moe_combine_grad_kernel.cu" "gpu/affine_grid_grad_kernel.cu" + "gpu/cal_aux_loss_kernel.cu" + "gpu/cal_aux_loss_grad_kernel.cu" + "gpu/build_src_rank_and_local_expert_id_kernel.cu" + "gpu/moe_gate_dispatch_kernel.cu" + "gpu/moe_gate_dispatch_grad_kernel.cu" "gpu/apply_per_channel_scale_kernel.cu" "gpu/calc_reduced_attn_kernel.cu" "gpu/eigvalsh_kernel.cu" @@ -236,7 +271,9 @@ if(WITH_ROCM) "gpu/matrix_rank_kernel.cu" "gpu/matrix_rank_tol_kernel.cu" "gpu/svd_kernel.cu" - "gpu/cuda_gemm_kernel.cu") + "gpu/cuda_gemm_kernel.cu" + "gpu/int_bincount.cu" + "gpu/layer_norm_cuda_kernel.cu") endif() # Remove AP kernel when CINN is not enabled. diff --git a/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h new file mode 100644 index 00000000000000..8fd7f13e6f6649 --- /dev/null +++ b/paddle/phi/kernels/build_src_rank_and_local_expert_id_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void BuildSrcRankAndLocalExpertIdKernel( + const Context& dev_ctx, + const DenseTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + DenseTensor* src_rank, + DenseTensor* local_expert_id); +} // namespace phi diff --git a/paddle/phi/kernels/cal_aux_loss_grad_kernel.h b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h new file mode 100644 index 00000000000000..3b1cc9dfe66e3a --- /dev/null +++ b/paddle/phi/kernels/cal_aux_loss_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const DenseTensor& l_aux_loss_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cal_aux_loss_kernel.h b/paddle/phi/kernels/cal_aux_loss_kernel.h new file mode 100644 index 00000000000000..3a73a1a376cab7 --- /dev/null +++ b/paddle/phi/kernels/cal_aux_loss_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CalAuxLossKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& dispatch_mask, + const paddle::optional& tokens_mask, + const paddle::optional& dispatch_tokens_mask, + int64_t num_experts, + bool use_group, + int64_t moe_k, + float clip_min, + DenseTensor* l_aux_loss, + DenseTensor* seqlen_float, + DenseTensor* ce); + +} // namespace phi diff --git a/paddle/phi/kernels/expand_modality_expert_id_kernel.h b/paddle/phi/kernels/expand_modality_expert_id_kernel.h new file mode 100644 index 00000000000000..1d0d308d33fb3f --- /dev/null +++ b/paddle/phi/kernels/expand_modality_expert_id_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ExpandModalityExpertIDKernel(const Context& dev_ctx, + const DenseTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + DenseTensor* expert_id_out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index f14b2af8c72609..e5361b836e3c81 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -102,7 +102,7 @@ __device__ __forceinline__ float exp_func(float a) { template <> __device__ __forceinline__ half exp_func(half a) { -#if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if defined(__HIPCC__) || (__CUDA_ARCH__ > 600) return hexp(a); #else return FromFloat(expf(ToFloat(a))); @@ -144,7 +144,7 @@ struct KeyValuePair { const half2 a2 = __halves2half2(key, value); const half2 b2 = __halves2half2(a.key, a.value); #ifdef PADDLE_WITH_CUDA -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if (__CUDA_ARCH__ > 600) const half2 res = __hadd2(a2, b2); #else float a2_1 = __low2float(a2); diff --git a/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu new file mode 100644 index 00000000000000..26837ada694e2d --- /dev/null +++ b/paddle/phi/kernels/gpu/build_src_rank_and_local_expert_id_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void build_srcrank_and_local_expert_id_kernel( + T* src_rank, + T* local_expert_id, + const U* expert_num, + int64_t total_num, + int64_t num_total_experts, + int64_t num_local_experts) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_total_experts) return; + int64_t start = 0; + int64_t end = 0; + for (int64_t i = 0; i < num_total_experts; ++i) { + end += expert_num[i]; + if (i == tid) { + break; + } + start += expert_num[i]; + } + for (int64_t i = start; i != end; ++i) { + src_rank[i] = static_cast(tid / num_local_experts); + local_expert_id[i] = static_cast(tid % num_local_experts); + } +} + +template +void build_srcrank_and_local_expert_id(T* src_rank, + T* local_expert_id, + const U* expert_num, + int64_t total_num, + int64_t num_total_experts, + int64_t num_local_experts, + cudaStream_t stream) { + int64_t threads_per_block = 32; + int64_t blocks = + (num_total_experts + threads_per_block - 1) / threads_per_block; + build_srcrank_and_local_expert_id_kernel + <<>>(src_rank, + local_expert_id, + expert_num, + total_num, + num_total_experts, + num_local_experts); +} + +template +void BuildSrcRankAndLocalExpertIdKernel( + const Context& dev_ctx, + const DenseTensor& expert_num_global_tensor, + const std::vector& expert_num_global, + int64_t num_local_experts, + DenseTensor* src_rank, + DenseTensor* local_expert_id) { + int64_t token_num = + std::accumulate(expert_num_global.begin(), expert_num_global.end(), 0); + + const int64_t* expert_num_global_tensor_data = + expert_num_global_tensor.data(); + + // Hard coded as ernie-core did. + int* src_rank_data = dev_ctx.template Alloc(src_rank); + int* local_expert_id_data = dev_ctx.template Alloc(local_expert_id); + + build_srcrank_and_local_expert_id(src_rank_data, + local_expert_id_data, + expert_num_global_tensor_data, + token_num, + expert_num_global.size(), + num_local_experts, + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(build_src_rank_and_local_expert_id, + GPU, + ALL_LAYOUT, + phi::BuildSrcRankAndLocalExpertIdKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu new file mode 100644 index 00000000000000..f0d9951e3654c8 --- /dev/null +++ b/paddle/phi/kernels/gpu/cal_aux_loss_grad_kernel.cu @@ -0,0 +1,117 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cal_aux_loss_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { + +template +__global__ void cal_aux_loss_grad_kernel(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, + const int64_t col_gate_prob, + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad) { + T ce_val = ce[threadIdx.x]; + T l_aux_grad = *out_grad; + if (use_group) { + l_aux_grad = l_aux_grad / static_cast(moe_k); + } + l_aux_grad *= static_cast(num_experts); + + gate_prob_grad[blockIdx.x * col_gate_prob + threadIdx.x] = + (ce_val * l_aux_grad) / (*seqlen_float); +} + +template +void cal_aux_loss_grad(const T* out_grad, + const T* gate_prob, + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const T* seqlen_float, + const T* ce, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + T* gate_prob_grad, + cudaStream_t stream) { + cal_aux_loss_grad_kernel + <<>>(out_grad, + gate_prob, + row_gate_prob, + col_gate_prob, + seqlen_float, + ce, + num_experts, + use_group, + moe_k, + gate_prob_grad); +} + +template +void CalAuxLossGradKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& seqlen_float, + const DenseTensor& ce, + const DenseTensor& l_aux_loss_grad, + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + DenseTensor* gate_prob_grad) { + auto gate_prob_dims = gate_prob.dims(); + + const T* l_aux_loss_grad_data = l_aux_loss_grad.data(); + const T* gate_prob_data = gate_prob.data(); + const T* seqlen_float_data = seqlen_float.data(); + const T* ce_data = ce.data(); + + int64_t row_gate_prob = gate_prob_dims[0]; + int64_t col_gate_prob = gate_prob_dims[1]; + + T* gate_prob_grad_data = dev_ctx.template Alloc(gate_prob_grad); + + cal_aux_loss_grad(l_aux_loss_grad_data, + gate_prob_data, + row_gate_prob, + col_gate_prob, + seqlen_float_data, + ce_data, + num_experts, + use_group, + moe_k, + gate_prob_grad_data, + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cal_aux_loss_grad, + GPU, + ALL_LAYOUT, + phi::CalAuxLossGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu new file mode 100644 index 00000000000000..21cbda4fe0303c --- /dev/null +++ b/paddle/phi/kernels/gpu/cal_aux_loss_kernel.cu @@ -0,0 +1,274 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cal_aux_loss_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace phi { + +template +__global__ void cal_aux_loss_kernel( + const T* gate_prob, /*[s, e]*/ + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const int64_t* dispatch_mask, /*[s, e] or [e]*/ + const int64_t row_dispatch_mask, + const int64_t col_dispatch_mask, + const T* tokens_mask, /*[s]*/ + const bool* dispatch_tokens_mask, + const int64_t dispatch_tokens_mask_len, /*global_seq_len*/ + const int64_t num_experts, + const bool use_group, + const int64_t moe_k, + const float clip_min, + T* l_aux_loss, /*output*/ + T* seqlen_float, + T* ce) { + extern __shared__ int64_t aux_loss_shared[]; + static __shared__ float shared_float[1]; + + float scale_val = 1.f; + + // 算seqlen_float + float seqlen_float_f = 0.f; + if (dispatch_tokens_mask) { + float local_seqlen_float_f = 0.f; + int64_t num_k = (dispatch_tokens_mask_len + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= dispatch_tokens_mask_len) continue; + bool mask = dispatch_tokens_mask[k * blockDim.x + threadIdx.x]; + local_seqlen_float_f += static_cast(mask); + } + seqlen_float_f = + phi::funcs::BlockReduceSum(local_seqlen_float_f, 0xFFFFFFFF); + + // 算scale_val + if (tokens_mask && row_gate_prob != dispatch_tokens_mask_len) { + float sum_tokens_mask = 0.f; + float local_sum_tokens_mask = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; + T mask = tokens_mask[k * blockDim.x + threadIdx.x]; + local_sum_tokens_mask += static_cast(mask); + } + sum_tokens_mask = + phi::funcs::BlockReduceSum(local_sum_tokens_mask, 0xFFFFFFFF); + if (threadIdx.x == 0) { + shared_float[0] = seqlen_float_f / max(sum_tokens_mask, clip_min); + } + __syncthreads(); + scale_val = shared_float[0]; + } + + } else if (tokens_mask) { + float local_seqlen_float_f = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t k = 0; k < num_k; ++k) { + if (k * blockDim.x + threadIdx.x >= row_gate_prob) continue; + T mask = tokens_mask[k * blockDim.x + threadIdx.x]; + local_seqlen_float_f += static_cast(mask); + } + seqlen_float_f = + phi::funcs::BlockReduceSum(local_seqlen_float_f, 0xFFFFFFFF); + } else { + seqlen_float_f = static_cast(row_gate_prob) / + static_cast(num_experts) * + static_cast(col_gate_prob); + } + + if (threadIdx.x == 0) { + shared_float[0] = max(seqlen_float_f, clip_min); + } + __syncthreads(); + seqlen_float_f = shared_float[0]; + + __syncthreads(); + // 处理dispatch_mask + if (col_dispatch_mask > 1) { + int64_t num_k = (row_dispatch_mask + blockDim.x - 1) / blockDim.x; + + for (int64_t e = 0; e < col_dispatch_mask; e++) { + int64_t local_sum_val = 0.f; + for (int64_t k = 0; k < num_k; ++k) { + int64_t mask_val = 0; + if (k * blockDim.x + threadIdx.x < row_dispatch_mask) { + mask_val = static_cast( + dispatch_mask[(k * blockDim.x + threadIdx.x) * col_dispatch_mask + + e]); + } + local_sum_val += mask_val; + } + int64_t sum_val = + phi::funcs::BlockReduceSum(local_sum_val, 0xFFFFFFFF); + if (threadIdx.x == 0) { + aux_loss_shared[e] = sum_val; + } + } + } else { + if (threadIdx.x < row_dispatch_mask) { + aux_loss_shared[threadIdx.x] = + static_cast(dispatch_mask[threadIdx.x]); + } + } + + // 算me和l_aux + float l_aux = 0.f; + int64_t num_k = (row_gate_prob + blockDim.x - 1) / blockDim.x; + for (int64_t e = 0; e < col_gate_prob; e++) { + float local_sum_val = 0.f; + for (int64_t k = 0; k < num_k; ++k) { + float gate_prob_val = 0.f; + if (k * blockDim.x + threadIdx.x < row_gate_prob) { + gate_prob_val = static_cast( + gate_prob[(k * blockDim.x + threadIdx.x) * col_gate_prob + e]); + } + local_sum_val += gate_prob_val; + } + float sum_val = + phi::funcs::BlockReduceSum(local_sum_val, 0xFFFFFFFF); + if (threadIdx.x == 0) { + float ce_val = static_cast(aux_loss_shared[e]) / seqlen_float_f; + float me_val = sum_val / seqlen_float_f; + l_aux += ce_val * me_val * static_cast(num_experts); + ce[e] = static_cast(ce_val); + } + } + + if (threadIdx.x == 0) { + if (use_group) { + l_aux /= static_cast(moe_k); + } + l_aux = l_aux * scale_val; + *l_aux_loss = static_cast(l_aux); + *seqlen_float = static_cast(seqlen_float_f); + } +} + +template +void cal_aux_loss(const T* gate_prob, + const int64_t row_gate_prob, /*seq_len*/ + const int64_t col_gate_prob, /*expert_num*/ + const int64_t* dispatch_mask, + const int64_t row_dispatch_mask, + const int64_t col_dispatch_mask, + const T* tokens_mask, + const bool* dispatch_tokens_mask, + const int64_t dispatch_tokens_mask_len, /*global_seq_len*/ + const int64_t num_experts, /*global_num_experts*/ + const bool use_group, + const int64_t moe_k, + const float clip_min, + T* l_aux_loss, /*output*/ + T* seqlen_float, + T* ce, + cudaStream_t stream) { + int64_t threads = 1024; + threads = std::min(row_gate_prob, threads); + cal_aux_loss_kernel + <<<1, threads, col_gate_prob * sizeof(int64_t), stream>>>( + gate_prob, + row_gate_prob, + col_gate_prob, + dispatch_mask, + row_dispatch_mask, + col_dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + dispatch_tokens_mask_len, + num_experts, + use_group, + moe_k, + clip_min, + l_aux_loss, + seqlen_float, + ce); +} + +template +void CalAuxLossKernel(const Context& dev_ctx, + const DenseTensor& gate_prob, + const DenseTensor& dispatch_mask, + const paddle::optional& tokens_mask, + const paddle::optional& dispatch_tokens_mask, + int64_t num_experts, + bool use_group, + int64_t moe_k, + float clip_min, + DenseTensor* l_aux_loss, + DenseTensor* seqlen_float, + DenseTensor* ce) { + auto gate_prob_dims = gate_prob.dims(); + auto dispatch_mask_dims = dispatch_mask.dims(); + + int64_t dispatch_tokens_mask_len = 0; + auto dispatch_tokens_mask_ptr = dispatch_tokens_mask.get_ptr(); + if (dispatch_tokens_mask) { + const auto mask_dims = dispatch_tokens_mask_ptr->dims(); + const auto dim_size = mask_dims.size(); + const bool is_not_zero_size = (dim_size > 0); + if (is_not_zero_size) { + dispatch_tokens_mask_len = dispatch_tokens_mask_ptr->dims()[0]; + } else { + dispatch_tokens_mask_len = 0; + } + } + + /* + T* l_aux_loss_data = dev_ctx.template Alloc(l_aux_loss); + T* seqlen_float_data = dev_ctx.template Alloc(seqlen_float); + T* ce_data = dev_ctx.template Alloc(ce); + */ + dev_ctx.template Alloc(l_aux_loss); + dev_ctx.template Alloc(seqlen_float); + dev_ctx.template Alloc(ce); + + cal_aux_loss(gate_prob.data(), + gate_prob_dims[0], + gate_prob_dims[1], + dispatch_mask.data(), + dispatch_mask_dims[0], + dispatch_mask_dims.size() > 1 ? dispatch_mask_dims[1] + : static_cast(1), + tokens_mask ? tokens_mask.get_ptr()->data() : nullptr, + dispatch_tokens_mask + ? dispatch_tokens_mask.get_ptr()->data() + : nullptr, + dispatch_tokens_mask_len, + num_experts, + use_group, + moe_k, + clip_min, + l_aux_loss->data(), + seqlen_float->data(), + ce->data(), + dev_ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cal_aux_loss, + GPU, + ALL_LAYOUT, + phi::CalAuxLossKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu new file mode 100644 index 00000000000000..b9f9aa7674c0bd --- /dev/null +++ b/paddle/phi/kernels/gpu/expand_modality_expert_id_kernel.cu @@ -0,0 +1,85 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/expand_modality_expert_id_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void expand_modality_expert_id(const T* expert_id, + T* expert_id_out, + int64_t seqlen, + int64_t k, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + cudaStream_t stream) { + thrust::transform( + thrust::cuda::par.on(stream), + thrust::device_pointer_cast(expert_id), + thrust::device_pointer_cast(expert_id) + seqlen * k, + thrust::counting_iterator(0), + thrust::device_pointer_cast(expert_id_out), + [k, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert] __device__(T e, T idx) { + if (is_group_expert) { + e += idx % k * group_size; + } + if (num_expert_per_modality <= 0) return static_cast(e); + T rank = e / num_expert_per_modality; + T expert_id_in_rank = e % num_expert_per_modality; + return static_cast(rank * (num_expert_per_modality * + 2) // HRAD code: only support 2 modality + + expert_id_in_rank + + modality_offset * num_expert_per_modality); + }); +} + +template +void ExpandModalityExpertIDKernel(const Context& dev_ctx, + const DenseTensor& expert_id, + int64_t num_expert_per_modality, + int64_t group_size, + int64_t modality_offset, + bool is_group_expert, + DenseTensor* expert_id_out) { + dev_ctx.template Alloc(expert_id_out); + auto expert_id_shape = expert_id.dims(); + int64_t seqlen = expert_id_shape[0]; + int64_t k = expert_id_shape[1]; + expand_modality_expert_id(expert_id.data(), + expert_id_out->data(), + seqlen, + k, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + dev_ctx.stream()); +} +} // namespace phi + +PD_REGISTER_KERNEL(expand_modality_expert_id, + GPU, + ALL_LAYOUT, + phi::ExpandModalityExpertIDKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/int_bincount.cu b/paddle/phi/kernels/gpu/int_bincount.cu new file mode 100644 index 00000000000000..733a0647bd0381 --- /dev/null +++ b/paddle/phi/kernels/gpu/int_bincount.cu @@ -0,0 +1,122 @@ +// NOLINT +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #include "paddle/extension.h" +#include "paddle/phi/kernels/int_bincount.h" // NOLINT +#include +#include +#include "cub/device/device_histogram.cuh" +#include "paddle/common/flags.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/empty_kernel.h" // NOLINT + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + +COMMON_DECLARE_bool(enable_pir_api); + +namespace phi { +static phi::DataType TransToDataType(int64_t dtype) { + if (FLAGS_enable_pir_api) { + return static_cast(dtype); + } else { + return phi::TransToPhiDataType(dtype); + } +} + +std::vector> IntBincountInferShape( + std::vector x_shape, + int64_t min_value, + int64_t max_value, + int64_t out_dtype) { + return {{max_value - min_value}}; +} + +std::vector IntBincountInferDType(phi::DataType x_dtype, + int64_t min_value, + int64_t max_value, + int64_t out_dtype) { + return {TransToDataType(out_dtype)}; +} + +template +void IntBincountImpl( + const Context &ctx, const T *x, int64_t n, T min_v, T max_v, BinsT *bins) { + DenseTensor workspace; + void *workspace_ptr = nullptr; + size_t workspace_size = 0; +#pragma unroll + for (int i = 0; i < 2; ++i) { + if (workspace_size > 0) { + workspace = + phi::Empty(ctx, {static_cast(workspace_size)}); + workspace_ptr = workspace.data(); + } + auto err = cub::DeviceHistogram::HistogramEven(workspace_ptr, + workspace_size, + x, + bins, + max_v - min_v + 1, + min_v, + max_v, + n, + ctx.stream()); + PD_CHECK( + err == cudaSuccess, "HistogramEven error: %s", cudaGetErrorString(err)); + } +} + +// T is x's input type and out_dtype is in args +template +void IntBincount(const Context &ctx, + const DenseTensor &x, + int64_t low, + int64_t high, + int64_t out_dtype, + DenseTensor *out) { + PD_CHECK(low < high); + int64_t bins_width = high - low; + PD_CHECK(bins_width + 1 < std::numeric_limits::max()); + + auto bins_dtype = TransToDataType(out_dtype); + + // auto x_dytpe = x.dtype(); + auto low_v = static_cast(low); + auto high_v = static_cast(high); + PD_CHECK(static_cast(low_v) == low); + PD_CHECK(static_cast(high_v) == high); + const auto *x_data = x.data(); + int64_t n = x.numel(); + if (bins_dtype == phi::DataType::INT32) { + ctx.template Alloc(out); + uint32_t *out_ptr = static_cast(out->data()); + IntBincountImpl( + ctx, x_data, n, low_v, high_v, out_ptr); + } else if (bins_dtype == phi::DataType::INT64) { + using ULLI = unsigned long long int; // NOLINT + ctx.template Alloc(out); + static_assert(sizeof(int64_t) == sizeof(ULLI)); + // WARNING: unsafe type cast used in original impl. + ULLI *out_ptr = static_cast(out->data()); + IntBincountImpl(ctx, x_data, n, low_v, high_v, out_ptr); + } else { + PD_THROW("Only support INT32 and INT64, but got %s", bins_dtype); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + int_bincount, GPU, ALL_LAYOUT, phi::IntBincount, int64_t, int) {} diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu new file mode 100644 index 00000000000000..93e831dedf6410 --- /dev/null +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.cu @@ -0,0 +1,88 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "paddle/common/exception.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" // NOLINT + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h" // NOLINT + +namespace phi { +// #define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") + +static void GetRowsCols(const std::vector &shape, + int *p_rows, + int *p_cols) { + int rows = 1; + for (int i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + int cols = shape[shape.size() - 1]; + *p_rows = rows; + *p_cols = cols; +} + +template +void RMSLnFwd(const Context &ctx, + const DenseTensor &x, + const DenseTensor &scale, + float epsilon, + DenseTensor *y, + DenseTensor *invvar) { + const auto &scale_shape = scale.dims(); + const auto &x_shape = x.dims(); + PD_CHECK(scale_shape.size() == 1); + PD_CHECK(scale_shape[0] == x_shape[x_shape.size() - 1]); + + int rows, cols; + rows = x_shape[0]; + cols = x_shape[1]; + // GetRowsCols(x_shape, &rows, &cols); + + *y = phi::EmptyLike(ctx, x); + *invvar = phi::Empty(ctx, {rows}); + + cuda_rms_norm(ctx, x, scale, rows, cols, epsilon, y, invvar); +} + +template +void RMSLnBwd(const Context &ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &invvar, + const DenseTensor &y_grad, + float epsilon, + DenseTensor *x_grad, + DenseTensor *scale_grad) { + int rows, cols; + const auto &x_shape = x.dims(); + rows = x_shape[0]; + cols = x_shape[1]; + ctx.template Alloc(x_grad); + ctx.template Alloc(scale_grad); + cuda_rms_norm_gradient( + ctx, x, scale, invvar, y_grad, rows, cols, epsilon, x_grad, scale_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + fused_rms_norm_ext, GPU, ALL_LAYOUT, phi::RMSLnFwd, float, double) {} + +PD_REGISTER_KERNEL( + fused_rms_norm_ext_grad, GPU, ALL_LAYOUT, phi::RMSLnBwd, float, double) {} diff --git a/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h new file mode 100644 index 00000000000000..2e96debcacf3d1 --- /dev/null +++ b/paddle/phi/kernels/gpu/layer_norm_cuda_kernel.h @@ -0,0 +1,1083 @@ +// NOLINT +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/common/exception.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +#include // NOLINT +#include // NOLINT + +namespace phi { +#define DEFAULT_THROW(NAME, TYPE) \ + default: \ + do { \ + PD_THROW(#NAME, " not implemented for '", TYPE, "'"); \ + } while (0); \ + break + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + DEFAULT_THROW(NAME, TYPEOUT); \ + } \ + break; \ + } \ + DEFAULT_THROW(NAME, TYPEIN); \ + } + +#define WARP_SIZE 32 + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, + int laneMask, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { + return __shfl_xor_sync(mask, value, laneMask, width); +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, + int srcLane, + int width = WARP_SIZE, + unsigned int mask = 0xffffffff) { + return __shfl_sync(mask, value, srcLane, width); +} + +template +__device__ void cuWelfordOnlineSum(const U curr, + U& mu, // NOLINT + U& sigma2, // NOLINT + U& count) { // NOLINT + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, + const U sigma2B, + const U countB, + U& mu, // NOLINT + U& sigma2, // NOLINT + U& count) { // NOLINT + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuRMSOnlineSum(const U curr, U& sigma2) { // NOLINT + sigma2 = sigma2 + curr * curr; +} + +template +__device__ void cuChanRMSOnlineSum(const U sigma2B, U& sigma2) { // NOLINT + sigma2 = sigma2 + sigma2B; +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, // NOLINT + U& sigma2, // NOLINT + U* buf, + bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; // NOLINT + U* ibuf = (U*)(ubuf + blockDim.y); // NOLINT + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + ubuf[2 * wrt_y + 1] = sigma2; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + U muB = ubuf[2 * threadIdx.y]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, // NOLINT + float& sigma2, // NOLINT + float* buf, + bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); // NOLINT + sigma2 = float(0); // NOLINT + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const auto* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { // NOLINT + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); // NOLINT + if (!rms_only) { + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr.x, sigma2); + cuRMSOnlineSum(curr.y, sigma2); + } + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + if (!rms_only) { + cuWelfordOnlineSum(curr, mu, sigma2, count); + } else { + cuRMSOnlineSum(curr, sigma2); + } + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + if (!rms_only) { + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; // NOLINT + float* ibuf = (float*)(ubuf + blockDim.y); // NOLINT + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y + 1] = sigma2; + if (!rms_only) { + ubuf[2 * wrt_y] = mu; + ibuf[wrt_y] = count; + } + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float sigma2B = ubuf[2 * threadIdx.y + 1]; + if (!rms_only) { + float muB = ubuf[2 * threadIdx.y]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + ubuf[0] = mu; + } + ubuf[1] = sigma2; + } + __syncthreads(); + if (!rms_only) { + mu = ubuf[0]; + } + sigma2 = ubuf[1] / float(n2); // NOLINT + // don't care about final value of count, we know count == n2 + } else { + if (!rms_only) { + mu = WARP_SHFL(mu, 0); + } + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); // NOLINT + } + } +} + +template +__inline__ __device__ U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +__inline__ __device__ float rsqrt(float v) { + return rsqrtf(v); +} +template <> +__inline__ __device__ double rsqrt(double v) { + return rsqrt(v); +} + +namespace { // NOLINT +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, + bool rms_only) { + // Assumptions: + // 1) blockDim.x == WARP_SIZE + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && (beta != NULL || rms_only)) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = + static_cast(static_cast(gamma[i]) * c_invvar * (curr - mu) + + static_cast(beta[i])); + } else { + ovals[i] = static_cast(static_cast(gamma[i]) * c_invvar * curr); + } + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + if (!rms_only) { + ovals[i] = static_cast(c_invvar * (curr - mu)); + } else { + ovals[i] = static_cast(c_invvar * curr); + } + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (!rms_only) { + mean[i1] = mu; + } + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + cuApplyLayerNorm_( + output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); +} + +template +__global__ void cuApplyRMSNorm(V* __restrict__ output_vals, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma) { + cuApplyLayerNorm_( + output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); +} + +template +__device__ void cuLoadWriteStridedInputs(const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; + } + } else { + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (!rms_only) { + warp_buf1[write_idx] = U(0); + } + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs(const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + bool rms_only) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean; + if (!rms_only) { + curr_mean = mean[i1]; + } + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + if (!rms_only) { + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; + } + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta, + bool rms_only) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; // NOLINT + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, + thr_load_row_off, + thr_load_col_off, + i2_off, + row_stride, + warp_buf1, + warp_buf2, + input, + dout, + i1_end, + n2, + mean, + invvar, + rms_only); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, + thr_load_row_off, + thr_load_col_off, + i2_off, + row_stride, + warp_buf1, + warp_buf2, + input, + dout, + i1_end, + n2, + mean, + invvar, + rms_only); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + if (!rms_only) { + acc1 += warp_buf1[idx1]; + } + acc2 += warp_buf2[idx1]; + } + + if (!rms_only) { + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + } + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + warp_buf1[idx1] += warp_buf1[idx2]; + } + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + if (!rms_only) { + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + } + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta, + bool rms_only) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + if (!rms_only) { + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + if (!rms_only) { + buf[write_idx + nbsize3] = sum_beta; + } + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + if (!rms_only) { + sum_beta += buf[read_idx + nbsize3]; + } + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + if (!rms_only) { + grad_beta[i2] = sum_beta; + } + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input, + bool rms_only) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + U c_mean; + if (!rms_only) { + c_mean = mean[i1]; + } + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + const U gamma_tmp = static_cast(gamma[l + k]); + if (!rms_only) { + sum_loss1 += c_loss * gamma_tmp; + sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + const U gamma_tmp = static_cast(gamma[l]); + if (!rms_only) { + sum_loss1 += c_loss * gamma_tmp; + sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; + } + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + if (!rms_only) { + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } else { + sum_loss2 += c_loss * (c_h)*c_invvar; + } + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + if (!rms_only) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + if (!rms_only) { + buf[2 * wrt_i] = sum_loss1; + } + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + if (!rms_only) { + sum_loss1 += buf[2 * read_i]; + } + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (!rms_only) { + buf[2 * threadIdx.x] = sum_loss1; + } + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + if (!rms_only) { + sum_loss1 = buf[2 * threadIdx.x]; + } + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * static_cast(gamma[l]); + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + if (!rms_only) { + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } else { + f_grad_input -= (c_h)*c_invvar * sum_loss2; + } + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + +static cudaDeviceProp GetDevicePropImpl() { + int device = -1; + PD_CHECK(cudaGetDevice(&device) == cudaSuccess); + cudaDeviceProp prop; + PD_CHECK(cudaGetDeviceProperties(&prop, device) == cudaSuccess); + return prop; +} + +static cudaDeviceProp* GetDeviceProp() { + static auto prop = GetDevicePropImpl(); + return ∝ +} + +template +void HostApplyLayerNorm(V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta, + cudaStream_t stream) { + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +template +void HostApplyRMSNorm(V* output, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + cudaStream_t stream) { + // auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + // const uint64_t maxGridY = + // at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyRMSNorm<<>>( + output, invvar, input, n1, n2, U(epsilon), gamma); +} + +// template +// void cuda_layer_norm(const Context& ctx, +// const DenseTensor& x, +// const DenseTensor& scale, +// const DenseTensor& bias, +// int rows, +// int cols, +// float epsilon, +// DenseTensor* y, +// DenseTensor* mean, +// DenseTensor* invvar) { +// DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( +// x.dtype(), +// y->dtype(), +// "cuda_layer_norm_kernel", +// HostApplyLayerNorm(y->data(), +// mean->data(), +// invvar->data(), +// const_cast(x.data()), +// rows, +// cols, +// epsilon, +// const_cast(scale.data()), +// const_cast(bias.data()), +// ctx.stream())); +// } + +template +void cuda_rms_norm(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + int rows, + int cols, + float epsilon, + DenseTensor* y, + DenseTensor* invvar) { + HostApplyRMSNorm(y->data(), + invvar->data(), + const_cast(x.data()), + rows, + cols, + epsilon, + const_cast(scale.data()), + ctx.stream()); +} + +template +void HostRMSNormGradient(const Context& ctx, + const V* dout, + const U* invvar, + const DenseTensor& input, + int n1, + int n2, + const V* gamma, + double epsilon, + T* grad_input, + V* grad_gamma, + cudaStream_t stream) { + if (gamma != NULL) { + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + auto place = input.place(); + DenseTensor part_grad_gamma = + phi::Empty(ctx, {part_size, n2}); + cuComputePartGradGammaBeta<<>>( + dout, + input.data(), + n1, + n2, + invvar, // unused + invvar, + U(epsilon), + part_grad_gamma.data(), + part_grad_gamma.data(), /* unused */ + true); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.data(), + part_grad_gamma.data(), /* unused */ + part_size, + n1, + n2, + grad_gamma, + grad_gamma, /* unused */ + true); + } + + // compute grad_input + const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, + input.data(), + n1, + n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + grad_input, + true); +} + +template +void cuda_rms_norm_gradient(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& invvar, + const DenseTensor& dy, + int rows, + int cols, + float epsilon, + DenseTensor* grad_x, + DenseTensor* grad_scale) { + HostRMSNormGradient(ctx, + dy.data(), + invvar.data(), + x, + rows, + cols, + scale.data(), + epsilon, + grad_x->data(), + grad_scale->data(), + ctx.stream()); +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu new file mode 100644 index 00000000000000..d7f603746ad39d --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu @@ -0,0 +1,177 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_combine_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" +namespace phi { + +template +__global__ void combine_moe_bwd_kernel(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + const int64_t n) { + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + int64_t row_i = i / hidden_size; + int64_t slice_i = i - row_i * hidden_size; + const int* scatter_index_start = scatter_index + row_i * k; + const T grad_y_i = *(grad_y + i); + // y [ row_i, slice_i] + // combine [row_i, k, slice_i] + int64_t weight_base = row_i * k * hidden_size + slice_i; + + T* grad_cw_ptr = + grad_combine_weights_helper + weight_base; // stride hidden_size + for (int64_t ki = 0; ki < k; ki++) { + // get combine_weights i + int64_t ele_index = + static_cast(*(scatter_index_start + ki)) * hidden_size + + slice_i; + const T* w_ptr = combine_weights + row_i * k + ki; + const T* x_ptr = x + ele_index; + if ((*w_ptr) != T(0)) { + *(grad_x + ele_index) = grad_y_i * (*w_ptr); + } + *(grad_cw_ptr + ki * hidden_size) = grad_y_i * (*x_ptr); + } + } +} + +template +void combine_moe_bwd_kernelLauncher(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + // y is [seqlen, hidden_size] + // for kk in k: + // y[i][j] += x[scatter_index[i][kk]][j] * combine_weights[i][kk] + + const int64_t n = hidden_size * seqlen; + + const int64_t threads = 1024; + const int64_t blocks = (n + threads - 1) / threads; + + combine_moe_bwd_kernel + <<>>(x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + k, + seqlen, + hidden_size, + n); +} + +template +void apply_moe_combine_bwd(const T* x, + const T* combine_weights, + const int* scatter_index, + const T* grad_y, + T* grad_x, + T* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + combine_moe_bwd_kernelLauncher(x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + k, + seqlen, + hidden_size, + stream); +} + +template +void moe_combine_bwd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + const DenseTensor* grad_x, + const DenseTensor* grad_combine_weights_helper, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size) { + apply_moe_combine_bwd( + x.data(), + combine_weights.data(), + scatter_index.data(), + grad_y.data(), + const_cast(grad_x->data()), + const_cast(grad_combine_weights_helper->data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); +} +template +void MoeCombineGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + DenseTensor* grad_x, + DenseTensor* grad_combine_weights_helper) { + dev_ctx.template Alloc(grad_x); + dev_ctx.template Alloc(grad_combine_weights_helper); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())), + 0, + grad_combine_weights_helper); + auto x_shape = x.dims(); + auto combine_weights_shape = combine_weights.dims(); + moe_combine_bwd(dev_ctx, + x, + combine_weights, + scatter_index, + grad_y, + grad_x, + grad_combine_weights_helper, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_combine_grad, + GPU, + ALL_LAYOUT, + phi::MoeCombineGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/moe_combine_kernel.cu b/paddle/phi/kernels/gpu/moe_combine_kernel.cu new file mode 100644 index 00000000000000..0c670f530f21c2 --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_combine_kernel.cu @@ -0,0 +1,131 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_combine_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +__global__ void combine_moe_kernel(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + const int64_t n) { + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + int64_t row_i = i / hidden_size; + int64_t slice_i = i - row_i * hidden_size; + const int* scatter_index_start = scatter_index + row_i * k; + T* dest_ptr = y + i; + for (int ki = 0; ki < k; ki++) { + // get combine_weights i + const T* w_ptr = combine_weights + row_i * k + ki; + const T* x_ptr = + x + static_cast(*(scatter_index_start + ki)) * hidden_size + + slice_i; + *(dest_ptr) += (*w_ptr) * (*x_ptr); + } + } +} + +template +void combine_moe_kernelLauncher(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + // y is [seqlen, hidden_size] + // for kk in k: + // y[i][j] += x[scatter_index[i][kk]][j] * combine_weights[i][kk] + const int64_t n = hidden_size * seqlen; + + const int64_t threads = 1024; + const int64_t blocks = (n + threads - 1) / threads; + + combine_moe_kernel<<>>( + x, combine_weights, scatter_index, y, k, seqlen, hidden_size, n); +} + +template +void apply_moe_combine_fwd(const T* x, + const T* combine_weights, + const int* scatter_index, + T* y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size, + cudaStream_t stream) { + combine_moe_kernelLauncher( + x, combine_weights, scatter_index, y, k, seqlen, hidden_size, stream); +} + +template +void moe_combine_fwd(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& y, + const int64_t k, + const int64_t seqlen, + const int64_t hidden_size) { + apply_moe_combine_fwd(x.data(), + combine_weights.data(), + scatter_index.data(), + const_cast(y.data()), + k, + seqlen, + hidden_size, + dev_ctx.stream()); +} + +template +void MoeCombineKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + DenseTensor* y) { + dev_ctx.template Alloc(y); // T cannot support phi::dtype::float8 very + // well, maybe replaced with x.dtype(); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + auto combine_weights_shape = combine_weights.dims(); + auto x_shape = x.dims(); + moe_combine_fwd(dev_ctx, + x, + combine_weights, + scatter_index, + *y, + combine_weights_shape[1], // k + combine_weights_shape[0], // seqlen + x_shape[1]); // hidden_size +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_combine, + GPU, + ALL_LAYOUT, + phi::MoeCombineKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu new file mode 100644 index 00000000000000..7612d36435880d --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_grad_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_grad, + const int* expert_id, // [s, k] + float* gate_logits_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + -1, + stream, + use_all2all_permute, + world_size, + num_local_experts, + capacity); + + topk_grad_with_mask_launcher(combine_weights_grad, + expert_id, + combine_weights, + gate_logits_grad, + num_rows, + k, + num_experts, + stream); +} + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + const DenseTensor& x_grad, + const DenseTensor& gate_logits_grad, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts) { + int64_t num_rows = combine_weights.dims()[0]; + int64_t k = combine_weights.dims()[1]; +#ifdef MOE_OPS_AUTO + int64_t hidden_size = y_grad.dims()[2]; +#else + int64_t hidden_size = y_grad.dims()[1]; +#endif + int64_t num_experts = gate_logits_grad.dims()[1]; + + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_grad.data(), + expert_id.data(), + const_cast(gate_logits_grad.data()), + const_cast(x_grad.data()), + num_rows, + k, + hidden_size, + num_experts, + capacity, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* x_grad, + DenseTensor* gate_logits_grad) { + auto y_grad_dims = y_grad.dims(); + auto scatter_index_dims = scatter_index.dims(); + +#ifdef MOE_OPS_AUTO + // y_grad shape is [num_experts, capacity, h] + int64_t num_experts = y_grad_dims[0]; + int64_t hidden_size = y_grad_dims[2]; +#else + int64_t num_experts = y_grad_dims[0] / capacity; + int64_t hidden_size = y_grad_dims[1]; +#endif + int64_t num_rows = scatter_index_dims[1]; + + const std::vector axis = {1, 0}; + + DenseTensor t_scatter_index; + phi::Transpose(dev_ctx, scatter_index, axis, &t_scatter_index); + DenseTensor t_scatter_index_; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_); + const DenseTensor t_scatter_index__ = t_scatter_index_; + + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(gate_logits_grad); + + moe_dispatch_bwd(dev_ctx, + combine_weights, + t_scatter_index__, + expert_id, + y_grad, + combine_weights_grad, + *x_grad, + *gate_logits_grad, + capacity); +} + +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu new file mode 100644 index 00000000000000..23bbca3cd8614a --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_kernel.cu @@ -0,0 +1,378 @@ +// NOLINT +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_gate_dispatch_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" +namespace phi { + +// -------- getWorkspaceSize -------- // +namespace { +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + // const int max_seq_len, + phi::CubKeyValueSorter &sorter // NOLINT +) { + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * inter_size); + // const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + // total_ws_bytes += padded_experts * sizeof(int64_t); // Hold + // total_rows_before_expert_ // expert_cnt total_ws_bytes += num_softmax_outs + // * sizeof(KeyT); const int bytes_for_fc1_result = interbuf_size * + // sizeof(KeyT); + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity + // // 用所有 bit 做排序,会降低些许性能,但是防止越界 + total_ws_bytes += sorter_ws_size_bytes; // intermediate (fc1) output + cub + // sorting workspace + // std::cout<<"sorter_ws_size_bytes = "< +void apply_moe_dispatch_fwd(const Context &dev_ctx, + const T *x, + const float *gate_logits, + const float *corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + T *y, + float *combine_weights, + int *scatter_index, + int64_t *expert_offset, + int *expert_id, + bool use_pad, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + phi::CubKeyValueSorter sorter(stream); + // phi::funcs::SetConstant zero; + // zero(ctx, &finished_tensor, false); + + DenseTensor xpanded_source_row_to_expanded_dest_row_tensor = + phi::Empty(dev_ctx, IntArray({num_rows, k})); + // int* expanded_source_row_to_expanded_dest_row = + // expanded_source_row_to_expanded_dest_row_tensor.data(); + + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); + + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); + DenseTensor active_cnt_tensor = + phi::Empty(dev_ctx, IntArray({1})); + + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, // hidden-size=0 + 0, // inter-size=0 + num_experts, + k, + sorter); + + DenseTensor ws_ptr_tensor = + phi::Empty(dev_ctx, IntArray({bytes})); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + // Pointers + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expert_id_; + + // T* permuted_data_; + float *softmax_out_; + // int64_t* total_rows_before_expert_; + T *fc1_result_; + + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * 0); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expert_id_ = permuted_experts_ + num_moe_inputs; + + // permuted_data_ = reinterpret_cast(expert_id_ + num_moe_inputs); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); + + // only use one number + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); + + fc1_result_ = reinterpret_cast(expert_id_ + num_moe_inputs); + softmax_out_ = nullptr; + +#ifdef DEBUG_MOE_OP + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); +#endif + + topk_gating_softmax_kernelLauncher(gate_logits, + corr_bias, + combine_weights, // output + softmax_out_, // no use + expert_id, // output + source_rows_, // output + num_rows, + num_experts, + k, + stream); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + // modify expert-id according to k + if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id + modify_expert_id_launcher( + expert_id, expert_id_, k, num_rows, num_experts, stream); + + // calc expert-size + /* + if (!use_pad) + cal_expert_size_and_filter_launcher(expert_id, + k * num_rows, + num_experts, + capacity, + stream); + */ +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id, 8, 16, std::string("expert-id after modified")); +#endif + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + use_pad ? expert_id_ : expert_id, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + if (use_pad) + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_total_rows_before_expert( + permuted_experts_, k * num_rows, num_experts, expert_offset, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(expert_offset, 8, 16, std::string("expert_offset")); + int64_t num_active_host_v2; + cudaMemcpy(&num_active_host_v2, + expert_offset + num_experts - 1, + sizeof(int64_t), + cudaMemcpyDeviceToHost); + std::cerr << "[DEBUG] num_active v2: " << num_active_host_v2 << std::endl; + print_to_screen1(permuted_experts_, + 8, + num_active_host_v2 + 2, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); +#endif + + if (!use_all2all_permute) { + initialize_moe_routing_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + use_pad, + stream); + } else { + PD_CHECK(num_experts > 0); + PD_CHECK(world_size > 0); + initialize_moe_routing_permute_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + world_size, + num_local_experts, + stream); + } + + // turn expert_offset_ptr into experts_num + // auto expert_offset_ptr = thrust::device_pointer_cast(expert_offset); + // thrust::adjacent_difference( + // expert_offset_ptr, expert_offset_ptr + num_experts, expert_offset_ptr + // ); +#ifdef DEBUG_MOE_OP + print_to_screen1( + scatter_index, 8, 16, std::string("scatter_index after pad")); +#endif + // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, + // cudaMemcpyDeviceToDevice); cudaMemcpy(combine_weights, expert_scales_float, + // sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); + return; +} + +template +void moe_dispatch_fwd(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + const DenseTensor &y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &expert_offset, + const DenseTensor &expert_id, + bool use_pad, + int64_t use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1) { + apply_moe_dispatch_fwd( + dev_ctx, + x.data(), + gate_logits.data(), + corr_bias ? corr_bias.get_ptr()->data() : nullptr, + num_rows, + num_experts, + hidden_size, + capacity, + k, + const_cast(y.data()), + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(expert_offset.data()), + const_cast(expert_id.data()), + use_pad, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGradDispatchKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor *y, + DenseTensor *combine_weights, + DenseTensor *scatter_index, + DenseTensor *expert_offset, + DenseTensor *expert_id) { + dev_ctx.template Alloc(expert_id); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(combine_weights); + dev_ctx.template Alloc(y); + + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + auto x_dims = x.dims(); + auto gate_logits_dims = gate_logits.dims(); + + const int64_t num_rows = x_dims[0]; + const int64_t hidden_size = x_dims[1]; + const int64_t num_experts = gate_logits_dims[1]; + + moe_dispatch_fwd(dev_ctx, + x, + gate_logits, + corr_bias, + num_rows, + num_experts, + hidden_size, + capacity, + k, + *y, + *combine_weights, + *scatter_index, + *expert_offset, + *expert_id, + use_pad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch, + GPU, + ALL_LAYOUT, + phi::MoeGradDispatchKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu new file mode 100644 index 00000000000000..48d07e2bff02ac --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_grad_kernel.cu @@ -0,0 +1,152 @@ +// NOLINT +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" +namespace phi { + +template +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_grad, + const int* expert_id, // [s, k] + float* gate_logits_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t capacity, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + -1, + stream, + use_all2all_permute, + world_size, + num_local_experts, + capacity); + + topk_grad_with_mask_launcher(combine_weights_grad, + expert_id, + combine_weights, + gate_logits_grad, + num_rows, + k, + num_experts, + stream); +} + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + DenseTensor& x_grad, // NOLINT + DenseTensor& gate_logits_grad, // NOLINT + int64_t capacity, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1) { + auto combine_weights_dims = combine_weights.dims(); + int64_t num_rows = combine_weights_dims[0]; + int64_t k = combine_weights_dims[1]; + auto y_grad_dims = y_grad.dims(); + int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1]; + int64_t num_experts = gate_logits_grad.dims()[1]; + + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_grad.data(), + expert_id.data(), + gate_logits_grad.data(), + x_grad.data(), + num_rows, + k, + hidden_size, + num_experts, + capacity, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoeGateDispatchGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [num_local_experts, num_experts * capacity + // // num_local_experts, h] + const DenseTensor& y_grad, // [s, k] + const DenseTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* x_grad, + DenseTensor* gate_logits_grad) { + int64_t num_local_experts = y_grad.dims()[0]; + auto scatter_index_dims = scatter_index.dims(); + + DenseTensor t_scatter_index; + phi::Transpose( + dev_ctx, scatter_index, {1, 0}, &t_scatter_index); + DenseTensor t_scatter_index_; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_); + + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(gate_logits_grad); + moe_dispatch_bwd(dev_ctx, + combine_weights, + t_scatter_index_, + expert_id, + y_grad, + combine_weights_grad, + *x_grad, + *gate_logits_grad, + capacity, + true, /*use_all2all_permute*/ + world_size, + num_local_experts); +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_permute_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu new file mode 100644 index 00000000000000..8d35c763770695 --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_gate_dispatch_permute_kernel.cu @@ -0,0 +1,378 @@ +// NOLINT +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" +namespace phi { + +namespace { +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + // const int max_seq_len, + phi::CubKeyValueSorter &sorter // NOLINT +) { + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * inter_size); + // const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + // total_ws_bytes += padded_experts * sizeof(int64_t); // Hold + // total_rows_before_expert_ // expert_cnt total_ws_bytes += num_softmax_outs + // * sizeof(KeyT); const int bytes_for_fc1_result = interbuf_size * + // sizeof(KeyT); + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity + // // 用所有 bit 做排序,会降低些许性能,但是防止越界 + total_ws_bytes += sorter_ws_size_bytes; // intermediate (fc1) output + cub + // sorting workspace + // std::cout<<"sorter_ws_size_bytes = "< +void apply_moe_dispatch_fwd(const Context &dev_ctx, + const T *x, + const float *gate_logits, + const float *corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + T *y, + float *combine_weights, + int *scatter_index, + int64_t *expert_offset, + int *expert_id, + bool use_pad, + bool use_all2all_permute, + int64_t world_size, + int64_t num_local_experts, + cudaStream_t stream) { + phi::CubKeyValueSorter sorter(stream); + // phi::funcs::SetConstant zero; + // zero(ctx, &finished_tensor, false); + + DenseTensor xpanded_source_row_to_expanded_dest_row_tensor = + phi::Empty(dev_ctx, IntArray({num_rows, k})); + // int* expanded_source_row_to_expanded_dest_row = + // expanded_source_row_to_expanded_dest_row_tensor.data(); + + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); + + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); + DenseTensor active_cnt_tensor = + phi::Empty(dev_ctx, IntArray({1})); + + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, // hidden-size=0 + 0, // inter-size=0 + num_experts, + k, + sorter); + + DenseTensor ws_ptr_tensor = + phi::Empty(dev_ctx, IntArray({bytes})); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + // Pointers + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expert_id_; + + // T* permuted_data_; + float *softmax_out_; + // int64_t* total_rows_before_expert_; + T *fc1_result_; + + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * 0); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expert_id_ = permuted_experts_ + num_moe_inputs; + + // permuted_data_ = reinterpret_cast(expert_id_ + num_moe_inputs); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); + + // only use one number + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); + + fc1_result_ = reinterpret_cast(expert_id_ + num_moe_inputs); + softmax_out_ = nullptr; + +#ifdef DEBUG_MOE_OP + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); +#endif + + topk_gating_softmax_kernelLauncher(gate_logits, + corr_bias, + combine_weights, // output + softmax_out_, // no use + expert_id, // output + source_rows_, // output + num_rows, + num_experts, + k, + stream); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + // modify expert-id according to k + if (use_pad) // 为了区分 k=1 选择和 k=2 选择,修改 expert-id + modify_expert_id_launcher( + expert_id, expert_id_, k, num_rows, num_experts, stream); + + // calc expert-size + /* + if (!use_pad) + cal_expert_size_and_filter_launcher(expert_id, + k * num_rows, + num_experts, + capacity, + stream); + */ +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id, 8, 16, std::string("expert-id after modified")); +#endif + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + use_pad ? expert_id_ : expert_id, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + if (use_pad) + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_total_rows_before_expert( + permuted_experts_, k * num_rows, num_experts, expert_offset, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1(expert_offset, 8, 16, std::string("expert_offset")); + int64_t num_active_host_v2; + cudaMemcpy(&num_active_host_v2, + expert_offset + num_experts - 1, + sizeof(int64_t), + cudaMemcpyDeviceToHost); + std::cerr << "[DEBUG] num_active v2: " << num_active_host_v2 << std::endl; + print_to_screen1(permuted_experts_, + 8, + num_active_host_v2 + 2, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); +#endif + + if (!use_all2all_permute) { + initialize_moe_routing_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + use_pad, + stream); + } else { + PD_CHECK(num_experts > 0); + PD_CHECK(world_size > 0); + initialize_moe_routing_permute_kernelLauncher(x, + y, + permuted_rows_, + scatter_index, + permuted_experts_, + expert_offset, + combine_weights, + static_cast(num_rows), + static_cast(hidden_size), + static_cast(k), + capacity, + world_size, + num_local_experts, + stream); + } + + // turn expert_offset_ptr into experts_num + // auto expert_offset_ptr = thrust::device_pointer_cast(expert_offset); + // thrust::adjacent_difference( + // expert_offset_ptr, expert_offset_ptr + num_experts, expert_offset_ptr + // ); +#ifdef DEBUG_MOE_OP + print_to_screen1( + scatter_index, 8, 16, std::string("scatter_index after pad")); +#endif + // cudaMemcpy(scatter_index, permuted_rows_, sizeof(int64_t) * k * num_rows, + // cudaMemcpyDeviceToDevice); cudaMemcpy(combine_weights, expert_scales_float, + // sizeof(float) * k * num_rows, cudaMemcpyDeviceToDevice); + return; +} + +template +void moe_dispatch_fwd(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + const DenseTensor &y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &expert_offset, + const DenseTensor &expert_id, + bool use_pad, + int64_t use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1) { + apply_moe_dispatch_fwd( + dev_ctx, + x.data(), + gate_logits.data(), + corr_bias ? corr_bias.get_ptr()->data() : nullptr, + num_rows, + num_experts, + hidden_size, + capacity, + k, + const_cast(y.data()), + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(expert_offset.data()), + const_cast(expert_id.data()), + use_pad, + use_all2all_permute, + world_size, + num_local_experts, + dev_ctx.stream()); +} + +template +void MoEDispatchPermuteKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &gate_logits, + const paddle::optional &corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor *y, + DenseTensor *combine_weights, + DenseTensor *scatter_index, + DenseTensor *expert_offset, + DenseTensor *expert_id) { + dev_ctx.template Alloc(expert_id); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(combine_weights); + dev_ctx.template Alloc(y); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + const auto &x_shape = x.dims(); + const auto &gate_logits_shape = gate_logits.dims(); + int64_t num_rows = x_shape[0]; + int64_t hidden_size = x_shape[1]; + int64_t num_experts = gate_logits_shape[1]; + int64_t num_local_experts = num_experts / world_size; + moe_dispatch_fwd(dev_ctx, + x, + gate_logits, + corr_bias, + num_rows, + num_experts, + hidden_size, + capacity, + k, + *y, + *combine_weights, + *scatter_index, + *expert_offset, + *expert_id, + true, /*use_pad*/ + true, /*use_all2all_permute*/ + world_size, + num_local_experts); +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_permute, + GPU, + ALL_LAYOUT, + phi::MoEDispatchPermuteKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu new file mode 100644 index 00000000000000..f1e1ec0d752bcc --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu @@ -0,0 +1,148 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h" +#include +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_bwd_op.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void apply_moe_dispatch_bwd(const T* y_grad, + const float* combine_weights, // [s, k] + const int* scatter_index, // [s, k] + const float* combine_weights_out_grad, + float* combine_weights_in_grad, + T* x_grad, + int64_t num_rows, + int64_t k, + int64_t dim, + int64_t num_experts, + int64_t num_active, + cudaStream_t stream) { + gather_with_mask_launcher(y_grad, + scatter_index, + combine_weights, + x_grad, + num_rows, + k, + dim, + num_active, + stream); + auto out_grad_ptr = thrust::device_pointer_cast(combine_weights_out_grad); + auto in_grad_ptr = thrust::device_pointer_cast(combine_weights_in_grad); + auto combine_weight_ptr = thrust::device_pointer_cast(combine_weights); + thrust::transform(thrust::cuda::par.on(stream), + out_grad_ptr, + out_grad_ptr + num_rows * k, + combine_weight_ptr, + in_grad_ptr, + [] __device__(float g, float w) { + return w > static_cast(0) ? g + : static_cast(0); + }); + // topk_grad_with_mask_launcher(combine_weights_grad, + // expert_id, + // combine_weights, + // gate_logtis_grad, + // num_rows, k, num_experts, stream); +} + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_out_grad, // [s, k] + DenseTensor* x_grad, + DenseTensor* combine_weights_in_grad, + int64_t num_experts) { + int64_t num_rows = combine_weights.dims()[0]; + int64_t k = combine_weights.dims()[1]; + int64_t hidden_size = y_grad.dims()[1]; + int64_t num_active = y_grad.dims()[0]; + + apply_moe_dispatch_bwd(y_grad.data(), + combine_weights.data(), + scatter_index.data(), + combine_weights_out_grad.data(), + combine_weights_in_grad->data(), + x_grad->data(), + num_rows, + k, + hidden_size, + num_experts, + num_active, + dev_ctx.stream()); +} + +template +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad) { + dev_ctx.template Alloc(x_grad); + dev_ctx.template Alloc(combine_weights_grad); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(combine_weights_grad->dims())), + 0, + combine_weights_grad); + DenseTensor t_scatter_index; + phi::Transpose( + dev_ctx, scatter_index, {1, 0}, &t_scatter_index); + DenseTensor t_scatter_index_out; + phi::ContiguousKernel( + dev_ctx, t_scatter_index, &t_scatter_index_out); + t_scatter_index = t_scatter_index_out; + int64_t num_experts = expert_offset.dims()[0]; + moe_dispatch_bwd(dev_ctx, + combine_weights_out, + t_scatter_index, + y_grad, + combine_weights_out_grad, + x_grad, + combine_weights_grad, + num_experts); +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk_grad, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchPartialNoSoftMaxTopkGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu new file mode 100644 index 00000000000000..d870fb85bc166b --- /dev/null +++ b/paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu @@ -0,0 +1,608 @@ +// NOLINT +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/moe_fuse_op.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" +#include "paddle/phi/kernels/slice_kernel.h" + +namespace phi { + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// already defined need to revise! +// static inline size_t AlignTo16(const size_t &input){ +// static constexpr int ALIGNMENT = 16; +// return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +// } +namespace { +// -------- getWorkspaceSize -------- // +template +size_t getWorkspaceSize(const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int capacity, + const int k, + // const int max_seq_len, + bool use_pad, + phi::CubKeyValueSorter &sorter) { // NOLINT + // const int buf_size = AlignTo16(k * num_rows * hidden_size); + const int interbuf_size = AlignTo16(k * num_rows * inter_size); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + const int num_dispatched_size = AlignTo16(num_experts * capacity); + int num_softmax_outs = 0; + + // softmax output, permuted_rows and permuted_experts have moved to outside of + // moe kernel, allocate them in Encoder or Decoder before invoking FfnLayer + // forward. + size_t total_ws_bytes = + 4 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += 2 * num_dispatched_size * sizeof(int); + total_ws_bytes += + padded_experts * + sizeof(int64_t); // Hold total_rows_before_expert_ // expert_cnt + // total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + total_ws_bytes += num_softmax_outs * sizeof(KeyT); + const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const int sorter_ws_size_bytes = + std::max(AlignTo16(sorter.getWorkspaceSize(k * num_rows)), + AlignTo16(sorter.getWorkspaceSize(capacity))); + // sorter.update_num_experts(num_experts+1); // +1 for filter out of capacity + // // 用所有 bit 做排序,会降低些许性能,但是防止越界 + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + // std::cout<<"num_softmax_outs --"<< num_softmax_outs << std::endl; + total_ws_bytes += + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace + // std::cout<<"buf_size --"<< buf_size<<" "< +void apply_moe_dispatch_fwd( + const Context &dev_ctx, + const DenseTensor &x, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + thrust::host_vector &expert_offset_host, // NOLINT + DenseTensor *y, + float *combine_weights, + int *scatter_index, + int *scatter_index_rev, + int64_t *expert_offset_global, + int64_t *expert_nums_local, + int *expert_id, + bool use_pad, + cudaStream_t stream) { + phi::CubKeyValueSorter sorter(stream); + // paddle::Tensor expanded_source_row_to_expanded_dest_row_tensor = + // paddle::empty({num_rows, k}, paddle::DataType::INT32, place); + // int* expanded_source_row_to_expanded_dest_row = + // expanded_source_row_to_expanded_dest_row_tensor.data(); + + // paddle::Tensor expert_scales_tensor_float = paddle::empty({num_rows, k}, + // paddle::DataType::FLOAT32, place); float* expert_scales_float = + // expert_scales_tensor_float.data(); + + // paddle::Tensor expert_for_source_row_tensor = paddle::empty({num_rows, k}, + // paddle::DataType::INT32, place); int* expert_for_source_row = + // expert_for_source_row_tensor.data(); paddle::Tensor active_cnt_tensor + // = paddle::empty({1}, paddle::DataType::INT32, place); + + int64_t bytes = getWorkspaceSize(num_rows, + hidden_size, // hidden-size=0 + 0, // inter-size=0 + num_experts, + capacity, + k, + use_pad, + sorter); + + DenseTensor ws_ptr_tensor = phi::Empty(dev_ctx, {bytes}); + int8_t *ws_ptr = ws_ptr_tensor.data(); + + phi::memory_utils::ThrustAllocator allocator(dev_ctx.GetPlace(), + dev_ctx.stream()); + + // Pointers + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expert_id_; + int *source_rows_for_seqsort_; + int *source_rows_for_seqsort_out_; + int *source_pos_for_seqsort_; + int *source_pos_for_seqsort_out_; + int64_t *expert_offset_; // local-expert-offset + + char *sorter_ws_; + // T* permuted_data_; + float *softmax_out_; + // int64_t* total_rows_before_expert_; + T *fc1_result_; + + const int sorter_ws_size_bytes = + AlignTo16(sorter.getWorkspaceSize(k * num_rows)); + const int sorter_ws_size_bytes_seqsort = + AlignTo16(sorter.getWorkspaceSize(capacity)); + + const int buf_size = AlignTo16(k * num_rows * hidden_size); + // const int interbuf_size = AlignTo16(k * num_rows * 0); + const int padded_experts = AlignTo16(num_experts); + const int num_moe_inputs = AlignTo16(k * num_rows); + const int num_dispatched_size = AlignTo16(num_experts * capacity); + + // 4:ints [k*row] + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expert_id_ = permuted_experts_ + num_moe_inputs; + // 4:ints: [E*C] + source_rows_for_seqsort_ = expert_id_ + num_moe_inputs; + source_rows_for_seqsort_out_ = source_rows_for_seqsort_ + num_dispatched_size; + // 1:ints: [E] + expert_offset_ = reinterpret_cast(source_rows_for_seqsort_out_ + + num_dispatched_size); + // permuted_data_ = reinterpret_cast(expert_offset_ + padded_experts); + // total_rows_before_expert_ = reinterpret_cast(permuted_experts_ + + // buf_size); + + // only use one number + // num_active = reinterpret_cast(permuted_experts_ + + // num_moe_inputs); + fc1_result_ = reinterpret_cast(expert_offset_ + padded_experts); + // fc1_result_ = reinterpret_cast(permuted_data_ + buf_size); + +#ifdef DEBUG_MOE_OP + // print_to_screen1(gate_logits, 8, 16, std::string("gate_logits + // before_topk")); print_to_screen1(finished, 2, 16, std::string("finished + // before_topk")); +#endif + + thrust::transform(thrust::cuda::par.on(stream), + thrust::device_pointer_cast(source_rows_), + thrust::device_pointer_cast(source_rows_) + num_rows * k, + thrust::counting_iterator(0), + thrust::device_pointer_cast(source_rows_), + [num_rows, k] __device__(int i, int cnt) { + int k_idx = cnt % k; + int block_row = cnt / k; + return k_idx * num_rows + block_row; + }); + +#ifdef DEBUG_MOE_OP + // phi::CastKernel(ctx, expert_scales_tensor_float, + // expert_scales_tensor.dtype(), &expert_scales_tensor); + print_to_screen1( + combine_weights, 8, 16, std::string("expert_scales_float after topk")); + print_to_screen1( + expert_id, 8, 16, std::string("expert-id before permute")); + print_to_screen1( + source_rows_, 8, 16, std::string("desc->src idx before permute")); +#endif + + // compute global expert offset, **not** consider capacity + // 必须在 modify_and_mask_expert_id_launcher 之前算出**全局 expert-offset** + + compute_global_expert_offset(expert_id, + expert_id_, // buffer + expert_offset_global, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + // modify expert-id according to k + modify_and_mask_expert_id_launcher(expert_id, + expert_id_, + k, + num_rows, + static_cast(num_experts), + static_cast(expert_start_index), + static_cast(expert_end_index), + stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + expert_id_, 8, 16, std::string("expert-id after modified 22")); +#endif + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + expert_id_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + source_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + unmodify_expert_id_launcher( + permuted_experts_, permuted_experts_, k, num_rows, num_experts, stream); + +#ifdef DEBUG_MOE_OP + print_to_screen1( + permuted_experts_, 8, 16, std::string("expert-id after permute")); + print_to_screen1( + permuted_rows_, 8, 16, std::string("dest->src idx after permute")); +#endif + + compute_local_expert_offset(permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG] num_active v2: " << expert_offset_host.back() + << std::endl; + print_to_screen1( + expert_offset_global, 8, 16, std::string("expert_offset global")); + print_to_screen1(expert_offset_, 8, 16, std::string("expert_offset local")); + print_to_screen1(permuted_experts_, + 8, + 16, + std::string("expert-id after permute")); + // print_to_screen1(permuted_experts_, 4096, 8192, + // std::string("expert-id after permute")); +#endif + + // calc expert-size + // 不 use-pad 的情况下,在此处标记截断位置。之后需要再 sort 一遍把截断 id + // 放到句尾 + if (!use_pad) { // 2sort + cal_expert_size_and_filter_launcher(permuted_experts_, + expert_offset_, + expert_offset_host.back(), + num_experts, + capacity, + expert_start_index, + expert_end_index, + reverse_token_drop, + stream); + // 2sort + sorter.run( + fc1_result_, + sorter_ws_size_bytes, + permuted_experts_, // key in + permuted_experts_, // key out // [num_row, k]: expert-id + permuted_rows_, // value in + permuted_rows_, // value out //[num_row, k]: id在原 activation 中的位置 + k * num_rows, // num_rows + false, + stream); + + compute_local_expert_offset(permuted_experts_, + expert_offset_, + expert_nums_local, + num_rows * k, + num_experts, + capacity, + stream, + allocator); + + CUDACHECK(cudaMemcpyAsync(expert_offset_host.data(), + expert_offset_, + num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + +#ifdef DEBUG_MOE_OP + std::cerr << "[DEBUG](after 2sort) num_active v2: " + << expert_offset_host.back() << std::endl; + print_to_screen1( + expert_id_, 8, 16, std::string(" permuted_experts")); + print_to_screen1(permuted_experts_, + 8, + 16, + std::string(" permuted_experts")); + print_to_screen1( + permuted_rows_, 8, 16, std::string(" dest->src idx")); +#endif + } + + thrust::fill( + thrust::cuda::par.on(stream), + thrust::device_ptr(scatter_index_rev), + thrust::device_ptr(scatter_index_rev) + num_experts * capacity, + num_rows); + build_seqsort_kv_pairs_kernel_launcher( + scatter_index_rev, // padded_to_unpermuted_input + source_rows_for_seqsort_, // seqsort-value + permuted_rows_, + // scatter_index, // 对截断位置置0 + permuted_experts_, + expert_offset_, + combine_weights, // 对截断位置置0 + static_cast(num_rows), + static_cast(k), + expert_offset_host.back(), // num_active + capacity, + expert_start_index, // expert start index + use_pad, + stream); + +#ifdef DEBUG_MOE_OP + + // print_to_screen1(scatter_index, 8, 16, std::string("scatter_index + // after build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1(source_rows_for_seqsort_, + 8, + 16, + std::string("source_rows_for_seqsort_ after " + "build_seqsort_kv_pairs_kernel_launcher")); + print_to_screen1( + scatter_index_rev, + 8, + 16, + std::string( + "scatter_index_rev after build_seqsort_kv_pairs_kernel_launcher")); +#endif + if (use_pad) { + for (auto iexpert = 0; iexpert != expert_end_index - expert_start_index; + ++iexpert) { + sorter.run(fc1_result_, + sorter_ws_size_bytes_seqsort, + scatter_index_rev + (iexpert * capacity), // key in + scatter_index_rev + (iexpert * capacity), // key out + source_rows_for_seqsort_ + (iexpert * capacity), // value in + source_rows_for_seqsort_ + + (iexpert * capacity), // value out //[num_row, k]: id在原 + // activation 中的位置 + capacity, // num_rows + false, + stream); + } + } else { + auto sort_iter = thrust::make_zip_iterator(thrust::make_tuple( + thrust::device_pointer_cast(permuted_experts_), // key1 + thrust::device_pointer_cast(scatter_index_rev), // key2 + thrust::device_pointer_cast(source_rows_for_seqsort_))); + thrust::stable_sort(thrust::cuda::par.on(stream), + sort_iter, + sort_iter + expert_offset_host.back(), + [] __device__(auto lhs, auto rhs) { + if (thrust::get<0>(lhs) < thrust::get<0>(rhs)) + return true; + else if (thrust::get<0>(lhs) > thrust::get<0>(rhs)) + return false; + else + return thrust::get<1>(lhs) < thrust::get<1>(rhs); + }); + } + if (use_pad) { + int64_t num_experts_diff = expert_end_index - expert_start_index; + y->Resize({num_experts_diff * capacity, x.dims()[1]}); + dev_ctx.template Alloc(y); + } else { + y->Resize({expert_offset_host.back(), x.dims()[1]}); + dev_ctx.template Alloc(y); + } + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y); + copy_unpermuted_to_permuted_kernelLauncher( + x.data(), + y->data(), // out + scatter_index_rev, // padded_out_to_unpermuted_input + source_rows_for_seqsort_, // padded_out_to_expanded_input + scatter_index, // out + use_pad ? (expert_end_index - expert_start_index) * capacity + : expert_offset_host.back(), // num_active + num_rows, + k, + hidden_size, + stream); + // cudaDeviceSynchronize(); //debug + // turn expert_offset_ptr into experts_num + return; +} + +template +void moe_dispatch_fwd( + const Context &dev_ctx, + const DenseTensor &x, + int64_t num_rows, + int64_t num_experts, + int64_t hidden_size, + int64_t capacity, + int64_t k, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + thrust::host_vector &expert_offset_host, // NOLINT + DenseTensor *y, + const DenseTensor &combine_weights, + const DenseTensor &scatter_index, + const DenseTensor &scatter_index_rev, + const DenseTensor &expert_offset, + const DenseTensor &expert_nums_local, + const DenseTensor &expert_id, + bool use_pad) { + apply_moe_dispatch_fwd( + dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + const_cast(combine_weights.data()), + const_cast(scatter_index.data()), + const_cast(scatter_index_rev.data()), + const_cast(expert_offset.data()), + const_cast(expert_nums_local.data()), + const_cast(expert_id.data()), + use_pad, + dev_ctx.stream()); +} + +template +void MoeGateDispatchPartialNoSoftMaxTopkKernel( + const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &combine_weights, + const DenseTensor &expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor *y, + DenseTensor *combine_weights_out, + DenseTensor *scatter_index, + DenseTensor *scatter_index_rev, + DenseTensor *expert_offset, + DenseTensor *expert_nums_local) { + dev_ctx.template Alloc(scatter_index); + dev_ctx.template Alloc(scatter_index_rev); + dev_ctx.template Alloc(expert_offset); + dev_ctx.template Alloc(expert_nums_local); + dev_ctx.template Alloc(combine_weights_out); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(scatter_index->dims())), + 0, + scatter_index); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(scatter_index_rev->dims())), + 0, + scatter_index_rev); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(expert_offset->dims())), + 0, + expert_offset); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(expert_nums_local->dims())), + 0, + expert_nums_local); + phi::Full( + dev_ctx, + phi::IntArray(common::vectorize(combine_weights_out->dims())), + 0, + combine_weights_out); + phi::Copy( + dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out); + const auto &x_shape = x.dims(); + int64_t num_rows = x_shape[0]; + int64_t hidden_size = x_shape[1]; + thrust::host_vector expert_offset_host(num_experts); + int64_t num_experts_diff = expert_end_index - expert_start_index; + moe_dispatch_fwd(dev_ctx, + x, + num_rows, + num_experts, + hidden_size, + capacity, + k, + expert_start_index, + expert_end_index, + reverse_token_drop, + expert_offset_host, + y, + *combine_weights_out, + *scatter_index, + *scatter_index_rev, + *expert_offset, // global-offset + *expert_nums_local, + expert_id, + use_pad); + if (use_pad) { + // scatter_index_rev = scatter_index_rev.slice(0, num_experts_diff * + // capacity); + *scatter_index_rev = phi::Slice( + dev_ctx, *scatter_index_rev, {0}, {0}, {num_experts_diff * capacity}); + } else { + if (expert_offset_host.back() > 0) { + // scatter_index_rev = scatter_index_rev.slice(0, + // expert_offset_host.back()); + *scatter_index_rev = phi::Slice( + dev_ctx, *scatter_index_rev, {0}, {0}, {expert_offset_host.back()}); + } else { + *y = phi::Empty(dev_ctx, {1, x_shape[1]}); + *scatter_index_rev = + phi::Empty(dev_ctx, {}); // special treatment + } + } +} +} // namespace phi + +PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk, + GPU, + ALL_LAYOUT, + phi::MoeGateDispatchPartialNoSoftMaxTopkKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/int_bincount.h b/paddle/phi/kernels/int_bincount.h new file mode 100644 index 00000000000000..29dfb582a14211 --- /dev/null +++ b/paddle/phi/kernels/int_bincount.h @@ -0,0 +1,28 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IntBincount(const Context& ctx, + const DenseTensor& x, + int64_t low, + int64_t high, + int64_t out_dtype, + DenseTensor* out); +} diff --git a/paddle/phi/kernels/moe_combine_grad_kernel.h b/paddle/phi/kernels/moe_combine_grad_kernel.h new file mode 100644 index 00000000000000..7468d9e944ce34 --- /dev/null +++ b/paddle/phi/kernels/moe_combine_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeCombineGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& grad_y, + DenseTensor* grad_x, + DenseTensor* grad_combine_weights_helper); +} // namespace phi diff --git a/paddle/phi/kernels/moe_combine_kernel.h b/paddle/phi/kernels/moe_combine_kernel.h new file mode 100644 index 00000000000000..8057833db0f604 --- /dev/null +++ b/paddle/phi/kernels/moe_combine_kernel.h @@ -0,0 +1,25 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeCombineKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/moe_fuse_bwd_op.h b/paddle/phi/kernels/moe_fuse_bwd_op.h new file mode 100644 index 00000000000000..e1a008baecf225 --- /dev/null +++ b/paddle/phi/kernels/moe_fuse_bwd_op.h @@ -0,0 +1,317 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include "paddle/common/exception.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" + +namespace phi { + +template +__global__ void gather_with_mask_permute_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num, + int64_t capacity, + int64_t world_size, + int64_t num_local_experts) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t remaining_after_irank = id % (num_local_experts * capacity); + + int64_t irank = id / (num_local_experts * capacity); + int64_t local_iexpert = remaining_after_irank / capacity; + int64_t row_in_expert = remaining_after_irank % capacity; + int64_t permuted_id = local_iexpert * (world_size * capacity) + + irank * capacity + row_in_expert; + int64_t in_offset = permuted_id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +__global__ void gather_with_mask_kernel( + const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s, d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t N, + int64_t num_active, // skip > num_active pos is num_active specified + int64_t s_shared_num) { + extern __shared__ char shared[]; + int* scatter_index_shared = reinterpret_cast(shared); + float* combine_weights_shared = + reinterpret_cast(shared + s_shared_num * k * sizeof(int)); + int64_t shared_idx_begin = blockIdx.x * blockDim.x * vec_size; + + for (int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * vec_size; + idx < N; + idx += blockDim.x * gridDim.x * vec_size) { + int64_t si = idx / dim; + int64_t di_begin = idx % dim; + int64_t si_shared_begin = shared_idx_begin / dim; + int64_t shared_stride = + min(static_cast(blockDim.x), N - shared_idx_begin); + + for (int64_t i = threadIdx.x; i < k * s_shared_num; i += shared_stride) { + if (si_shared_begin * k + i >= num_rows * k) { + break; + } + scatter_index_shared[i] = scatter_index[si_shared_begin * k + i]; + combine_weights_shared[i] = combine_weights[si_shared_begin * k + i]; + } + __syncthreads(); + + phi::AlignedVector in_vec; + phi::AlignedVector out_vec; + for (int ii = 0; ii < vec_size; ++ii) { + out_vec[ii] = static_cast(0); + } + + for (int64_t i = 0; i < k; ++i) { + int64_t scatter_offset = (si - si_shared_begin) * k + i; + int id = scatter_index_shared[scatter_offset]; + if (num_active >= 0 && id >= num_active) { + continue; + } + if (combine_weights_shared[scatter_offset] > 0.f) { + int64_t in_offset = id * dim + di_begin; + phi::Load(dy + in_offset, &in_vec); + for (int64_t j = 0; j < vec_size; ++j) { + out_vec[j] += in_vec[j]; + } + } + } + phi::Store(out_vec, dx + idx); + shared_idx_begin += blockDim.x * gridDim.x * vec_size; + } +} + +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} + +inline int64_t max_shared_s_num(int64_t num_rows, + int64_t dim, + int64_t threads, + int64_t vec_size) { + if ((threads * vec_size) % dim == 0) { + return min(num_rows, threads * vec_size / dim); + } else { + int64_t max_res = DivUp(threads * 4, dim); + for (int64_t idx = 0; idx < num_rows * dim; idx += vec_size * threads) { + int64_t si_start = idx / dim; + int64_t si_end = min(num_rows * dim, idx + vec_size * threads - 1) / dim; + max_res = max(max_res, (si_end - si_start + 1)); + } + return min(num_rows, max_res); + } +} + +template +void gather_with_mask_launcher(const T* dy, // [s*k, d] + const int* scatter_index, // [s, k] + const float* combine_weights, // [s, k] + T* dx, // [s,k,d] + int64_t num_rows, // s + int64_t k, // k + int64_t dim, // d + int64_t num_active, + cudaStream_t stream, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1, + int64_t capacity = -1) { + int numel = num_rows * dim; + + int64_t threads = 512; + if (dim % 4 == 0) { + int64_t blocks = DivUp(DivUp(numel, 4), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 4); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + PD_CHECK(world_size > 0 && num_local_experts > 0 && capacity > 0); + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } else { + int64_t blocks = DivUp(DivUp(numel, 1), threads); + int64_t s_shared_num = max_shared_s_num(num_rows, dim, threads, 1); + size_t shared_size = k * s_shared_num * (sizeof(int) + sizeof(float)); + +#ifdef DEBUG_MOE_OP + std::cerr + << "[DEBUG-BWD] gather_with_mask without vectorized, s_shared_num=" + << s_shared_num << ", block=" << blocks << std::endl; +#endif + + if (!use_all2all_permute) { + gather_with_mask_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num); + } else { + gather_with_mask_permute_kernel + <<>>(dy, + scatter_index, + combine_weights, + dx, + num_rows, + k, + dim, + numel, + num_active, + s_shared_num, + capacity, + world_size, + num_local_experts); + } + } +} + +template +__global__ void topk_grad_with_mask(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts // e +) { + // init dx to zero + for (int i = blockIdx.x; i < num_rows; i += gridDim.x) { + int base_grad = i * num_experts; + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + dx[base_grad + j] = static_cast(0); + } + __syncthreads(); + int base_index = i * k; + for (int j = threadIdx.x; j < k; j += blockDim.x) { + int64_t idx = topk_idx[base_index + j]; + if (combine_weights[base_index + j] > static_cast(0)) { + dx[base_grad + idx] = dy[base_index + j]; + } + } + } +} + +// y=zero_part(topk(x)) 的反向过程 +// x: [s,e] +// dy: [s,k] +// X: [s, e] -(topk)-> Y:[s, k] - (越界设置为0)-> combine_weights: [s, k] +template +void topk_grad_with_mask_launcher(const T* dy, // [s, k] + const int* topk_idx, // [s, k] + const T* combine_weights, // [s, k] + T* dx, // [s, e] + int64_t num_rows, // s + int64_t k, // k + int64_t num_experts, // e + cudaStream_t stream) { + int blocks = num_rows; + int threads = 1024; + + topk_grad_with_mask<<>>( + dy, topk_idx, combine_weights, dx, num_rows, k, num_experts); +} + +} // namespace phi +#endif diff --git a/paddle/phi/kernels/moe_fuse_op.h b/paddle/phi/kernels/moe_fuse_op.h new file mode 100644 index 00000000000000..80d51844b49efc --- /dev/null +++ b/paddle/phi/kernels/moe_fuse_op.h @@ -0,0 +1,819 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include // 包含常用的 thrust 算法 +#include +#include +#include +#include +#include "paddle/common/enforce.h" +#include "paddle/common/exception.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/moe_kernel_impl.h" + +namespace phi { + +template +__launch_bounds__(TPB) __global__ + void moe_top_k(const T* inputs_after_softmax, + const T* bias, // bias could be nullptr if not used + T* output, + int* indices, + int* source_rows, + const int num_experts, + const int k) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] + : inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = + bias ? inputs_after_softmax[thread_read_offset + result_kvp.key] + : result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, + const T* bias, + T* output, + T* softmax, // no use + int* indices, + int* source_row, + const int num_rows, + const int num_experts, + const int k, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + static constexpr int TPB = 256; + moe_top_k<<>>( + input, bias, output, indices, source_row, num_experts, k); +} + +template +__global__ void modify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id[idx], ik); + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx] << offset) | mask; + // printf("after: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id_out[idx], + // ik); +} + +template +void modify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_expert_id<<>>( + expert_id, expert_id_out, k, num_rows, num_experts); +} + +template +__global__ void unmodify_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + int offset = log2(k) + 1; + expert_id_out[idx] = (expert_id[idx] >> offset); +} + +template +void unmodify_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int64_t num_experts, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + unmodify_expert_id<<>>( + expert_id, expert_id_out, k, num_rows, num_experts); +} + +template +__device__ inline int find_total_elts_leq_target(const T* sorted_indices, + const int arr_length, + const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +template +__global__ void compute_total_rows_before_expert_kernel( + const T* sorted_experts, + const int sorted_experts_len, + const int64_t num_experts, + int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = + find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); + // total_rows_before_expert[0] = 0; + // total_rows_before_expert[1] = 1; + // if (sorted_experts_len > 3) { + // for (int i=0; i<35;i++){ + // total_rows_before_expert[i] = i; + // } + // } +} + +template +void compute_total_rows_before_expert(const T* sorted_indices, + const int total_indices, + const int64_t num_experts, + int64_t* total_rows_before_expert, + const cudaStream_t& stream) { + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity) { + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + } + return; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + // printf("going through: capacity=%lld, num_active=%lld, row=[%d->%d], + // row-in-expert %lld\n", + // capacity, + // num_active, + // expanded_dest_row, expanded_source_row, + // row_in_expert + // ); + if (use_pad) num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + if (use_pad) { + dest_row_ptr = + permuted_output + iexpert * capacity * cols + row_in_expert * cols; + } else { + dest_row_ptr = permuted_output + expanded_dest_row * cols; + } + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } +} + +template +void initialize_moe_routing_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + bool use_pad, + cudaStream_t stream) { + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad); + } else { + initialize_moe_routing_kernel<<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + use_pad); + } +} + +/** + * 原逻辑的output: + * R0E0 + * R0E1 + * R1E0 + * R1E1 + * + * 我们想对all2all和专家gemm做overlap, 所以需要将all2all拆成流水线, + * 为了便于后续计算, 此kernel的output: R0E0 R1E0 R0E1 R1E1 + */ +template +__global__ void initialize_moe_routing_permute_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. +#pragma unroll + for (int i = 0; i < LoopSize; i++) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int expanded_dest_row = blockIdx.x + i * gridDim.x; + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + if (row_in_expert >= capacity) { + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + 0; // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + } + continue; + } + int64_t num_padded = 0; + if (threadIdx.x == 0) { + num_padded = iexpert * capacity - offset; + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row + num_padded; + } + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr; + + const int64_t irank = iexpert / num_local_experts; + const int64_t local_iexpert = iexpert % num_local_experts; + dest_row_ptr = permuted_output + + local_iexpert * world_size * capacity * cols + + irank * capacity * cols + row_in_expert * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +template +void initialize_moe_routing_permute_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int cols, + const int k, + const int64_t capacity, + const int64_t world_size, + const int64_t num_local_experts, + cudaStream_t stream) { + const int loop_size = 2; + const int blocks = (num_rows * k) / loop_size; + assert((num_rows * k) % loop_size == 0); + const int threads = std::min(cols, 1024); + constexpr int max_pack_size = 16 / sizeof(T); + if (cols % max_pack_size == 0) { + initialize_moe_routing_permute_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts); + } else { + initialize_moe_routing_permute_kernel + <<>>( + unpermuted_input, + permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + cols, + k, + capacity, + world_size, + num_local_experts); + } +} + +// moe_ops_partial_nosoftmaxtopk utils + +template +void compute_global_expert_offset( + const T* expert_id, // [len] + T* sort_buffer, // [len] + int64_t* expert_offset, // [num_experts] + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator) { + auto ptr = thrust::device_pointer_cast(expert_id); + auto outptr = thrust::device_pointer_cast(sort_buffer); + auto offsetptr = thrust::device_pointer_cast(expert_offset); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::copy(exec_policy, ptr, ptr + len, outptr); + thrust::sort(exec_policy, outptr, outptr + len); + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sort_buffer, len, num_experts, expert_offset); + thrust::adjacent_difference( + exec_policy, offsetptr, offsetptr + num_experts, offsetptr); + // thrust::transform(offsetptr, + // offsetptr + num_experts, + // thrust::constant_iterator(capacity), + // offsetptr, + // thrust::minimum() + // ); +} + +template +__global__ void modify_and_mask_expert_id(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= k * num_rows) return; + int ik = idx % k; + int irow = idx / k; + // const T mask = (~0) >> (8*sizeof(T)-ik); // 最后 ik 位为 1 其他位为 0 + int mask = ik; // k => 2(11) + // printf("before: idx=%d, expert-id:%d, ik=%d, s=%d, e=%d\n", idx, + // expert_id[idx], ik, expert_start_index, expert_end_index); + int offset = log2(k) + 1; + if (expert_id[idx] < expert_start_index || + expert_id[idx] >= expert_end_index) { + expert_id_out[idx] = (num_experts << offset); // -1 means + } else { + expert_id_out[idx] = (expert_id[idx] << offset) | mask; + } + // printf("after: idx=%d, expert-id:%d, ik=%d\n", idx, expert_id_out[idx], + // ik); +} + +template +void modify_and_mask_expert_id_launcher(const T* expert_id, + T* expert_id_out, + const int k, + const int num_rows, + const int num_experts, + const int expert_start_index, + const int expert_end_index, + const cudaStream_t& stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + + modify_and_mask_expert_id + <<>>(expert_id, + expert_id_out, + k, + num_rows, + num_experts, + expert_start_index, + expert_end_index); +} + +template +void compute_local_expert_offset( + const T* sorted_expert_id, // [len] + int64_t* expert_offset, // [num_experts] + int64_t* expert_num, + const int64_t len, + const int64_t num_experts, + const int64_t capacity, + const cudaStream_t& stream, + const phi::memory_utils::ThrustAllocator& allocator) { + auto offset_ptr = thrust::device_pointer_cast(expert_offset); + auto expert_num_ptr = thrust::device_pointer_cast(expert_num); + const auto& exec_policy = thrust::cuda::par(allocator).on(stream); + thrust::fill( + exec_policy, offset_ptr, offset_ptr + num_experts, static_cast(0)); + + const int threads = std::min(static_cast(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>( + sorted_expert_id, len, num_experts, expert_offset); + // 不考虑 capacity 影响 + thrust::adjacent_difference( + exec_policy, offset_ptr, offset_ptr + num_experts, expert_num_ptr); +} + +template +__global__ void cal_expert_size_and_filter(T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capacity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse) { + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= len) return; + int64_t off = reverse ? expert_offset[expert_end_index - 1] : 0; + if (reverse) { + for (int64_t i = expert_end_index - 1; i >= expert_start_index; --i) { + if (idx >= expert_offset[i]) break; + off = expert_offset[i]; + } + } else { + for (int64_t i = expert_start_index; i != expert_end_index; ++i) { + if (idx < expert_offset[i]) break; + off = expert_offset[i]; + } + } + if (reverse) { + if (((off - 1) - idx) >= capacity) { + expert_id[idx] = num_experts; + } + } else { + if ((idx - off) >= capacity) { + expert_id[idx] = num_experts; + } + } +} + +template +void cal_expert_size_and_filter_launcher(T* expert_id, + const int64_t* expert_offset, + int64_t len, + int64_t num_experts, + int64_t capacity, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse, + const cudaStream_t& stream) { + if (len <= 0) return; + const int64_t threads = std::min(static_cast(1024), len); + const int64_t blocks = (len + threads - 1) / threads; + cal_expert_size_and_filter + <<>>(expert_id, + expert_offset, + len, + num_experts, + capacity, + expert_start_index, + expert_end_index, + reverse); +} + +template +__global__ void build_seqsort_kv_pairs_kernel( + T* seqsort_key, + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int k, + const int64_t num_active, + const int64_t capacity, + int64_t expert_start_index, + bool use_pad) { + const int expanded_dest_row = blockIdx.x * blockDim.x + threadIdx.x; + if (expanded_dest_row >= num_rows * k) { + return; + } + const int expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + const int64_t iexpert = permuted_experts[expanded_dest_row]; + const int64_t offset = iexpert == 0 ? 0 : (expert_offset[iexpert - 1]); + const int64_t row_in_expert = expanded_dest_row - offset; + // printf("DEBUG %d=>%d, num_active=%lld, offset=%lld, cap=%lld \n", + // expanded_dest_row, expanded_source_row, num_active, row_in_expert, + // capacity); 从此以后不会发生截断,后续的 seqsort 也不会截断。 + // printf("expanded_dest_row:%d row_in_expert:%lld capacity:%lld + // num_active:%lld\n", expanded_dest_row, row_in_expert, capacity, + // num_active); + if ((use_pad && row_in_expert >= capacity) || + expanded_dest_row >= num_active) { + // expanded_source_row_to_expanded_dest_row[expanded_source_row] = 0; // + // unset scatter-idx + auto ik = expanded_source_row / num_rows; + auto isent = expanded_source_row % num_rows; // transpose + combine_weights[isent * k + ik] = 0.f; // unset combine-weight + return; + } + + // auto num_padded = use_pad ? (iexpert - expert_start_index) * capacity - + // offset : 0; expanded_source_row_to_expanded_dest_row[expanded_source_row] = + // expanded_dest_row + num_padded; + + // Duplicate and permute rows + T source_row = expanded_source_row % num_rows; + + if (use_pad) { + // printf("inner print: k=%d num_row=%d before minus %d\n", k, num_rows, + // source_row); + seqsort_key[(iexpert - expert_start_index) * capacity + row_in_expert] = + source_row; // 为保证 padding 位置(0)在最后, 所以对 pos-id + // 取减去其最大值 + seqsort_value[(iexpert - expert_start_index) * capacity + row_in_expert] = + expanded_source_row; + } else { + seqsort_key[expanded_dest_row] = source_row; + seqsort_value[expanded_dest_row] = expanded_source_row; + } +} + +template +void build_seqsort_kv_pairs_kernel_launcher( + T* seqsort_key, // 实现初始化为 num-rows,保证 sort 到最后 + T* seqsort_value, + const int* expanded_dest_row_to_expanded_source_row, + // int* expanded_source_row_to_expanded_dest_row, + const int* permuted_experts, + const int64_t* expert_offset, + float* combine_weights, // output + const int num_rows, + const int k, + const int64_t num_active, // -1 expert pos + const int64_t capacity, + const int64_t expert_start_index, + bool use_pad, + cudaStream_t stream) { + int max = 1024; + const int threads = std::min(max, num_rows * k); + const int blocks = (num_rows * k + threads - 1) / threads; + build_seqsort_kv_pairs_kernel<<>>( + seqsort_key, + seqsort_value, + expanded_dest_row_to_expanded_source_row, + // expanded_source_row_to_expanded_dest_row, + permuted_experts, + expert_offset, + combine_weights, + num_rows, + k, + num_active, + capacity, + expert_start_index, + use_pad); +} + +template +__global__ void copy_unpermuted_to_permuted_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, + const int64_t k, + const int64_t cols) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + const int padded_dest_row = blockIdx.x; + if (padded_out_to_unpermuted_input[padded_dest_row] == num_rows) { + // padded_out_to_unpermuted_input[padded_dest_row] = -1; + return; // padded place + } + const int source_row = padded_out_to_unpermuted_input[padded_dest_row]; + const int source_row_expanded = padded_out_to_expanded_input[padded_dest_row]; + if (threadIdx.x == 0) { + expanded_input_to_padded_out[source_row_expanded] = padded_dest_row; + } + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* padded_dest_row_ptr = permuted_output + padded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + phi::Load(&source_row_ptr[tid], &src_vec); + phi::Store(src_vec, &padded_dest_row_ptr[tid]); + } + PADDLE_ENFORCE( + (padded_dest_row < padded_len) && (source_row_expanded < num_rows * k), + "The index is out of bounds, " + "origin_input[%d] -> distributed_input:[%d], should < [%ld],[%ld] \n", + source_row_expanded, + padded_dest_row, + num_rows * k, + padded_len); + + // for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + // padded_dest_row_ptr[tid] = source_row_ptr[tid]; // copy + // } +} + +template +void copy_unpermuted_to_permuted_kernelLauncher( + const T* unpermuted_input, + T* permuted_output, + const int* padded_out_to_unpermuted_input, + const int* padded_out_to_expanded_input, + int* expanded_input_to_padded_out, + const int64_t padded_len, + const int64_t num_rows, // unpermuted_input_len + const int64_t k, + const int64_t num_cols, + cudaStream_t stream) { + auto blocks = padded_len; + auto threads = std::min(num_cols, static_cast(1024)); + constexpr int64_t max_pack_size = 16 / sizeof(T); + if (num_cols % max_pack_size == 0) { + copy_unpermuted_to_permuted_kernel + <<>>(unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + } else { + copy_unpermuted_to_permuted_kernel + <<>>(unpermuted_input, + permuted_output, + padded_out_to_unpermuted_input, + padded_out_to_expanded_input, + expanded_input_to_padded_out, + padded_len, + num_rows, + k, + num_cols); + } +} +} // namespace phi +#endif diff --git a/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h new file mode 100644 index 00000000000000..6c3b4d6d6f241c --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_grad_kernel.h @@ -0,0 +1,46 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* x_grad, + DenseTensor* gate_logits_grad); + +template +void moe_dispatch_bwd(const Context& dev_ctx, + const DenseTensor& combine_weights, // [s, k] + const DenseTensor& scatter_index, // [k, s] + const DenseTensor& expert_id, // [s, k] + const DenseTensor& y_grad, // [num_experts * capacity, h] + const DenseTensor& combine_weights_grad, // [s, k] + const DenseTensor& x_grad, + const DenseTensor& gate_logits_grad, + int64_t capacity, + bool use_all2all_permute = false, + int64_t world_size = -1, + int64_t num_local_experts = -1); +} // namespace phi diff --git a/paddle/phi/kernels/moe_gate_dispatch_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_kernel.h new file mode 100644 index 00000000000000..b17d0387c7a750 --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MoeGradDispatchKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const paddle::optional& corr_bias, + const int64_t k, + const int64_t capacity, + const bool use_pad, + DenseTensor* y, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, + DenseTensor* expert_id); + +} // namespace phi diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h new file mode 100644 index 00000000000000..f8cd9bee0d6083 --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeGateDispatchGradKernel(const Context& dev_ctx, + const DenseTensor& combine_weights, + const DenseTensor& scatter_index, + const DenseTensor& expert_id, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_grad, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* x_grad, + DenseTensor* gate_logits_grad); +} // namespace phi diff --git a/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h new file mode 100644 index 00000000000000..1d6c1f5fed0b33 --- /dev/null +++ b/paddle/phi/kernels/moe_gate_dispatch_permute_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoEDispatchPermuteKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_logits, + const paddle::optional& corr_bias, + int64_t k, + int64_t capacity, + int64_t world_size, + DenseTensor* y, + DenseTensor* combine_weights, + DenseTensor* scatter_index, + DenseTensor* expert_offset, + DenseTensor* expert_id); +} // namespace phi diff --git a/paddle/phi/kernels/moe_kernel_impl.h b/paddle/phi/kernels/moe_kernel_impl.h new file mode 100644 index 00000000000000..68b84efc9fdfcb --- /dev/null +++ b/paddle/phi/kernels/moe_kernel_impl.h @@ -0,0 +1,649 @@ +// NOLINT +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include +#include +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" +namespace phi { + +static const float HALF_FLT_MAX = 65504.F; +static const float HALF_FLT_MIN = -65504.F; +static inline size_t AlignTo16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter { + public: + inline CubKeyValueSorter(); + + inline CubKeyValueSorter(cudaStream_t stream = 0); // NOLINT + + inline explicit CubKeyValueSorter(const int num_experts); + + inline void update_num_experts(const int num_experts); + + inline size_t getWorkspaceSize(const size_t num_key_value_pairs, + bool descending = false); + + template + inline void run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; + cudaStream_t stream_; +}; + +// ===== CUB Sorting things ===== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(cudaStream_t stream) + : num_experts_(0), num_bits_(sizeof(int) * 8), stream_(stream) {} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts), + num_bits_(static_cast(log2(num_experts)) + 1) {} + +void CubKeyValueSorter::update_num_experts(const int num_experts) { + num_experts_ = num_experts; + num_bits_ = static_cast(log2(num_experts)) + + 3; // 额外增加 3 位用于标记 topk的位置 +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, + bool descending) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + 32, + stream_); + } else { + cub::DeviceRadixSort::SortPairs(NULL, + required_storage, + null_int, + null_int, + null_int, + null_int, + num_key_value_pairs, + 0, + num_bits_, + stream_); + } + return required_storage; +} + +template +inline void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const KeyT* keys_in, + KeyT* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + std::stringstream err_ss; + err_ss << "[Error][CubKeyValueSorter::run]\n"; + err_ss + << "Error. The allocated workspace is too small to run this problem.\n"; + err_ss << "Expected workspace size of at least " << expected_ws_size + << " but got problem size " << workspace_size << "\n"; + throw std::runtime_error(err_ss.str()); + } + if (descending) { + cub::DeviceRadixSort::SortPairsDescending(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + 32, + stream); + } else { + cub::DeviceRadixSort::SortPairs(workspace, + actual_ws_size, + keys_in, + keys_out, + values_in, + values_out, + num_key_value_pairs, + 0, + num_bits_, + stream); + } +} + +template <> +inline void CubKeyValueSorter::run(void* workspace, + const size_t workspace_size, + const __nv_bfloat16* keys_in, + __nv_bfloat16* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + bool descending, + cudaStream_t stream) {} + +// CubKeyValueSorter sorter_(stream); + +// -------- initialize_expert_choice_route_kernel -------- // +template +__global__ void initialize_expert_choice_route_kernel( + int* expert_for_source_row, + int* source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* total_rows_before_expert, + T* attr_mask, + const int cols, + const int k, + const int batch_size) { + int start = cols * blockIdx.x; + + for (int i = threadIdx.x; i < cols; i += blockDim.x) { + expert_for_source_row[start + i] = blockIdx.x; + source_row[start + i] = start + i; + expanded_source_row_to_expanded_dest_row[start + i] = -1; + attr_mask[start + i] = (T)1.0f; + } + if (threadIdx.x == 0) { + total_rows_before_expert[blockIdx.x] = batch_size * k * (blockIdx.x + 1); + } +} + +// -------- softmax_kernel -------- // +template +__global__ void softmax_kernel_v4( + T* qk_buf_, + const T* qk_buf_src, // shape [batch_size, seq_len] + const T* attr_mask, // shape [batch_size, seq_len] + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + float data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x; + + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + data[i] = qk + mask_val; + local_max = fmax(local_max, data[i]); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + data[i] = __expf(data[i] - s_max); + local_sum += data[i]; + } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + qk_buf_[qk_offset] = (T)(data[i] * s_mean); + } +#endif +} + +template +__global__ void softmax_kernel_v4_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x; + + T2 qk = qk_buf_half2[qk_offset]; + T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); + mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), + __float2half2_rn(-10000.0f)); + + data[i] = __hadd2(qk, mask_val); + + local_max = fmax( + local_max, + fmax(static_cast(data[i].x), static_cast(data[i].y))); + } + + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); + + float local_sum = 0; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); + local_sum += static_cast(data[i].x + data[i].y); + } + + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); + } +#endif +} + +template +__global__ void softmax_kernel_v5_half2(T* qk_buf_, + const T* attr_mask, + const int batch_size, + const int seq_len) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + using T2 = half2; + T2* qk_buf_half2 = reinterpret_cast(qk_buf_); + const T2* attr_mask_half2 = (const T2*)attr_mask; + + T2 data[NUM][ITEMS_PER_THREAD]; + + int qk_offset[NUM]; + + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_max[j] = -1e20f; + } + + const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + int mask_offset[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + + T2 mask_val[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); + } + + T2 qk[NUM]; +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), + __float2half2_rn(-10000.0f)); + } +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = __hadd2(qk[j], mask_val[j]); + local_max[j] = fmax(local_max[j], + fmax(static_cast(data[j][i].x), + static_cast(data[j][i].y))); + } + } + if (blockDim.x <= 32) { + phi::funcs::WarpReduceMaxV2(local_max); + } else { + phi::funcs::BlockReduceMaxV2(local_max); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_max[j] = local_max[j]; + } + } + __syncthreads(); + float local_sum[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; + } + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + local_sum[j] += static_cast(data[j][i].x + data[j][i].y); + } + } + + if (blockDim.x <= 32) { + phi::funcs::WarpReduceSumV2(local_sum); + + } else { + phi::funcs::BlockReduceSumV2(local_sum); + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); + } + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + qk_buf_half2[qk_offset[j]] = + __hmul2(data[j][i], __float2half2_rn(s_sum[j])); + } + } +#endif +} + +// -------- transpose_kernel -------- // +template +__global__ void transposeAxis01( + T* out, T* in, const int dim0, const int dim1, const int dim2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < dim0 * dim1 * dim2) { + const int input_dim2_index = index % dim2; + index = (index - input_dim2_index) / dim2; + const int input_dim1_index = index % dim1; + index = (index - input_dim1_index) / dim1; + const int input_dim0_index = index % dim0; + + out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + + input_dim2_index] = in[input_dim0_index * dim1 * dim2 + + input_dim1_index * dim2 + input_dim2_index]; + } +} + +// -------- padding_kernel -------- // +template +__global__ void paddingKernel(T* output1, + int* output2, + const T* input1, + const int* input2, + const int* input_lengths, + const int num_tokens, + const int batch_size, + const int max_seq_len, + const int num_experts) { + const bool IS_FP32 = std::is_same::value; + const T MIN_T_VAL = (!IS_FP32) ? (T)HALF_FLT_MIN : (T)FLT_MIN; + int offset1 = blockIdx.x * num_tokens; + int offset2 = blockIdx.x * batch_size * max_seq_len; + for (int i = 0; i < batch_size; i++) { + const T* in1_ptr = input1 + offset1; + const int* in2_ptr = input2 + offset1; + int input_length = input_lengths[i]; + offset1 += input_length; + + T* out1_ptr = output1 + offset2; + int* out2_ptr = output2 + offset2; + offset2 += max_seq_len; + + for (int j = threadIdx.x; j < max_seq_len; j += max_seq_len) { + if (j < input_length) { + out1_ptr[j] = in1_ptr[j]; + out2_ptr[j] = in2_ptr[j]; + } else { + out1_ptr[j] = MIN_T_VAL; + out2_ptr[j] = 0; + } + } + } +} + +// -------- general_topk_pair_sort_kernel -------- // +template +__global__ void general_topk_pair_sort(T* out_keys, + int* out_values, + T* in_keys, + int* in_values) { + typedef cub::BlockRadixSort + BlockRadixSort; + typedef cub:: + BlockLoad + BlockLoadKey; + typedef cub:: + BlockLoad + BlockLoadValue; + typedef cub:: + BlockStore + BlockStoreKey; + typedef cub::BlockStore + BlockStoreValue; + + __shared__ union { + typename BlockRadixSort::TempStorage sort; + typename BlockLoadKey::TempStorage loadkey; + typename BlockLoadValue::TempStorage loadvalue; + typename BlockStoreKey::TempStorage storekey; + typename BlockStoreValue::TempStorage storevalue; + } temp_storage; + + int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; + + T thread_keys[ITEMS_PER_THREAD]; + int thread_values[ITEMS_PER_THREAD]; + BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); + BlockLoadValue(temp_storage.loadvalue) + .Load(in_values + block_offset, thread_values); + __syncthreads(); + + BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); + __syncthreads(); + + BlockStoreKey(temp_storage.storekey) + .Store(out_keys + block_offset, thread_keys); + BlockStoreValue(temp_storage.storevalue) + .Store(out_values + block_offset, thread_values); +} + +// -------- finalize_moe_routing_kernel -------- // +template +__global__ void finalize_moe_routing_kernel( + const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int cols, + const int k, + bool ec_route) { + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + const T* skip_row_ptr = skip + original_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output = skip_row_ptr[tid]; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + if (ec_route && expanded_permuted_row == -1) continue; + const int64_t k_offset = + ec_route ? expanded_original_row : original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = + expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = ec_route ? k_idx : expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = + thread_output + + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} + +// -------- initialize_moe_routing_kernel -------- // +template +__global__ void initialize_moe_routing_kernel( + const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_rows, + const int active_rows, + const int cols, + const int k, + const int max_seq_len, + bool ec_route) { + // using LoadT = phi::AlignedVector; + // LoadT src_vec; + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = + ec_route ? expanded_dest_row_to_expanded_source_row[expanded_dest_row / + k * max_seq_len + + expanded_dest_row % k] + : expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + dest_row_ptr[tid] = source_row_ptr[tid]; + // phi::Load(&source_row_ptr[tid], &src_vec); + // phi::Store(src_vec, &dest_row_ptr[tid]); + } + } +} + +} // namespace phi + +#endif diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h new file mode 100644 index 00000000000000..c5cdcbfe6f4443 --- /dev/null +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_grad_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeGateDispatchPartialNoSoftMaxTopkGradKernel( + const Context& dev_ctx, + const DenseTensor& combine_weights_out, + const DenseTensor& scatter_index, + const DenseTensor& scatter_index_rev, + const DenseTensor& expert_offset, + const DenseTensor& expert_offset_local, + const DenseTensor& y_grad, + const DenseTensor& combine_weights_out_grad, + int64_t k, + int64_t capacity, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + DenseTensor* x_grad, + DenseTensor* combine_weights_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h new file mode 100644 index 00000000000000..0ecf6afda63e91 --- /dev/null +++ b/paddle/phi/kernels/moe_ops_partial_nosoftmaxtopk_kernel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void MoeGateDispatchPartialNoSoftMaxTopkKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& combine_weights, + const DenseTensor& expert_id, + int64_t k, + int64_t capacity, + int64_t num_experts, + bool use_pad, + int64_t expert_start_index, + int64_t expert_end_index, + bool reverse_token_drop, + DenseTensor* y, + DenseTensor* combine_weights_out, + DenseTensor* scatter_index, + DenseTensor* scatter_index_rev, + DenseTensor* expert_offset, + DenseTensor* expert_nums_local); + +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index d96f1231fdc472..3a237a78c29b61 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -361,6 +361,15 @@ param: [softmax, label, loss_grad, ignore_index, rank, nranks] inplace : (softmax -> logits_grad) +- backward_op : cal_aux_loss_grad + forward : cal_aux_loss (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) -> Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) + args : ( Tensor gate_prob, Tensor seqlen_float, Tensor ce, Tensor l_aux_loss_grad, int64_t num_experts, bool use_group, int64_t moe_k) + output : Tensor(gate_prob_grad) + infer_meta : + func : CalAuxLossGradInferMeta + kernel : + func : cal_aux_loss_grad + - backward_op : cast_grad forward : cast (Tensor x, DataType dtype) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -2262,6 +2271,45 @@ kernel : func : mode_grad +- backward_op : moe_combine_grad + forward : moe_combine (Tensor x, Tensor combine_weights, Tensor scatter_index) -> Tensor(y) + args : (Tensor x, Tensor combine_weights, Tensor scatter_index, Tensor y_grad) + output : Tensor(x_grad), Tensor(combine_weights_grad) + infer_meta : + func : MoeCombineGradInferMeta + kernel : + func : moe_combine_grad + +- backward_op : moe_gate_dispatch_grad + forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(x_grad), Tensor(gate_logits_grad) + infer_meta : + func : MoeGateDispatchGradInferMeta + kernel : + func : moe_gate_dispatch_grad + data_type : y_grad + +- backward_op : moe_gate_dispatch_partial_nosoftmaxtopk_grad + forward : moe_gate_dispatch_partial_nosoftmaxtopk (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) -> Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) + args : (Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_nums_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index) + output : Tensor(x_grad), Tensor(combine_weights_grad) + infer_meta : + func : MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta + kernel : + func : moe_gate_dispatch_partial_nosoftmaxtopk_grad + data_type : y_grad + +- backward_op : moe_gate_dispatch_permute_grad + forward : moe_gate_dispatch_permute (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, int64_t world_size) + output : Tensor(x_grad), Tensor(gate_logits_grad) + infer_meta : + func : MoeGateDispatchPermuteGradInferMeta + kernel : + func : moe_gate_dispatch_permute_grad + data_type : y_grad + - backward_op : mp_allreduce_sum_grad forward : mp_allreduce_sum(Tensor x, int ring_id = 0) -> Tensor(out) args : (Tensor out_grad, int ring_id = 0) @@ -3823,6 +3871,15 @@ func: check_model_nan_inf data_type: out_grad +- backward_op: fused_rms_norm_ext_grad + forward: fused_rms_norm_ext (Tensor x, Tensor scale, float epsilon) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale,Tensor invvar, Tensor y_grad, float epsilon) + output: Tensor(x_grad), Tensor(scale_grad) + infer_meta: + func: FusedRMSNormGradInferMeta + kernel: + func: fused_rms_norm_ext_grad + - backward_op: im2sequence_grad forward: im2sequence (Tensor x, Tensor y, int[] kernels, int[] strides = {1, 1}, int[] paddings = {0, 0, 0, 0}, int[] out_stride = {1, 1}) -> Tensor (out) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml old mode 100755 new mode 100644 index 87d83595de2445..ec2ec08e1d2390 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -820,6 +820,15 @@ backward: broadcast_tensors_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : build_src_rank_and_local_expert_id + args : (Tensor expert_num_global_tensor, int64_t[] expert_num_global, int64_t num_local_experts) + output : Tensor(vector), Tensor(local_expert_id) + infer_meta : + func : BuildSrcRankAndLocalExpertIdInferMeta + kernel : + func : build_src_rank_and_local_expert_id + data_type : expert_num_global_tensor + - op : c_allreduce_sum args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -884,6 +893,17 @@ func : c_split param: [x, rank, nranks, use_model_parallel] +- op : cal_aux_loss + args : (Tensor gate_prob, Tensor dispatch_mask, Tensor tokens_mask, Tensor dispatch_tokens_mask, int64_t num_experts, bool use_group, int64_t moe_k, float clip_min) + output : Tensor(l_aux_loss), Tensor(seqlen_float), Tensor(ce) + infer_meta : + func : CalAuxLossInferMeta + kernel : + func : cal_aux_loss + data_type : gate_prob + optional: tokens_mask, dispatch_tokens_mask + backward : cal_aux_loss_grad + - op : calc_reduced_attn_scores args : (Tensor q, Tensor k, Tensor softmax_lse) output : Tensor(reduced_scores) @@ -1793,6 +1813,15 @@ backward : expand_as_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : expand_modality_expert_id + args : (Tensor expert_id, int64_t num_expert_per_modality, int64_t group_size, int64_t modality_offset, bool is_group_expert) + output : Tensor(expert_id_out) + infer_meta : + func : ExpandModalityExpertIdInferMeta + kernel : + func : expand_modality_expert_id + data_type : expert_id + - op : expm1 args : (Tensor x) output : Tensor(out) @@ -3598,6 +3627,49 @@ backward : mode_grad interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface +- op : moe_combine + args : (Tensor x, Tensor combine_weights, Tensor scatter_index) + output : Tensor(y) + infer_meta : + func : MoeCombineInferMeta + kernel : + func : moe_combine + data_type : x + backward : moe_combine_grad + +- op : moe_gate_dispatch + args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) + output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + infer_meta : + func : MoeGateDispatchInferMeta + kernel : + func : moe_gate_dispatch + data_type : x + optional : corr_bias + backward : moe_gate_dispatch_grad + +- op : moe_gate_dispatch_partial_nosoftmaxtopk + args : (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) + output : Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local) + infer_meta : + func : MoeGateDispatchPartialNoSoftmaxTopKInferMeta + kernel : + func : moe_gate_dispatch_partial_nosoftmaxtopk + data_type : x + # inplace : (combine_weights -> combine_weights_out) + backward : moe_gate_dispatch_partial_nosoftmaxtopk_grad + +- op : moe_gate_dispatch_permute + args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, int64_t world_size) + output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id) + infer_meta : + func : MoeGateDispatchPermuteInferMeta + kernel : + func : moe_gate_dispatch_permute + data_type : x + optional : corr_bias + backward : moe_gate_dispatch_permute_grad + - op : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) @@ -5667,6 +5739,25 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface traits : paddle::dialect::ForwardOnlyTrait +- op: fused_rms_norm_ext + args: (Tensor x, Tensor scale, float epsilon) + output: Tensor(y), Tensor(invvar) + infer_meta: + func: FusedRMSNormInferMeta + kernel: + func: fused_rms_norm_ext + data_type: x + backward: fused_rms_norm_ext_grad + +- op: int_bincount + args: (Tensor x, int64_t low, int64_t high, int64_t dtype) + output: Tensor(out) + infer_meta: + func: IntBincountInferMeta + kernel: + func: int_bincount + data_type: x + - op: number_count args: (Tensor numbers, int upper_range) output: Tensor(out) diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index aec7625145d348..05ec5f17620df3 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +16,22 @@ from .blha_get_max_len import blha_get_max_len from .block_multihead_attention import ( block_multihead_attention, - block_multihead_attention_xpu, # noqa: F401 + block_multihead_attention_xpu, ) + +# from .moe_gate_dispatch_permute import moe_gate_dispatch_permute +from .build_src_rank_and_local_expert_id import ( + build_src_rank_and_local_expert_id, +) +from .cal_aux_loss import cal_aux_loss +from .expand_modality_expert_id import expand_modality_expert_id from .fused_bias_act import fused_bias_act from .fused_dot_product_attention import ( - cudnn_flash_attention, # noqa: F401 - fused_dot_product_attention, # noqa: F401 + cudnn_flash_attention, + fused_dot_product_attention, ) from .fused_dropout_add import fused_dropout_add -from .fused_gate_attention import fused_gate_attention # noqa: F401 +from .fused_gate_attention import fused_gate_attention from .fused_layer_norm import fused_layer_norm from .fused_matmul_bias import ( fused_linear, @@ -31,6 +39,7 @@ fused_matmul_bias, ) from .fused_rms_norm import fused_rms_norm +from .fused_rms_norm_ext import fused_rms_norm_ext from .fused_rotary_position_embedding import fused_rotary_position_embedding from .fused_transformer import ( fused_bias_dropout_residual_layer_norm, @@ -38,7 +47,14 @@ fused_multi_head_attention, fused_multi_transformer, ) +from .int_bincount import int_bincount from .masked_multihead_attention import masked_multihead_attention +from .moe_combine import moe_combine +from .moe_gate_dispatch import moe_gate_dispatch +from .moe_gate_dispatch_partial_nosoftmaxtopk import ( + moe_gate_dispatch_partial_nosoftmaxtopk, +) +from .moe_gate_dispatch_permute import moe_gate_dispatch_permute from .swiglu import swiglu from .variable_length_memory_efficient_attention import ( variable_length_memory_efficient_attention, @@ -62,4 +78,13 @@ "blha_get_max_len", "block_multihead_attention", "swiglu", + "moe_combine", + "expand_modality_expert_id", + "cal_aux_loss", + "build_src_rank_and_local_expert_id", + "int_bincount", + "fused_rms_norm_ext", + "moe_gate_dispatch", + "moe_gate_dispatch_permute", + "moe_gate_dispatch_partial_nosoftmaxtopk", ] diff --git a/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py new file mode 100644 index 00000000000000..25195c236b629c --- /dev/null +++ b/python/paddle/incubate/nn/functional/build_src_rank_and_local_expert_id.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def build_src_rank_and_local_expert_id( + expert_num_global_tensor: Tensor, + expert_num_global: list, + num_local_experts: int, + name: str | None = None, +) -> Tensor: + """ + Args: + expert_num_global_tensor: + expert_num_global: + num_local_experts: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.build_src_rank_and_local_expert_id( + expert_num_global_tensor, expert_num_global, num_local_experts + ) + + helper = LayerHelper('expert_num_global_tensor', **locals()) + vector = helper.create_variable_for_type_inference(dtype=paddle.int32) + local_expert_id = helper.create_variable_for_type_inference( + dtype=paddle.int32 + ) + + inputs = {'expert_num_global_tensor': expert_num_global_tensor} + attrs = { + 'expert_num_global': expert_num_global, + 'num_local_experts': num_local_experts, + } + outputs = {'vector': vector, 'local_expert_id': local_expert_id} + helper.append_op( + type='build_src_rank_and_local_expert_id', + inputs=inputs, + attrs=attrs, + outputs=outputs, + ) + return vector, local_expert_id diff --git a/python/paddle/incubate/nn/functional/cal_aux_loss.py b/python/paddle/incubate/nn/functional/cal_aux_loss.py new file mode 100644 index 00000000000000..56676fa18b7d28 --- /dev/null +++ b/python/paddle/incubate/nn/functional/cal_aux_loss.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def cal_aux_loss( + gate_prob: Tensor, + dispatch_mask: Tensor, + tokens_mask: Tensor, + dispatch_tokens_mask: Tensor, + num_experts: int, + use_group: bool, + moe_k: int, + clip_min: float, + name: str | None = None, +) -> Tensor: + """ + Args: + gate_prob: + dispatch_mask: + tokens_mask: + dispatch_tokens_mask: + num_experts: + use_group: + moe_k: + clip_min: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, + ) + + helper = LayerHelper('cal_aux_loss', **locals()) + l_aux_loss = helper.create_variable_for_type_inference( + dtype=gate_prob.dtype + ) + seqlen_float = helper.create_variable_for_type_inference( + dtype=gate_prob.dtype + ) + ce = helper.create_variable_for_type_inference(dtype=gate_prob.dtype) + + inputs = { + 'gate_prob': gate_prob, + 'dispatch_mask': dispatch_mask, + 'tokens_mask': tokens_mask, + 'dispatch_tokens_mask': dispatch_tokens_mask, + } + attrs = { + 'num_experts': num_experts, + 'use_group': use_group, + 'moe_k': moe_k, + 'clip_min': clip_min, + } + outputs = {'l_aux_loss': l_aux_loss, 'seqlen_float': seqlen_float, 'ce': ce} + helper.append_op( + type='cal_aux_loss', inputs=inputs, attrs=attrs, outputs=outputs + ) + return l_aux_loss, seqlen_float, ce diff --git a/python/paddle/incubate/nn/functional/expand_modality_expert_id.py b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py new file mode 100644 index 00000000000000..e91a02ef795783 --- /dev/null +++ b/python/paddle/incubate/nn/functional/expand_modality_expert_id.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def expand_modality_expert_id( + expert_id: Tensor, + num_expert_per_modality: int, + group_size: int, + modality_offset: int, + is_group_expert: bool, + name: str | None = None, +) -> Tensor: + """ + Args: + expert_id: + num_expert_per_modality: + group_size: + modality_offset: + is_group_expert: + + Returns: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.expand_modality_expert_id( + expert_id, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + ) + helper = LayerHelper('expand_modality_expert_id', **locals()) + expert_id_out = helper.create_variable_for_type_inference( + dtype=expert_id.dtype + ) + inputs = {'expert_id': expert_id} + attrs = { + 'num_expert_per_modality': num_expert_per_modality, + 'group_size': group_size, + 'modality_offset': modality_offset, + 'is_group_expert': is_group_expert, + } + helper.append_op( + type='expand_modality_expert_id', + inputs=inputs, + attrs=attrs, + outputs={'expert_id_out': expert_id_out}, + ) + return expert_id_out diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py new file mode 100644 index 00000000000000..dd3cb392793e46 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_rms_norm_ext.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# File: python/paddle/incubate/nn/functional/layer_norm_cuda.py +from paddle import _C_ops +from paddle.base.data_feeder import convert_dtype +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + + +def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None): + """ + Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation. + Args: + x (Tensor): Input tensor of shape [rows, cols] or higher dimensions (flattened to 2D). + scale (Tensor): Scale tensor of shape [cols]. + bias (Tensor, optional): Bias tensor of shape [cols]. If None, no bias is added. + epsilon (float): Small constant to avoid division by zero. + name (str, optional): Name of the operator. + Returns: + y (Tensor): Normalized tensor of same shape as x. + mean (Tensor): Tensor of shape [rows], the mean of each row. + invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. + """ + if in_dynamic_or_pir_mode(): + return _C_ops.fused_rms_norm_ext(x, scale, epsilon) + helper = LayerHelper('fused_rms_norm_ext', **locals()) + dtype = convert_dtype(x.dtype) + y = helper.create_variable_for_type_inference(dtype) + invvar = helper.create_variable_for_type_inference('float32') + + inputs = {'x': x, 'scale': scale} + + helper.append_op( + type='fused_rms_norm_ext', + inputs=inputs, + outputs={'y': y, 'invvar': invvar}, + attrs={'epsilon': epsilon}, + ) + return y, invvar diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py new file mode 100644 index 00000000000000..9e444ae5992a30 --- /dev/null +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.base.data_feeder import convert_dtype +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + + +def int_bincount(x, low, high, dtype=None, name=None): + if in_dynamic_or_pir_mode(): + return _C_ops.int_bincount(x, low, high, dtype) + + helper = LayerHelper("int_bincount", **locals()) + out_dtype = dtype if dtype is not None else x.dtype + y = helper.create_variable_for_type_inference(dtype=out_dtype) + dtype_attr = convert_dtype(out_dtype) + + helper.append_op( + type="int_bincount", + inputs={"x": x}, + outputs={"y": y}, + attrs={ + "low": low, + "high": high, + "dtype": dtype_attr, + }, + ) + return y diff --git a/python/paddle/incubate/nn/functional/moe_combine.py b/python/paddle/incubate/nn/functional/moe_combine.py new file mode 100644 index 00000000000000..e9e23915ce0a5e --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_combine.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def moe_combine( + x: Tensor, + combine_weights: Tensor, + scatter_index: Tensor, + name: str | None = None, +) -> Tensor: + """ + Args: + x: Input tensor [seq, dim] + combine_weights: Combination weights [s, k] + scatter_index: Scatter indices [k, s] dtype=int32 + + Returns: + Output Combined output [s, dim] + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_combine(x, combine_weights, scatter_index) + helper = LayerHelper('moe_combine', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + inputs = { + 'x': x, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index, + } + helper.append_op(type='moe_combine', inputs=inputs, outputs={'y': y}) + return y diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py new file mode 100644 index 00000000000000..2e50f6c1698f22 --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle import _C_ops + +# from ....framework import LayerHelper, in_dynamic_or_pir_mode +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def moe_gate_dispatch( + x: Tensor, + gate_logits: Tensor, + corr_bias: Tensor, + k: int, + capacity: int, + use_pad: bool, + name: str | None = None, +) -> Tensor: + """ + Args: + x: + gate_logits: + corr_bias: + k: + capacity: + use_pad: + + Returns: + y: + combine_weights: + scatter_index: + expert_offset: + expert_id: + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch( + x, gate_logits, corr_bias, k, capacity, use_pad + ) + + helper = LayerHelper('moe_gate_dispatch', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights = helper.create_variable_for_type_inference( + dtype=paddle.float32 + ) + scatter_index = helper.create_variable_for_type_inference( + dtype=paddle.int32 + ) + expert_offset = helper.create_variable_for_type_inference( + dtype=paddle.int64 + ) + expert_id = helper.create_variable_for_type_inference(dtype=paddle.int32) + + inputs = { + 'x': x, + 'gate_logits': gate_logits, + 'corr_bias': corr_bias, + } + attrs = { + 'k': k, + 'capacity': capacity, + 'use_pad': use_pad, + } + outputs = { + 'y': y, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index, + 'expert_offset': expert_offset, + 'expert_id': expert_id, + } + helper.append_op( + type='moe_gate_dispatch', + inputs=inputs, + attrs=attrs, + outputs=outputs, + ) + return y, combine_weights, scatter_index, expert_offset, expert_id diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py new file mode 100644 index 00000000000000..ef591637fb2502 --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def moe_gate_dispatch_partial_nosoftmaxtopk( + x: Tensor, + combine_weights: Tensor, + expert_id: Tensor, + k: int, + capacity: int, + num_experts: int, + use_pad: bool, + expert_start_index: int, + expert_end_index: int, + reverse_token_drop: bool, + name: str | None = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch_partial_nosoftmaxtopk( + x, + combine_weights, + expert_id, + k, + capacity, + num_experts, + use_pad, + expert_start_index, + expert_end_index, + reverse_token_drop, + ) + helper = LayerHelper("moe_gate_dispatch_partial_nosoftmaxtopk", **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights_out = helper.create_variable_for_type_inference( + dtype=combine_weights.dtype + ) + scatter_index = helper.create_variable_for_type_inference(dtype='int32') + scatter_index_rev = helper.create_variable_for_type_inference(dtype='int32') + expert_offset = helper.create_variable_for_type_inference(dtype='int64') + expert_nums_local = helper.create_variable_for_type_inference(dtype='int64') + inputs = { + "x": x, + "combine_weights": combine_weights, + "expert_id": expert_id, + } + outputs = { + "y": y, + "combine_weights_out": combine_weights_out, + "scatter_index": scatter_index, + "scatter_index_rev": scatter_index_rev, + "expert_offset": expert_offset, + "expert_nums_local": expert_nums_local, + } + attrs = { + "k": k, + "capacity": capacity, + "num_experts": num_experts, + "use_pad": use_pad, + "expert_start_index": expert_start_index, + "expert_end_index": expert_end_index, + "reverse_token_drop": reverse_token_drop, + } + helper.append_op( + type="moe_gate_dispatch_partial_nosoftmaxtopk", + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + return ( + y, + combine_weights_out, + scatter_index, + scatter_index_rev, + expert_offset, + expert_nums_local, + ) diff --git a/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py new file mode 100644 index 00000000000000..9721590f1443f0 --- /dev/null +++ b/python/paddle/incubate/nn/functional/moe_gate_dispatch_permute.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops +from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.layer_helper import LayerHelper + +if TYPE_CHECKING: + from paddle import Tensor + + +def moe_gate_dispatch_permute( + x: Tensor, + gate_logits: Tensor, + corr_bias: Tensor, + k: int, + capacity: int, + world_size: int, + name: str | None = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Dispatch and permute for Mixture of Experts (MoE). + + Args: + x: Input tensor [batch_size, seq_len, hidden_dim]. + gate_logits: Gate logits for choosing experts [batch_size, seq_len, num_experts]. + corr_bias: Optional correction bias to adjust gate logits. + k: Top-k experts to be selected. + capacity: The maximum number of tokens an expert can handle. + world_size: Number of distributed processes. + name: Optional name for the operation. + + Returns: + Tuple of Tensors containing: + - y: Output tensor after dispatch and permute. + - combine_weights: Weights for combining experts' outputs. + - scatter_index: Indices for scattering inputs to experts. + - expert_offset: Offset indices for each expert. + - expert_id: IDs of selected experts for each position. + """ + if in_dynamic_or_pir_mode(): + return _C_ops.moe_gate_dispatch_permute( + x, gate_logits, corr_bias, k, capacity, world_size + ) + + helper = LayerHelper('moe_gate_dispatch_permute', **locals()) + y = helper.create_variable_for_type_inference(dtype=x.dtype) + combine_weights = helper.create_variable_for_type_inference(dtype='float') + scatter_index = helper.create_variable_for_type_inference(dtype='int32') + expert_offset = helper.create_variable_for_type_inference(dtype='int32') + expert_id = helper.create_variable_for_type_inference(dtype='int32') + + inputs = { + 'x': x, + 'gate_logits': gate_logits, + 'corr_bias': corr_bias if corr_bias is not None else None, + } + attrs = {'k': k, 'capacity': capacity, 'world_size': world_size} + outputs = { + 'y': y, + 'combine_weights': combine_weights, + 'scatter_index': scatter_index, + 'expert_offset': expert_offset, + 'expert_id': expert_id, + } + + helper.append_op( + type='moe_gate_dispatch_permute', + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + return y, combine_weights, scatter_index, expert_offset, expert_id diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 3282d661b1f6e5..5e8669479aa082 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -479,6 +479,26 @@ if(NOT WITH_GPU list(REMOVE_ITEM TEST_OPS test_sparse_conv_igemm_op) endif() +# New Op only supported by CUDA>=12.0 and Linux, CUDA_ARCH_NAME==Volta, skip some op test +if(NOT WITH_GPU + OR WIN32 + OR APPLE + OR (${CUDA_ARCH_NAME} STREQUAL "Volta") + OR ((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))) + list( + REMOVE_ITEM + TEST_OPS + test_incubate_build_src_rank_and_local_expert_id + test_incubate_expand_modality_expert_id + test_incubate_fused_loss + test_incubate_fused_rmsnorm_ext + test_incubate_int_bincount + test_incubate_moe_combine + test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk + test_incubate_moe_gate_dispatch_w_permute_bwd + test_incubate_moe_gate_dispatch_w_permute) +endif() + if(NOT WITH_CUDNN_FRONTEND) list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bn_op) list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_add_relu_op) diff --git a/test/legacy_test/ernie_utils/moe_all_gather_layer.py b/test/legacy_test/ernie_utils/moe_all_gather_layer.py new file mode 100644 index 00000000000000..53186898788748 --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_all_gather_layer.py @@ -0,0 +1,274 @@ +# ruff: noqa: FA100 +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +@author: kebo +@contact: kebo01@baidu.com + +@version: 1.0 +@file: moe_layer_all_gather.py +@time: 2024/09/21 15:11:10 +@Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved + +这一行开始写关于本文件的说明与解释 + + +""" + +from __future__ import annotations + +import contextlib +import logging + +import paddle +from paddle import nn +from paddle.incubate.nn.functional import expand_modality_expert_id + +from .moe_layer import MOELayer + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from paddle.distributed.communication.group import Group + + +def profile(_): + """dumy profile""" + return contextlib.nullcontext() + + +logger = logging.getLogger(__name__) + +if False: + try: + from paddle_xpu_nn import moe_gate_dispatch as xpu_moe_gate_dispatch + except ImportError: + xpu_moe_gate_dispatch = None + logger.warning("`xpu moe dispatch` not found") +else: + pass + + +class MOEAllGatherLayer(MOELayer): + """_summary_ + + Args: + MOELayer (_type_): _description_ + """ + + def __init__( + self, + gate: nn.Layer, + experts: list[nn.Layer], + layer_idx, + shared_experts: list[nn.Layer] | None = None, + dense_experts: list[nn.Layer] | None = None, # no use + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + enable_logging, + k, + enable_bpr, + all_to_all_dropout, + group_experts, + moe_statics, + ) + + +class MOEAllGatherLayerV2(MOEAllGatherLayer): + """_summary_ + + Args: + MOELayer (_type_): _description_ + """ + + def __init__( + self, + gate: nn.Layer, + experts: list[nn.Layer], + layer_idx, + shared_experts: list[nn.Layer] | None = None, + dense_experts: list[nn.Layer] | None = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_expert_alltoall_overlap=False, + use_padding=True, + dense_token_type=3, # considered as dense tokens (no moe) + moe_statics=None, + ): + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + dense_experts, + group, + recompute, + enable_logging, + k, + enable_bpr, + all_to_all_dropout, + group_experts, + moe_statics, + ) + self.enable_reverse_token_drop = enable_reverse_token_drop + self.is_allgather_moe_layer = True + # assert self.gate.config.sequence_parallel + world_size = self.gate.config.moe_world_size + self.use_padding = use_padding + + # 全局 gate gather + self.send_rank = None + self.local_expert_id = None + self.dense_token_type = dense_token_type + self.dense_experts = dense_experts + self.capacity_tensor = None + self.use_expert_out_alltoall = use_expert_out_alltoall + self.use_expert_alltoall_overlap = use_expert_alltoall_overlap + logger.info( + f"using MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " + f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " + f"enable_reverse_token_drop={self.enable_reverse_token_drop}" + ) + self.two = paddle.to_tensor(2, dtype=paddle.float32) + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + def fused_gate_logits_process_fused( + self, gate_logits_lm, gate_logits_mm, token_type_ids + ): + """process gatelogits w/ moe utils""" + # top_k = 1 if isinstance(self.gate, SinkHornGateFused) else self.k + top_k = self.k + num_expert_per_rank_per_modality = ( + gate_logits_lm.shape[-1] // self.config.moe_world_size + ) + group_size = gate_logits_lm.shape[-1] // top_k + if self.group_experts: + assert not self.use_correction_bias + gate_logits_lm = gate_logits_lm.reshape( + [gate_logits_lm.shape[0], top_k, -1] + ) + prob_lm = self.gate.act(gate_logits_lm) + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=1, axis=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + expert_id_lm = expert_id_lm.squeeze(-1) + else: + prob_lm = self.gate.act(gate_logits_lm) + if self.use_correction_bias: + prob_lm_ = ( + prob_lm + + self.moe_statics.e_score_correction_bias[0].detach() + ) + else: + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, axis=-1) + + if self.use_correction_bias: + batch_idx = ( + paddle.arange(prob_lm_.shape[0]) + .unsqueeze(-1) + .expand_as(expert_id_lm) + ) + weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias + + # num_expert_per_modality == 0 时只执行 group-expert expand,不执行 multimodal-expand + expert_id_lm = expand_modality_expert_id( + expert_id_lm, + num_expert_per_modality=( + num_expert_per_rank_per_modality + if (token_type_ids is not None and gate_logits_mm is not None) + else 0 + ), + group_size=group_size, + modality_offset=0, + is_group_expert=self.group_experts, + ) + expert_id_lm = expert_id_lm.reshape(weight_lm.shape) + lm_weight_and_expert_id = paddle.concat( + [weight_lm, expert_id_lm.astype("float32")], -1 + ) + if token_type_ids is None or gate_logits_mm is None: + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) + + prob_mm = self.gate.act(gate_logits_mm) + if self.use_correction_bias: + prob_mm_ = ( + prob_mm + self.moe_statics.e_score_correction_bias[1].detach() + ) + else: + prob_mm_ = prob_mm + weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, axis=-1) + if self.use_correction_bias: + batch_idx = ( + paddle.arange(prob_lm_.shape[0]) + .unsqueeze(-1) + .expand_as(expert_id_lm) + ) + weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias + + expert_id_mm = expand_modality_expert_id( + expert_id_mm, + num_expert_per_modality=num_expert_per_rank_per_modality, + group_size=group_size, + modality_offset=1, + is_group_expert=False, + ) + expert_id_mm = expert_id_mm.reshape(weight_mm.shape) + mm_weight_and_expert_id = paddle.concat( + [weight_mm, expert_id_mm.astype("float32")], -1 + ) + weight_and_expert = paddle.where( + (token_type_ids == 0).unsqueeze(-1), + lm_weight_and_expert_id, + mm_weight_and_expert_id, + ) + return ( + weight_and_expert, + prob_lm.reshape([prob_lm.shape[0], -1]), + prob_mm, + ) diff --git a/test/legacy_test/ernie_utils/moe_layer.py b/test/legacy_test/ernie_utils/moe_layer.py new file mode 100644 index 00000000000000..b5fbf11791d090 --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_layer.py @@ -0,0 +1,245 @@ +# ruff: noqa: FA100 +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""_summary_ + +Returns: + _type_: _description_ +""" +from __future__ import annotations + +import logging +from collections import namedtuple +from typing import TYPE_CHECKING + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet + +if TYPE_CHECKING: + from paddle.distributed.communication.group import Group +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None +try: + from paddle.distributed import in_auto_parallel_align_mode +except: + + def in_auto_parallel_align_mode(): + """ + hack for paddlenlp develop branch. + """ + return False + + +try: + from bincount_ops import int_bincount +except ImportError: + int_bincount = None + +logger = logging.getLogger(__name__) + +try: + import moe_ops +except ImportError: + moe_ops = None + logger.warning( + "`moe-ops` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +class MOELayer(nn.Layer): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + + moe = MOELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (paddle.nn.Layer): + gate network + expert (paddle.nn.LayerList): + expert network, LayerList 长度是 per_device 上的 expert 数。 + group (paddle.ProgressGroup) + recompute: 启用MOE内recomupte + Returns: + output + combine_weight + router-loss + """ + + def __init__( + self, + gate: nn.Layer, + experts: list[nn.Layer], + layer_idx, + shared_experts: list[nn.Layer] | None = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + """ + 初始化MoE层。 + + Args: + gate (nn.Layer): 智能门控层,用于选择需要使用的专家。 + experts (List[nn.Layer]): 需要使用的专家列表。 + layer_idx (int): 当前MoE层的索引。 + group (Group): 分布式通信组。默认值为None。 + recompute (bool): 是否在每个训练迭代中重新计算MoE输出。默认值为False。 + """ + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.enable_logging = enable_logging + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group + is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe: + p.is_distributed = True + + self.world_size = dist.get_world_size(self.group) + # assert self.world_size > 1, f'moe-group not found, world_size {self.world_size}' + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = len(self.experts) + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") + and self.gate.dispatch_by_task + ) + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + ''' dummy skip + if enable_bpr: + logger.info("using BPR") + prepost_process_buffer = {} + self.input_preprocess = partial( + bpr_preprocess, buffer=prepost_process_buffer + ) + self.output_postprocess = partial( + bpr_postprocess, buffer=prepost_process_buffer + ) + else: + self.input_preprocess = self.output_postprocess = None + ''' + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + self._rr_moe_gate_dispatch = None + self._rr_moe_combine = None + ''' dummy skip + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_gate_dispatch", False + ): + self._rr_moe_gate_dispatch = RefinedRcomputeMoEGateDispatch() + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_combine", False + ): + self._rr_moe_combine = RefinedRcomputeMoECombine() + ''' + + +def fuse_logging(gate_logits, combine_weights, token_type_ids): + """fuse_logging""" + with paddle.no_grad(): + gate_expert_per_token_type_0, gate_expert_per_token_type_1 = None, None + gate_experts_per_token = None + ce = moe_router_loss_ops.cal_cross_entropy_info(gate_logits).mean(0) + if token_type_ids is not None: + ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ) = moe_router_loss_ops.cal_gate_experts_per_token_info( + combine_weights, token_type_ids + ) + else: + gate_experts_per_token = paddle.count_nonzero(combine_weights) / ( + gate_logits.shape[0] + ) + + return ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ce, + ) diff --git a/test/legacy_test/ernie_utils/moe_layer_uneven.py b/test/legacy_test/ernie_utils/moe_layer_uneven.py new file mode 100644 index 00000000000000..4bdf42c377d75c --- /dev/null +++ b/test/legacy_test/ernie_utils/moe_layer_uneven.py @@ -0,0 +1,292 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +moe +""" + +import inspect +import logging +from collections import namedtuple + +import paddle +from paddle import _C_ops +from paddle.autograd import PyLayer + +# from ernie_core.models.moe.moe_layer import _AllToAll +from paddle.incubate.nn.functional import moe_gate_dispatch +from paddle.nn import functional as F + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 + +logger = logging.getLogger(__name__) + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +if False: + try: + from paddle_xpu_nn import ( + moe_combine as xpu_moe_combine, + moe_combine_bwd as xpu_moe_combine_bwd, + ) + except ImportError: + xpu_moe_combine = None + xpu_moe_combine_bwd = None + logger.warning("`xpu moe combine` not found") +else: + try: + from paddle.incubate.nn.functional import moe_combine + except ImportError: + moe_combine = None + logger.warning( + "`moe-combine` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + + +def average_grad(x, y, dy, eps=1e-12): + """ + TODO: fuse 这坨 shit + y=x/x.sum(-1, keepdim=True) 的反向过程 + """ + s, k = x.shape + xsum = x.sum(axis=-1, keepdim=True) # [s,1] + maskpos = (xsum == 0.0).expand_as(x) + + xsum_square = xsum.square() # [s,1] + left = paddle.triu( + paddle.tril((1 / xsum).unsqueeze(-1).expand([s, k, k])) + ) # aka diag-emb [s,k,k] + right = (-x / xsum_square).unsqueeze(-1).expand([s, k, k]) + dydx = left + right + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze( + -2 + ) # [s,1,k] @[s,k,k] -> [s,1,k] + dx = paddle.where(maskpos, paddle.zeros_like(dx), dx) + return dx + + +mask = paddle.to_tensor( + [ + [1, -1], + [-1, 1], + ] +).unsqueeze(0) + + +def average_grad_bi(x, y, dy, eps=1e-12): + """ + y=x/x.sum(-1, keepdim=True) + k=2 下面的反向过程,精度会更准一些: + dx1 = (y2 *dy1 - y2*dy2)/(y1+y2)**2 + dx2 = (y1 *dy2 - y1*dy1)/(y1+y2)**2 + """ + s, k = x.shape + assert k == 2, k + xsum = paddle.clip(x.sum(axis=-1, keepdim=True), min=eps) # [s,1] + dydx = ( + x.flip(axis=1).unsqueeze(-2).tile([1, 2, 1]) + * mask.cast(x.dtype) + / xsum.square().unsqueeze(-1) + ) + dx = paddle.matmul(dy.unsqueeze(-2).cast(dydx.dtype), dydx).squeeze( + -2 + ) # [s,1,k] @[s,k,k] -> [s,1,k] + return dx + + +def topk_grad(x, dy, indices): + """ + TODO: fuse 这坨 shit + y=gather(topk(x)) 的反向过程 + x: [s,e] + dy: [s,k] + """ + s, e = x.shape + _, k = dy.shape + dx = paddle.scatter_nd( + paddle.stack( + [ + paddle.arange(s).repeat_interleave(k).cast(indices.dtype), + indices.reshape([-1]), + ], + -1, + ), + dy.reshape([-1]), + shape=[s, e], + ) # [s,k] -> [s,e] + return dx # dx 保持高精度 + + +class GateDispatch(PyLayer): + """doc""" + + @staticmethod + def forward(ctx, x, gate_prob, k, capacity, use_pad, eps=1e-12): + """ + 对`gate_prob` 进行 softmax 并根据结果选取 topk 路由expert。 最后根据 expert 号对 `x` 进行重排。 + Args: + x: [s, d] 输入的 activateion + gate_prob: [s, e] + k: int + capacity: int #no use + Returns: + y: [s*k, d] 将所有 `x` 根据其路由的 `expert-id` 升序的排序,融合到 s 维度。 + 当截断发生时 s 会比输入 s 小。 + combine_weights: [s, k], float: 每个 token 第 k 选择的 expert 的权重。 + 当截断发生时 s 会比输入 s 小。 + scatter_index: [k, s] : 每个 token 第 k 次选择对应到 `y` 中的位置。 + expert_offset: [e]: `y`中每个 expert-id 的分割位置。 + expert_id: [s] `x` 中激活的 expert 号 + """ + ctx.k = k + ctx.eps = eps + ctx.capacity = capacity + ctx.gate_prob = gate_prob + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: + compat_args = (None,) + else: + compat_args = () + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch( + x, + gate_prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=use_pad, + ) + ) + ctx.combine_weights = combine_weights + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + ctx.scatter_index = scatter_index + ctx.expert_id = expert_id + num_experts = gate_prob.shape[-1] + + ctx.num_experts = num_experts + ctx.seqlen = gate_prob.shape[0] + + return y, combine_weights, scatter_index, expert_offset, expert_id + + @staticmethod + def backward(ctx, dy, dw, *_): + """ + TODO: 这坨代码可以 fuse 一手。 + 关于 softmax 对 logits 的导数,参考: + https://stats.stackexchange.com/questions/215521/ + how-to-find-derivative-of-softmax-function-for-the-purpose-of-gradient-descent/328095#328095 + """ + s, k = ctx.combine_weights.shape + grad = F.embedding(ctx.scatter_index, dy) # [s, k,d] + mask = (ctx.combine_weights > 0.0).astype(grad.dtype) # [s,k] + dx = paddle.matmul(mask.unsqueeze(1), grad).squeeze( + 1 + ) # [s,1,k] @ [s,k,d] -> [s,1,d] + if ctx.gate_prob.stop_gradient: + return dx, None + + combine_weights_unnorm = ctx.combine_weights + dw = dw.astype(combine_weights_unnorm.dtype) + d_prob = topk_grad(ctx.gate_prob, dw, ctx.expert_id) + return dx, d_prob + + +class GateCombine(PyLayer): + """GateCombine""" + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Input: + x: [seqlen * k, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + y: [seqlen, hidden_size] + """ + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + if False: + assert xpu_moe_combine is not None + return xpu_moe_combine(x, combine_weights, scatter_index) + else: + assert moe_combine is not None + ret = moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Input: + grad_y: [seqlen, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + grad_x: [seqlen * k, hidden_size] + grad_combine_weight: [seqlen, k] + + """ + + if False: + assert xpu_moe_combine_bwd is not None + grad_x, grad_combine_weight_helper = xpu_moe_combine_bwd( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + else: + assert moe_combine is not None + grad_x, grad_combine_weight_helper = _C_ops.moe_combine_grad( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return ( + grad_x, + grad_combine_weight.reshape(ctx.combine_weights.shape), + None, + ) + # return grad_x, grad_combine_weight_helper + + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [s, k] + scatter_index: ** [k, s] ** + + Returns: + y: Tensor[s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret diff --git a/test/legacy_test/ernie_utils/top2_gate.py b/test/legacy_test/ernie_utils/top2_gate.py new file mode 100644 index 00000000000000..8ab34b5f04c19b --- /dev/null +++ b/test/legacy_test/ernie_utils/top2_gate.py @@ -0,0 +1,1035 @@ +# ruff: noqa: FA100 +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +top2gate +""" + +from __future__ import annotations + +import logging +from functools import partial + +import numpy as np + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.incubate.nn.functional import cal_aux_loss +from paddle.utils import unique_name + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} # 没有erniebot的环境下无法打印 debug 量 +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None + +try: + from custom_setup_ops import matmul_bwd +except ImportError: + matmul_bwd = None + +try: + from bincount_ops import int_bincount +except ImportError: + int_bincount = None + +logger = logging.getLogger(__name__) + + +class CalAuxLossFunctor(paddle.autograd.PyLayer): + """CalAuxLossFunctor""" + + @staticmethod + def forward( + ctx, + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min=1e-6, + ): + """forward""" + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + loss, seqlen_float, ce = cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, + ) + ''' + ctx.save_for_backward(gate_prob, seqlen_float, ce) + ctx.num_experts = num_experts + ctx.use_group = use_group + ctx.moe_k = moe_k + ''' + return loss + + @staticmethod + def backward(ctx, out_grad): + """backward""" + ''' + gate_prob, seqlen_float, ce = ctx.saved_tensor() + num_experts = ctx.num_experts + use_group = ctx.use_group + moe_k = ctx.moe_k + from paddle import _C_ops + return _C_ops.cal_aux_loss_grad( + out_grad, gate_prob, seqlen_float, ce, num_experts, use_group, moe_k + ) + ''' + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + """cal_aux_loss_func""" + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if ( + tokens_mask is not None + and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] + ): + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + # me = paddle.mean(gate_prob, axis=0) + # ce = paddle.mean(dispatch_mask.cast("float32"), axis=0) + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + if scale is not None: + # 前向用局部me, 反向用全局me + l_aux = l_aux + (scale - 1) * l_aux.detach() + return l_aux + + +def masked_fill(x, mask, value): + """ + 将输入的Tensor中根据mask进行掩盖,并用value值替换。 + + Args: + x (Tensor): 输入的Tensor。 + mask (Tensor): 用于掩盖的布尔Tensor,其形状应与x相同。 + value (Union[float, int]): 需要替换的值。 + + Returns: + Tensor: 返回一个新的Tensor,其形状与x相同,并且根据mask和value进行掩盖和替换。 + + """ + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@paddle.no_grad() +def compute_optimal_transport( + M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10 +): + """ + Computes the optimal transport matrix and Slinkhorn distance using the + Sinkhorn-Knopp algorithm + + Inputs: + - M : cost matrix (n x m) + - r : vector of marginals (n, ) + - c : vector of marginals (m, ) + - lam : strength of the entropic regularization + - epsilon : convergence parameter + + Outputs: + - P : optimal transport matrix (n x m) + - dist : Sinkhorn distance + """ + n, _ = M.shape + # P = (- lam * M).exp() + # P /= P.sum() + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + # normalize this matrix + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + """ + FusedGateDetachMatmul + """ + + @staticmethod + def forward(ctx, x, w): + """ + forward + """ + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear( + cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype) + ) + + @staticmethod + def backward(ctx, y_grad): + """ + backward + """ + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = matmul_bwd( + cast_if_needed(x, ctx.dtype), + cast_if_needed(w, ctx.dtype), + y_grad, + False, + False, + ) + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse): + """ + gate_detach_matmul + """ + if use_fuse: + return FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + return F.linear(x, weight) + + +class Top2Gate(nn.Layer): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + """ + 初始化 MoE 层,包含参数初始化和一些其他功能。 + + Args: + layer_idx (int): 当前层的索引号。 + group: 分组名称。 + + Returns: + None: 不返回任何内容。 + """ + super().__init__() + if False: + try: + from paddle_xpu.layers.nn import xpu_matmul + + self.xpu_matmul = xpu_matmul() + except ImportError: + self.xpu_matmul = None + + self.config = config + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + if self.fuse_gate_detach_matmul: + assert matmul_bwd is not None, "matmul_bwd is not supported" + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) # paddle.to_tensor(config.moe_num_experts, dtype="float32").sum() + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_token_type_bias = config.moe_use_token_type_bias + self.use_correction_bias = config.moe_use_aux_free + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) # [S,E] + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.enable_logging = config.moe_logging + self.norm_gate_logits = config.moe_norm_gate_logits + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor( + config.moe_aux_loss_lambda, dtype="float32" + ) + self.moe_z_loss_lambda = paddle.to_tensor( + config.moe_z_loss_lambda, dtype="float32" + ) + self.moe_orthogonal_loss_lambda = paddle.to_tensor( + config.moe_orthogonal_loss_lambda, dtype="float32" + ) + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = ( + self.moe_orthogonal_loss_lambda.unsqueeze(0) + ) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap + and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.moe_use_hard_gate: + self.num_experts_list = [] + self.experts_type_mask = [] + # hard-gate + group_experts 需要对gate_logits不同部分分开计算 + experts_ids = paddle.zeros( + [sum(self.num_experts)], dtype="int64" + ).reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[ + :, offset : offset + expert_num // config.moe_world_size + ] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + # 非group_experts, 依赖token_type_bias实现hard-gate能力。 + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + # 强制在amp下任使用fp32精度 + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precision = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + """ + 创建参数权重。 + + Args: + None + + Returns: + weight (Parameter): 创建的参数权重。 + + """ + if self.config.multimodel_experts: + # support setting lambda for each expert group + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = ( + self.moe_orthogonal_loss_lambda.expand(len(self.num_experts)) + ) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard( + f"mm_gate_{self.layer_idx}_" + ): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + ( + "weight" if i == 0 else f"weight_{i}" + ), # 为了对齐原 state-dict,第一个 gate-weight 不改名. + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), # 特殊处理,有利于热启 dense-ckpt + ) + logger.info(f"moe-Gate, {self.weight}") + + if self.use_token_type_bias: + if self.config.multimodel_experts: + assert ( + not self.config.moe_use_hard_gate + ), "multimodel_experts with hard_gate is not support token_type_bias." + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + bias_type_num = ( + len(self.num_experts) if self.config.multimodel_experts else 1 + ) + self.bias = self.create_parameter( + shape=[bias_type_num, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate_bias"), + initializer=paddle.nn.initializer.Assign( + np.zeros([bias_type_num, num_experts]) + ), + ), # 特殊处理,有利于热启 dense-ckpt + ) + logger.info(f"using token type bias, bias: {self.bias},") + # 强制在amp下任使用fp32精度 + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precision = False + + def get_gate_weight(self, transform_weight): + """ + 在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 + transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 + """ + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[ + :, + :, + offset : offset + num_experts // self.config.moe_world_size, + ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( + [self.model_dim, self.config.moe_world_size, -1] + ) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + + return weight + + def forward( + self, + input: Tensor, + token_type_ids: Tensor = None, + transform_weight: bool = True, # [seq] + correction_bias: Tensor = None, # [seq] + ) -> tuple[Tensor, Tensor, Tensor]: # type: ignore + """ + Args: + input: paddle.Tensor[Seq, Dim], hidden-states of layer + token_type_ids: paddle.Tensor[Seqw], token_type_ids of input + transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified + Returns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + orig_dtype = input.dtype + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if False: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) + + if self.use_token_type_bias: + assert token_type_ids is not None + bias = self.bias[token_type_ids] # [seq] + # logger.info(f"adding bias: {bias}") + logits = logits + bias + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + l_aux, + l_zloss, + ) = self.top2_gating(logits, correction_bias=correction_bias) + orthogonal_loss = self._cal_orthogonal_loss() + router_loss = ( + l_aux * self.moe_aux_loss_lambda + + l_zloss * self.moe_z_loss_lambda + + orthogonal_loss * self.moe_orthogonal_loss_lambda + ) + router_loss.stop_gradient = False + + combine_weights = combine_weights.cast(orig_dtype) + return ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + logits, + ) + + def get_capacity(self, num_tokens, cap_factor=None): + """ + return capacity + """ + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + # capacity = 2S/E + capacity = int(cap * num_tokens // num_experts) + assert ( + capacity > 0 + ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + """ + Args: + logits: 形状为[batch, vocab_size]的logits,用于计算top2 gate。 + cap[Optional]: capacity-factor, if none, read from config + correction_bias[Optional]: used for aux-free router + + Returns: + tuple: + - capacity: 每个token可分发的最大数量。 + - dispatch_masks: 用于dispatching的mask。第一个元素是第一类token的mask;第二个元素是第二类token的mask。 + - combine_weights:用于combining的权重。第一个元素是第一类token的权重;第二个元素是第二类token的权重。 + - scatter_indexes: 用于scattering的索引。第一个元素是第一类token的索引;第二个元素是第二类token的索引。 + - loss_aux: aux loss。 + - loss_z: z loss。 + """ + # logger.info(f'gate-input: {logits}') + l_zloss = self._cal_z_loss(logits) + gates = self.act(logits) + + # gates has shape of SE + assert logits.ndim == 2, logits.shape + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # capacity = 2S/E + capacity = self.get_capacity(logits.shape[0], cap) + + # Create a mask for 1st's expert per token + score_for_argmax = ( + gates + correction_bias.unsqueeze(0) + if correction_bias is not None + else gates + ) + indices1_s = paddle.argmax(score_for_argmax, axis=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast( + paddle.int64 + ) # [0,1] + + l_aux = self._cal_aux_loss( + gates, mask1.sum(axis=0), self.num_experts_tensor + ) + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + if self.training and not self.no_jitter: + gumbels = ( + -paddle.empty_like( + logits, + ) + .exponential_() + .log() + ) # ~Gumbel(0,1) + logits_w_noise = logits + gumbels + else: + logits_w_noise = logits + + logits_except1 = masked_fill( + logits_w_noise, mask1.cast(paddle.bool), float("-inf") + ) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) + if correction_bias is not None + else logits_except1 + ) + indices2_s_original = paddle.argmax(score_for_argmax, axis=1) + + if self.training and self.sinkhorn_2gate: + r = paddle.ones(num_tokens, "float32") / num_tokens + # c = paddle.ones(num_experts, "float32") / num_experts + # 非均匀c + c = capacity - mask1.cast("float32").sum(0) + c = paddle.maximum(c, paddle.zeros_like(c)) + c /= c.sum() + + pi, _ = compute_optimal_transport( + -logits_except1.cast("float32").detach(), + r, + c, + lam=self.sinkhorn_temp, + ) + pi = masked_fill(pi, mask1.cast(paddle.bool), float("-inf")) + indices2_s = paddle.argmax(pi, axis=1) + else: + indices2_s = indices2_s_original + + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast( + paddle.int64 + ) + + # Compute locations in capacity buffer + locations1 = ( + paddle.cumsum(mask1, axis=0) - 1 + ) # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] + locations2 = paddle.cumsum(mask2, axis=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + # Remove locations outside capacity from mask + mask1 *= (locations1 < capacity).cast(paddle.int64) # [0,1,1,0,0,0,0] + mask2 *= (locations2 < capacity).cast(paddle.int64) + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = (gates * mask1_float).sum(axis=-1) + gates2_s = (gates * mask2_float).sum(axis=-1) + # logger.info(f'gates1_s:{gates1_s} gates2_s:{gates2_s} logits:{logits}') + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s # [0.2, 0.3] + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + # log.debug(gates2_s) + gates2_s = paddle.where( + 2 * gates2_s < paddle.rand_like(gates2_s), + paddle.zeros_like(gates2_s), + gates2_s, + ) + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + expert1_index = paddle.argmax(gates1, -1) + combine1_weight = paddle.max(gates1, -1, keepdim=True) + scatter1_index = expert1_index * capacity + locations1_s + scatter1_index = scatter1_index.cast("int64") + dispatch1_mask = combine1_weight.cast(paddle.bool).detach() + + expert2_index = paddle.argmax(gates2, -1) + combine2_weight = paddle.max(gates2, -1, keepdim=True) + scatter2_index = expert2_index * capacity + locations2_s + scatter2_index = scatter2_index.cast("int64") + dispatch2_mask = combine2_weight.cast(paddle.bool).detach() + # logger.info(f'expert-id: {expert1_index} vs {expert2_index}, mask:{mask1_float} vs {mask2_float}') + + return ( + capacity, + paddle.concat((dispatch1_mask, dispatch2_mask), 1), + paddle.concat((combine1_weight, combine2_weight), 1), + paddle.stack((scatter1_index, scatter2_index), 1), + l_aux, + l_zloss, + ) + + def _cal_aux_loss( + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, + ): + """ + 计算辅助损失 + + Args: + gate_prob (paddle.Tensor[local_seq, num_experts]): + dispatch_mask (paddle.Tensor[num_experts]): 每个 expert 被分配的 token 数(不考虑 token drop) + tokens_mask (paddle.Tensor[Seq]): 每个 MP 内 token-type-id + dispatch_tokens_mask (paddle.Tensor): AllGather 后的`tokens_mask` + Returns: + paddle.Tensor: 辅助损失值。 + + """ + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk( + k=self.config.moe_k, axis=-1 + ) + if int_bincount is not None: + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) + else: + mask = paddle.zeros_like( + gate_prob_this_modality + ).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + dispatch_mask = paddle.sum( + mask.cast(paddle.int64), axis=0 + ) + else: + dispatch_mask = paddle.zeros( + gate_prob.shape[-1], dtype="int64" + ) + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + if int_bincount is not None: + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) + else: + mask = paddle.zeros_like(gate_prob).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + if ( + moe_router_loss_ops is not None + and (tokens_mask is None or len(tokens_mask.shape) == 1) + and ( + tokens_mask is None + or tokens_mask.shape[0] == gate_prob.shape[0] + ) + and (gate_prob.shape[0] >= gate_prob.shape[1]) + and (not self.global_aux_loss) + and (gate_prob.dtype == paddle.float32) + ): + return CalAuxLossFunctor.apply( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + clip_min=1e-6, + ) + else: + return cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + self.global_aux_loss, + self.rank if self.global_aux_loss else None, + self.group if self.global_aux_loss else None, + ) + + +class TopKGateFused(Top2Gate): + """doc""" + + def forward( + self, + input: Tensor, + token_type_ids=None, + transform_weight=True, + ) -> tuple[Tensor, Tensor, Tensor]: # type: ignore + """ + Args: + input: paddle.Tensor, hidden-states of layer + token_type_ids: paddle.Tensor[Seqw], token_type_ids of input + transform_weight: bool, when using multimodal experts, perform `self.get_gate_weight` if specified + Returns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + capacity = self.get_capacity(input.shape[0]) + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if False: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] # [seq] + logits = logits + bias + orthogonal_loss = None + # 正交 loss 拿到 moe-layer 里去计算 + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + return logits, capacity, router_loss + + +class DeepEPTop2Gate(TopKGateFused): + """DeepEPTop2Gate""" + + def forward( + self, + input, + transform_weight=True, + global_gate_mask=None, + input_ids=None, + ): + """forward""" + + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul + ) + + if global_gate_mask is not None: + logits = logits + global_gate_mask + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + return logits, router_loss + + def _cal_aux_loss(self, gates, dispatch_mask, input_ids=None): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. + The shape is [seq_len, num_experts] + dispatch_mask: (paddle.Tensor): Represents the number of tokens for each expert. + The shape is [num_experts] + topk_indices: + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + assert ( + len(gates.shape) == 2 + ), "gates.shape must be [sequence_length, num_experts]" + if input_ids is not None: + # has_padding = (input_ids == 0).any() + assert ( + input_ids.shape[0] == gates.shape[0] + ), f"check input_ids shape {input_ids.shape}" + valid_mask = (input_ids != 0).astype(paddle.float32) + seqlen_float = valid_mask.sum().item() + gates = gates * valid_mask.unsqueeze(-1) + else: + seqlen_float = float(gates.shape[0]) + me = paddle.sum(gates, axis=0) / seqlen_float + ce = dispatch_mask.astype(gates.dtype).detach() / seqlen_float + + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + if seqlen_float == 0: + return paddle.to_tensor(0.0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = paddle.logsumexp(logits, axis=1).square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean( + paddle.square( + paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts) + ) + ) + return orthogonal_loss diff --git a/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py new file mode 100644 index 00000000000000..e0a35a3f85233d --- /dev/null +++ b/test/legacy_test/test_incubate_build_src_rank_and_local_expert_id.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.functional import build_src_rank_and_local_expert_id + +logger = logging.getLogger(__name__) + + +class TestFusedCalculateAuxLoss(unittest.TestCase): + def test_build_src_rank_and_local_expert_id(self): + def orig_func(expert_num_global_list, num_local_experts): + send_rank_cpu = np.concatenate( # TOO SLOW!!! break every thing + [ + np.full([j], i // num_local_experts, dtype="int32") + for i, j in enumerate(expert_num_global_list) + ], + 0, + ) + local_expert_id_cpu = np.concatenate( + [ + np.full([j], i % num_local_experts, dtype="int32") + for i, j in enumerate(expert_num_global_list) + ], + 0, + ) + send_rank = paddle.to_tensor(send_rank_cpu) + local_expert_id = paddle.to_tensor(local_expert_id_cpu) + return send_rank, local_expert_id + + def fused_func( + expert_num_global_tensor, expert_num_global, num_local_experts + ): + return build_src_rank_and_local_expert_id( + expert_num_global_tensor, expert_num_global, num_local_experts + ) + + expert_num_global = np.random.randint( + 0, 512, size=[12 * 8], dtype="int32" + ) + expert_num_global_tensor = paddle.to_tensor( + expert_num_global, dtype="int64" + ) + + s1, l1 = orig_func(expert_num_global, 12) + s2, l2 = fused_func(expert_num_global_tensor, expert_num_global, 12) + assert ((s1 - s2) == 0).all(), (s1, s2) + assert ((l1 - l2) == 0).all(), (l1, l2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_incubate_expand_modality_expert_id.py b/test/legacy_test/test_incubate_expand_modality_expert_id.py new file mode 100644 index 00000000000000..9f1d41e49697fe --- /dev/null +++ b/test/legacy_test/test_incubate_expand_modality_expert_id.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import namedtuple +from functools import partial + +from ernie_utils.moe_all_gather_layer import MOEAllGatherLayerV2 + +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import expand_modality_expert_id + + +def fused_gate_logits_process_ref( + self, gate_logits_lm, gate_logits_mm, token_type_ids +): + """process gatelogits""" + top_k = self.k + num_expert_per_rank_per_modality = ( + gate_logits_lm.shape[-1] // self.config.moe_world_size + ) + + @paddle.no_grad() + def shift_ids(ids, modality_offset): + # 现在认为所以模态的 expert 数都一样 + rank = ids // num_expert_per_rank_per_modality + expert_id_in_rank = ids % num_expert_per_rank_per_modality + return ( + rank * (num_expert_per_rank_per_modality * 2) + + expert_id_in_rank + + modality_offset * num_expert_per_rank_per_modality + ) + + if self.group_experts: + gate_logits_lm = gate_logits_lm.reshape( + [gate_logits_lm.shape[0], top_k, -1] + ) + prob_lm = self.gate.act(gate_logits_lm) + weight_lm, expert_id_lm = prob_lm.topk(k=1, axis=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + expert_id_lm = expert_id_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + scale = paddle.arange(0, top_k * group_size, group_size).unsqueeze(0) + expert_id_lm = expert_id_lm + scale + else: + prob_lm = self.gate.act(gate_logits_lm) + weight_lm, expert_id_lm = prob_lm.topk(k=top_k, axis=-1) + if token_type_ids is not None: + expert_id_lm = shift_ids(expert_id_lm, 0) + expert_id_lm.stop_gradient = True + lm_weight_and_expert_id = paddle.concat( + [weight_lm, expert_id_lm.astype("float32")], -1 + ) + if token_type_ids is None: + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) + + prob_mm = self.gate.act(gate_logits_mm) + weight_mm, expert_id_mm = prob_mm.topk(k=top_k, axis=-1) + + expert_id_mm = shift_ids(expert_id_mm, 1) + expert_id_mm.stop_gradient = True + + mm_weight_and_expert_id = paddle.concat( + [weight_mm, expert_id_mm.astype("float32")], -1 + ) + + token_type_ids_float = token_type_ids[:, None].astype("float32") + weight_and_expert = ( + (1 - token_type_ids_float) * lm_weight_and_expert_id + + token_type_ids_float * mm_weight_and_expert_id + ) + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + + +def test_expand_modality_expert_id(): + def expand_id_one( + expert_id, + num_expert_per_modality, + k, + group_size, + modality_offset, + is_group_expert, + ): + orig_shape = expert_id.shape + expert_id = expert_id.reshape([-1]) + xid = paddle.arange(len(expert_id)) + if is_group_expert: + eid = xid % k + expert_id += eid * group_size + + rank = expert_id // num_expert_per_modality + expert_id_in_rank = expert_id % num_expert_per_modality + ret = ( + rank * (num_expert_per_modality * 2) + + expert_id_in_rank + + modality_offset * num_expert_per_modality + ) + return ret.reshape(orig_shape) + + S, E, k = 100, 24, 3 + expert_id_mm = paddle.randint(0, 12, shape=[S, k]) + num_expert_per_rank_per_modality = E // 2 // 4 + group_size = E // 2 // k + print( + f"num_expert_per_rank_per_modality: {num_expert_per_rank_per_modality}" + ) + fused = expand_modality_expert_id( + expert_id_mm, num_expert_per_rank_per_modality, group_size, 1, True + ) + + nonfused = expand_id_one( + expert_id_mm, num_expert_per_rank_per_modality, k, group_size, 1, True + ) + # num_expert_per_rank_per_modality, group_size + assert (fused == nonfused).all().item() + + Config = namedtuple("Config", ["moe_world_size"]) + Self = namedtuple( + "Self", + [ + "config", + "k", + "gate", + "group_experts", + "moe_statics", + "use_correction_bias", + ], + ) + Gate = namedtuple("Gate", ["act"]) + fake_gate = Gate(act=partial(F.softmax, axis=-1)) + fake_self = Self( + config=Config( + moe_world_size=8, + ), + k=k, + gate=fake_gate, + moe_statics=None, + use_correction_bias=False, + group_experts=True, + ) + + fake_logits = paddle.randn([S, E]) + fake_logits_mm = paddle.randn([S, E]) + token_type_ids = paddle.randint(0, 2, shape=[S]) + w_and_e, prob_lm, prob_mm = ( + MOEAllGatherLayerV2.fused_gate_logits_process_fused( + fake_self, fake_logits, fake_logits_mm, None + ) + ) + w_and_e_ref, prob_lm_ref, prob_mm_ref = fused_gate_logits_process_ref( + fake_self, fake_logits, fake_logits_mm, None + ) + assert (prob_lm == prob_lm_ref).all().item() + assert (w_and_e == w_and_e_ref).all().item() + w, e = w_and_e_ref.chunk(2, axis=-1) + + +class Test_expand_modality_expert_id_API(unittest.TestCase): + def test_dygraph(self): + test_expand_modality_expert_id() + + +if __name__ == "__main__": + + unittest.main() diff --git a/test/legacy_test/test_incubate_fused_loss.py b/test/legacy_test/test_incubate_fused_loss.py new file mode 100644 index 00000000000000..e6fe14a2d295f2 --- /dev/null +++ b/test/legacy_test/test_incubate_fused_loss.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np +from ernie_utils.top2_gate import ( + cal_aux_loss_func, +) + +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import cal_aux_loss + +logger = logging.getLogger(__name__) + + +class TestFusedCalculateAuxLoss(unittest.TestCase): + def setUp(self): + paddle.seed(42) + self.atol = 1e-10 + self.rtol = 1e-5 + + def run_and_check( + self, + gate_prob, + dispatch_mask, + tokens_mask=None, + dispatch_tokens_mask=None, + num_experts=48, + moe_k=6, + use_group=False, + ): + dispatch_mask_for_ref = dispatch_mask.detach() + dispatch_mask_for_test = dispatch_mask.detach() + input_for_ref = gate_prob.detach() + input_for_test = gate_prob.detach() + input_for_ref.stop_gradient = False + input_for_test.stop_gradient = False + + loss_ref = cal_aux_loss_func( + input_for_ref, + dispatch_mask_for_ref, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + ) + loss, _, _ = cal_aux_loss( + input_for_test, + dispatch_mask_for_test, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + 1e-6, + ) + loss_ref.backward() + loss.backward() + + np.testing.assert_equal(loss.shape, loss_ref.shape) + np.testing.assert_equal(loss.dtype, loss_ref.dtype) + np.testing.assert_equal( + input_for_ref.grad.shape, input_for_test.grad.shape + ) + np.testing.assert_equal( + input_for_ref.grad.dtype, input_for_test.grad.dtype + ) + np.testing.assert_allclose( + loss.astype("float32").numpy(), + loss_ref.astype("float32").numpy(), + atol=self.atol, + rtol=self.rtol, + ) + np.testing.assert_allclose( + input_for_test.grad.astype("float32").numpy(), + input_for_ref.grad.astype("float32").numpy(), + atol=self.atol, + rtol=self.rtol, + ) + + def run_single_case( + self, + seq_len, + expert_num=48, + g_num_experts=96, + moe_k=6, + ): + for use_group in [True, False]: + for use_tokens_mask in [True, False]: + for use_dispatch_tokens_mask in [True, False]: + paddle.seed(48) + gate_prob = paddle.randn([seq_len, expert_num]) + dispatch_mask = paddle.randint( + 0, seq_len, [expert_num] + ).astype("int64") + tokens_mask = ( + paddle.randint(0, 1, [seq_len]).astype(gate_prob.dtype) + if use_tokens_mask + else None + ) + dispatch_tokens_mask = ( + paddle.randint(0, 1, [seq_len * 2]).astype("bool") + if use_dispatch_tokens_mask + else None + ) + self.run_and_check( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + g_num_experts, + moe_k, + use_group, + ) + + def test_trivial_cases(self): + self.run_single_case(seq_len=1, expert_num=1) + self.run_single_case(seq_len=3, expert_num=2) + self.run_single_case(seq_len=13, expert_num=3) + self.run_single_case(seq_len=1024, expert_num=6) + self.run_single_case(seq_len=2048, expert_num=6) + self.run_single_case(seq_len=3005, expert_num=48) + self.run_single_case(seq_len=3005, expert_num=96) + self.run_single_case(seq_len=4096, expert_num=48) + self.run_single_case(seq_len=4096, expert_num=15) + self.run_single_case(seq_len=4096, expert_num=92) + self.run_single_case(seq_len=6000, expert_num=92) + self.run_single_case(seq_len=8192, expert_num=48) + self.run_single_case(seq_len=8192, expert_num=96) + self.run_single_case(seq_len=8477, expert_num=48) + self.run_single_case(seq_len=16 * 1024, expert_num=48) + self.run_single_case(seq_len=32 * 1024, expert_num=96) + self.run_single_case(seq_len=48 * 1024, expert_num=48) + self.run_single_case(seq_len=100 * 1024, expert_num=48) + self.run_single_case(seq_len=128 * 1024, expert_num=96) + self.run_single_case(seq_len=128 * 1024 + 478, expert_num=48) + self.run_single_case(seq_len=256 * 1024, expert_num=48) + self.run_single_case(seq_len=512 * 1024, expert_num=128) + + def run_special_case( + self, global_seq_len, seq_len, global_expert_num, expert_num, moe_k + ): + for use_group in [True, False]: + paddle.seed(48) + seq_len = 4096 + expert_num = 48 + gate_prob = F.softmax(paddle.randn([seq_len, expert_num]), axis=-1) + dispatch_mask = paddle.randint( + 0, seq_len, [seq_len, expert_num] + ).astype("int64") + tokens_mask = paddle.randint(0, 1, [seq_len]).astype( + gate_prob.dtype + ) + dispatch_tokens_mask = paddle.randint( + 0, 1, [global_seq_len] + ).astype("bool") + self.run_and_check( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + global_expert_num, + moe_k, + use_group, + ) + + def test_special_cases(self): + self.run_special_case(123, 156, 4, 8, 2) + self.run_special_case(123, 123 * 2, 4, 8, 2) + self.run_special_case(128, 128, 4, 8, 2) + self.run_special_case(1024, 4096, 4, 8, 2) + self.run_special_case(2048, 4096, 4, 8, 2) + self.run_special_case(2048, 9648, 4, 16, 2) + self.run_special_case(4096, 7546, 4, 8, 2) + self.run_special_case(4096, 4096 * 2, 4, 8, 2) + self.run_special_case(4096, 4096 * 2, 48, 48 * 2, 6) + self.run_special_case(5001, 5555, 48, 48 * 2, 6) + self.run_special_case(4096, 4096 * 8, 48, 48 * 8, 2) + self.run_special_case(4565, 4565 * 8, 47, 47 * 8, 4) + self.run_special_case(8192, 12288, 47, 47 * 8, 4) + self.run_special_case(8192, 8192 * 8, 48, 48 * 6, 16) + self.run_special_case(8192, 8192 * 16, 48, 48 * 16, 32) + self.run_special_case(8192, 8192 * 16, 123, 123 * 16, 111) + self.run_special_case(10580, 10580 * 16, 52, 52 * 16, 78) + self.run_special_case(512 * 1024, 1024 * 1024, 123, 123 * 16, 111) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_incubate_fused_rmsnorm_ext.py b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py new file mode 100644 index 00000000000000..89dc90aeb2f18f --- /dev/null +++ b/test/legacy_test/test_incubate_fused_rmsnorm_ext.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.functional import fused_rms_norm_ext + +# 假设 fused_rms_norm_ext 已经被导入 +# from your_module import fused_rms_norm_ext + + +class TestFusedRMSNorm(unittest.TestCase): + def setUp(self): + # 设置随机种子以确保结果可复现 + paddle.seed(2023) + np.random.seed(2023) + + def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): + """ + 使用 Paddle 原生操作实现 RMS Normalization 作为参考 + """ + # 计算均方根 + variance = paddle.mean(paddle.square(x), axis=-1, keepdim=True) + # 计算 RMS + rms = paddle.sqrt(variance + epsilon) + # 归一化 + y = x / rms + # 应用缩放 + y = y * scale.reshape([1, -1]) + # 应用偏置(如果有) + if bias is not None: + y = y + bias.reshape([1, -1]) + + # 返回归一化后的张量、均值(RMS Norm 中为0)和逆标准差 + return y, (1.0 / rms).squeeze(-1) + + def test_2d_input(self): + # 测试 2D 输入 + rows, cols = 32, 64 + x = paddle.randn([rows, cols]) + scale = paddle.randn([cols]) + + # 使用我们的实现 + y_fused, invvar_fused = fused_rms_norm_ext(x, scale) + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + + # 验证结果 + np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 + ) + + def test_without_bias(self): + # 测试没有偏置的情况 + rows, cols = 32, 64 + x = paddle.randn([rows, cols]) + scale = paddle.randn([cols]) + + # 使用我们的实现 + y_fused, invvar_fused = fused_rms_norm_ext(x, scale) + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + + # 验证结果 + np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 + ) + + def test_backward(self): + # 测试反向传播 + rows, cols = 16, 32 + x = paddle.randn([rows, cols], dtype='float32') + x.stop_gradient = False + scale = paddle.randn([cols], dtype='float32') + scale.stop_gradient = False + + # 前向传播 + y_fused, invvar = fused_rms_norm_ext(x, scale) + + # 计算损失并反向传播 + loss = paddle.mean(y_fused) + loss.backward() + + # 获取梯度 + x_grad_fused = x.grad.clone() + scale_grad_fused = scale.grad.clone() + + # 重置梯度 + x.clear_gradient() + scale.clear_gradient() + + # 使用参考实现 + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + loss_ref = paddle.mean(y_ref) + loss_ref.backward() + + # 获取参考梯度 + x_grad_ref = x.grad + scale_grad_ref = scale.grad + + # 验证梯度 + np.testing.assert_allclose( + x_grad_fused, x_grad_ref, rtol=1e-4, atol=1e-4 + ) + np.testing.assert_allclose( + scale_grad_fused, scale_grad_ref, rtol=1e-4, atol=1e-4 + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_incubate_int_bincount.py b/test/legacy_test/test_incubate_int_bincount.py new file mode 100644 index 00000000000000..46f43cf791c35b --- /dev/null +++ b/test/legacy_test/test_incubate_int_bincount.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.functional import int_bincount + + +class TestIntBincount(unittest.TestCase): + def setUp(self): + paddle.set_device('gpu') + + def test_basic(self): + x = paddle.to_tensor([1, 2, 3, 1, 2, 3], dtype=paddle.int32) + out = int_bincount(x, low=1, high=4, dtype=paddle.int32) + expected = np.array([2, 2, 2, 0]) + np.testing.assert_array_equal(out.numpy(), expected) + + def test_empty_input(self): + x = paddle.to_tensor([], dtype=paddle.int32) + out = int_bincount(x, low=0, high=10, dtype=paddle.int32) + self.assertEqual(out.shape, [11]) + self.assertEqual(out.sum().item(), 0) + + def test_different_dtypes(self): + x = paddle.to_tensor([1, 3, 5, 3, 1], dtype=paddle.int64) + out = int_bincount(x, low=1, high=6, dtype=paddle.int64) + expected = np.array([2, 0, 2, 0, 1, 0]) + np.testing.assert_array_equal(out.numpy(), expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_combine.py b/test/legacy_test/test_incubate_moe_combine.py new file mode 100644 index 00000000000000..2c765e13671230 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_combine.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import unittest + +import numpy as np +from ernie_utils.moe_layer_uneven import GateCombine + +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import moe_combine + +os.environ["FLAGS_flash_attn_version"] = "v1" +os.environ["FLAGS_cudnn_deterministic"] = "1" +os.environ["FLAGS_embedding_deterministic"] = "1" + + +def combining(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [seq, k] + scatter_index: ** [seq, k] ** + + Returns: + y: Tensor[s, dim] + """ + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + if hard_gate: + return x_gatherd.squeeze(-2) + # logger.info(f'combinning: {combine_weights}') + y = (combine_weights.unsqueeze(-1) * x_gatherd).sum(1) + # y = paddle.matmul(combine_weights.unsqueeze(1), x_gatherd).squeeze() # [s,1,k] @ [s,k,dim] -> [s,1,dim] + return y + + +def baseline_result( + x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy +): + """baseline_result""" + scatter_index = paddle.to_tensor(scatter_index_numpy) + x = paddle.to_tensor(x_numpy).cast("float32") + x.stop_gradient = False + + combine_weights = paddle.to_tensor(combine_weights_numpy).cast("float32") + combine_weights.stop_gradient = False + + scatter_index = paddle.to_tensor(scatter_index_numpy) + grad = paddle.to_tensor(grad_numpy).cast("float32") + + y = combining(x, combine_weights, scatter_index) + paddle.autograd.backward([y], [grad], True) + return [x.grad, combine_weights.grad, y] + + +def test_moe_combine( + x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy +): + """baseline_result""" + x = paddle.to_tensor(x_numpy).cast("float32") + x.stop_gradient = False + + combine_weights = paddle.to_tensor(combine_weights_numpy).cast("float32") + combine_weights.stop_gradient = False + + scatter_index = paddle.to_tensor(scatter_index_numpy).cast("int32") + grad = paddle.to_tensor(grad_numpy).cast("float32") + + y = GateCombine.apply(x, combine_weights, scatter_index) + paddle.autograd.backward([y], [grad], True) + # grad.backward() + return [x.grad, combine_weights.grad, y] + + +def gen_test_case(S, K, Dim, capacity_factor, seed=1234): + """gen_test_case""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + x_numpy = np.random.rand(int(S * capacity_factor), Dim).astype(np.float32) + combine_weights_numpy = np.random.rand(S, K).astype(np.float32) + scatter_index_numpy = np.random.permutation(max(x_numpy.shape[0], S * K))[ + : S * K + ].astype("int64") + scatter_index_numpy = scatter_index_numpy.reshape([S, K]) + + combine_weights_numpy[scatter_index_numpy >= x_numpy.shape[0]] = 0 + scatter_index_numpy[scatter_index_numpy >= x_numpy.shape[0]] = 0 + grad_numpy = np.random.randn(S, Dim).astype(np.float32) + return x_numpy, combine_weights_numpy, scatter_index_numpy, grad_numpy + + +def testing(test_case): + """testing""" + [bl_x_grad, bl_combine_weights_grad, bl_y] = baseline_result(*test_case) + [fused_x_grad, fused_combine_weights_grad, fused_y] = test_moe_combine( + *test_case + ) + np.testing.assert_allclose( + fused_y.astype("float32").numpy(), + bl_y.astype("float32").numpy(), + err_msg="fwd precision not pass", + rtol=1e-6, + ) + np.testing.assert_allclose( + fused_x_grad.astype("float32").numpy(), + bl_x_grad.astype("float32").numpy(), + rtol=1e-6, + err_msg="bwd grad precision not pass", + ) + np.testing.assert_allclose( + fused_combine_weights_grad.astype("float32").numpy(), + bl_combine_weights_grad.astype("float32").numpy(), + rtol=1e-6, + ) + + +class TestFused(unittest.TestCase): + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_lt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=1.8)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_eq_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=2)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_cap_gt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=2, Dim=4096, capacity_factor=2.2)) + + @unittest.skipIf(moe_combine is None, "test_moe_combine not installed") + def test_k_gt_2( + self, + ): + """ + 测试精度对齐的功能 + + Args: + 无参,没有任何参数。 + + Returns: + NoneType:测试通过时返回None;测试失败时抛出异常。 + + """ + testing(gen_test_case(S=1024, K=8, Dim=4096, capacity_factor=2)) + + +if __name__ == "__main__": + + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py new file mode 100644 index 00000000000000..0a19402605211d --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py @@ -0,0 +1,222 @@ +# ruff: noqa: C419 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_partial_nosoftmaxtopk, +) + + +def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(): + + s, d, e = 4, 100, 8 + k, cap = 4, 3 + local_expert_num = 2 + + # x = paddle.randn([s, d]) + # gate_logits = paddle.randn([s, e]) + x = paddle.arange(1, s + 1).unsqueeze(-1).expand([s, d]).astype("bfloat16") + x_ = x.clone().detach() + + t = ( + (paddle.arange(0, e)).unsqueeze(0) + + paddle.arange(0, -s, -1).unsqueeze(-1) + ) % e + gate_logits = (1 / (t + 1)).astype("float32") + # gate_logits = F.softmax(paddle.randn([s,e]),-1).astype('float32') + gate_logits_ = gate_logits.clone().detach() + s = x.shape[0] + d = x.shape[1] + e = gate_logits.shape[1] + x.stop_gradient = False + x_.stop_gradient = False + gate_logits.stop_gradient = False + gate_logits_.stop_gradient = False + print(f"gate_logits:{gate_logits}") + + def check_ascend(index_rev, chunks): + for idx in index_rev.split(chunks.tolist()): + if len(idx) > 2: + assert (paddle.diff(idx) >= 0).all(), (index_rev,) + + ys, comm, scatter_idx = [], [], [] + for ilocal_expert in range(0, e, local_expert_num): + combine_weihgts, expert_id = gate_logits.topk(k=k, axis=1) + ( + y, + combine_weihgts, + scatter_index, + scatter_index_rev, + expert_offset, + expert_num_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, + combine_weihgts, + expert_id.astype("int32"), + k=k, + capacity=cap, + num_experts=gate_logits.shape[-1], + use_pad=False, + expert_start_index=ilocal_expert, + expert_end_index=ilocal_expert + local_expert_num, # k # cap + reverse_token_drop=False, + ) + check_ascend(scatter_index_rev, expert_num_local) + print(f"y:{y.mean(-1)}") + print(f"combine_weihgts:{combine_weihgts}") + print(f"expert_num_local:{expert_num_local}") + print(f"scatter_index:{scatter_index.transpose([1,0])}") + print(f"scatter_index_rev:{scatter_index_rev}") + + ys.append(y) + comm.append(combine_weihgts) + scatter_idx.append(scatter_index) + + comm_sum = paddle.stack(comm).sum(0) + ys_sum = paddle.concat(ys) + + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x_, + gate_logits_, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + ) + valid_y = y_.sum(-1) > 0.0 + y_2 = y_[valid_y].squeeze() + + print( + f""" + y: {ys_sum.astype("float32").mean(axis=-1)} + y_: {y_2.astype("float32").mean(axis=-1)} + + comm-weight: {comm_sum} + comm-weight_: {combine_weihgts_} + + expert_id:{expert_id} + scatter_index:{scatter_index} + scatter_index_rev: {scatter_index_rev} + expert_num_global:{expert_offset} + expert_num_local:{expert_num_local} + """ + ) + + print("<<< begin backward>>>") + + assert combine_weihgts_.shape == combine_weihgts.shape, ( + combine_weihgts_.shape, + combine_weihgts.shape, + ) + + dysum, dcombine_weights_sum = paddle.ones_like(ys_sum), paddle.randn( + comm_sum.shape + ).astype(comm_sum.dtype) + dy_, dcombine_weights_ = paddle.ones_like(y_), paddle.ones_like( + combine_weihgts_ + ) + dy_[~valid_y] = 0 + + y_shapes = [len(y) for y in ys] + for dyy, yy, commm in zip( + paddle.split(dysum, y_shapes), + ys, + comm, + ): + print(f"dyy:{dyy.shape}, {yy.shape} {commm.shape}") + paddle.autograd.backward([yy, commm], [dyy, dcombine_weights_sum]) + print(x.grad.astype("float32").mean(axis=-1)) + print(f"bwd original:{y_.shape} {dy_.shape}") + paddle.autograd.backward([y_, combine_weihgts_], [dy_, dcombine_weights_]) + + print(x_.grad.astype("float32").mean(axis=-1)) + + print( + f""" + x: {x.grad.astype('float32').mean(axis=-1)} + x_: {x_.grad.astype('float32').mean(axis=-1)} + """ + ) + + +def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(): + + S, E, D = 3, 4, 3 + k = 2 + capacity = 2 + x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") + cw = paddle.randn([S, k]) + eid = paddle.to_tensor( + [[0, 1], [0, 1], [0, 2]], dtype="int32" + ) # 1 # 2 # 3 + ( + y, + cw_, + idx, + idx_rev, + num_ex_global, + num_ex_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, cw, eid, k, capacity, E, False, 0, 2, reverse_token_drop=True + ) + + y0, y1 = y.split([i for i in num_ex_local.tolist() if i > 0]) + assert y0[:, 0].astype("int32").tolist() == [2, 3], y0[:, 0] + assert y1[:, 0].astype("int32").tolist() == [1, 2] + + +def test_moe_ops_partial_nosoftmax_topk_empty_output(): + + S, E, D = 3, 4, 3 + k = 2 + capacity = 2 + x = (paddle.arange(S) + 1).unsqueeze(-1).expand([S, D]).astype("bfloat16") + cw = paddle.randn([S, k]) + paddle.device.synchronize() + eid = paddle.to_tensor( + [[0, 1], [0, 1], [0, 2]], dtype="int32" + ) # 1 # 2 # 3 + ( + y, + cw_, + idx, + idx_rev, + num_ex_global, + num_ex_local, + ) = moe_gate_dispatch_partial_nosoftmaxtopk( + x, cw, eid, k, capacity, E, False, 3, 4, reverse_token_drop=True + ) + assert all([i == 0 for i in num_ex_local.tolist()]), num_ex_local + + +class TestAddition(unittest.TestCase): + + def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op(self): + test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op() + + def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop(self): + test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop() + + def test_moe_ops_partial_nosoftmax_topk_empty_output(self): + test_moe_ops_partial_nosoftmax_topk_empty_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py new file mode 100644 index 00000000000000..56d9ddd397a776 --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute.py @@ -0,0 +1,206 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_permute, +) + +os.environ["FLAGS_flash_attn_version"] = "v1" +os.environ["FLAGS_cudnn_deterministic"] = "1" +os.environ["FLAGS_embedding_deterministic"] = "1" + + +class TestFused(unittest.TestCase): + + def test_moe_ops(self): + """ + test `moe-ops` w/ bias + """ + S, E, D = 8192, 64, 128 + k = 4 + x = paddle.randn([S, D], dtype="bfloat16") + gate_logits = paddle.randn([S, E], dtype="float32") + x_ = x.clone() + gate_logits_ = gate_logits.clone() + x.stop_gradient = True + x_.stop_gradient = True + gate_logits.stop_gradient = True + gate_logits_.stop_gradient = True + bias = paddle.zeros([E], dtype="float32") + cap = 512 + + y, combine_weihgts, scatter_index, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x, + gate_logits, + None, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + ) + + y_, combine_weihgts_, scatter_index_, expert_offset_, expert_id_ = ( + moe_gate_dispatch( + x_, + gate_logits_, + bias + 1, # +1也不会破坏路由结果 + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + ) + bias_unbalanced = bias.clone() + bias_unbalanced[0] += 1 + ( + y__, + combine_weihgts__, + scatter_index__, + expert_offset__, + expert_id__, + ) = moe_gate_dispatch( + x_, + gate_logits_, + bias_unbalanced, + k=k, + capacity=cap, + use_pad=True, # k # cap + ) + np.testing.assert_equal( + y.astype("float32").numpy(), + y_.astype("float32").numpy(), + err_msg="incubate w bias not match", + ) + # bias 不影响 prob 概率 + np.testing.assert_equal( + combine_weihgts.astype("float32").numpy(), + combine_weihgts_.astype("float32").numpy(), + err_msg="incubate w bias not match", + ) + np.testing.assert_( + ( + y.astype("float32").numpy(0) != y__.astype("float32").numpy() + ).any(), + ) + + +class TestDispatchPermute(unittest.TestCase): + def get_detached_input(self, input, prob): + ret_input = input.detach() + ret_prob = prob.detach() + ret_input.stop_gradient = input.stop_gradient + ret_prob.stop_gradient = prob.stop_gradient + return ret_input, ret_prob + + def get_stage_input_list(self, x, world_size, stage): + print(world_size, stage, x.shape) + x = x.reshape([world_size * stage, -1, x.shape[-1]]) + stage_input_list = [] + x_list = paddle.split(x, num_or_sections=(world_size * stage), axis=0) + for stage_id in range(stage): + stage_input_list.append( + paddle.unsqueeze( + paddle.concat(x_list[stage_id::stage], axis=0), axis=0 + ) + ) + stage_input_list = paddle.concat(stage_input_list, axis=0) + return stage_input_list + + def test_moe_permute_ops(self): + paddle.seed(2025) + + test_cases = [ + (8, 4, 2), + (64, 16, 32), + (1024, 1024, 1024), + (8, 2, 4), + (4096, 4096, 4096), + ] + cases = list(zip(*test_cases)) + for _, case in enumerate(cases): + world_size, num_experts, num_tokens, k, hidden_size = case + capacity = num_tokens // k + stages = num_experts // world_size + + input = paddle.randn([num_tokens, hidden_size], dtype="float32") + prob_logits = paddle.randn( + [num_tokens, num_experts], dtype="float32" + ) + prob = F.softmax(prob_logits, axis=-1) + input.stop_gradient = False + prob.stop_gradient = False + + compat_args = (None,) + + ref_input, ref_prob = self.get_detached_input(input, prob) + ( + ref_dispatched_input, + ref_combine_weights_unnorm, + ref_scatter_index, + ref_dispatch_mask, + _, + ) = moe_gate_dispatch( + ref_input, + ref_prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + + ref_stage_input_list = self.get_stage_input_list( + ref_dispatched_input, world_size, stages + ) + + test_input, test_prob = self.get_detached_input(input, prob) + ( + test_dispatched_input, + test_combine_weights_unnorm, + test_scatter_index, + test_dispatch_mask, + _, + ) = moe_gate_dispatch_permute( + test_input, + test_prob, + *compat_args, + k=k, + capacity=capacity, + world_size=world_size, + ) + + np.testing.assert_equal( + test_dispatched_input.shape, + ref_stage_input_list.shape, + err_msg="moe_permute_ops not match", + ) + np.testing.assert_equal( + test_dispatched_input._md5sum(), + ref_stage_input_list._md5sum(), + err_msg="moe_permute_ops not match", + ) + + +if __name__ == "__main__": + + unittest.main() diff --git a/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py new file mode 100644 index 00000000000000..bf03ffa20d4c1d --- /dev/null +++ b/test/legacy_test/test_incubate_moe_gate_dispatch_w_permute_bwd.py @@ -0,0 +1,175 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.incubate.nn.functional import ( + moe_gate_dispatch, + moe_gate_dispatch_permute, +) + +batch_size = 4 +hidden_size = 2 +k = 16 +capacity = 2 +num_experts = 16 + +world_size = 2 + + +class TestLayer(paddle.nn.Layer): + def forward(self, x, gate_prob, k, capacity): + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch(x, gate_prob, None, k, capacity, True) + ) + return y, combine_weights, scatter_index, expert_offset, expert_id + + +class TestLayerPermute(paddle.nn.Layer): + def forward(self, x, gate_prob, k, capacity): + y, combine_weights, scatter_index, expert_offset, expert_id = ( + moe_gate_dispatch_permute( + x, gate_prob, None, k, capacity, world_size=world_size + ) + ) + return y, combine_weights, scatter_index, expert_offset, expert_id + + +def check_backward_correctness(layer_cls): + paddle.seed(1024) + + dtype = "bfloat16" + layer = layer_cls() + input = paddle.randn([batch_size, hidden_size]) + + gate_weight = paddle.randn([hidden_size, num_experts]) + logits = paddle.matmul(input, gate_weight) + gate_prob = F.softmax(logits, axis=-1) + print(f"gate_prob: {gate_prob}") + + input = paddle.cast(input, "bfloat16") + input.stop_gradient = False + gate_prob.stop_gradient = False + + output, combine_weights, scatter_index, expert_offset, expert_id = layer( + input, gate_prob, k, capacity + ) + + print(f"output: {output}") + print(f"combine_weights: {combine_weights}") + print(f"scatter_index: {scatter_index}") + print(f"expert_offset: {expert_offset}") + print(f"expert_id: {expert_id}") + + # output_g = paddle.randn(output.shape).astype(output.dtype) + # combine_weights_g = paddle.randn(combine_weights.shape).astype(combine_weights.dtype) + output_g = paddle.ones_like(output) + combine_weights_g = paddle.ones_like(combine_weights) + print(f"output_g: {output_g}") + print(f"combine_weights_g: {combine_weights_g}") + + paddle.autograd.backward( + tensors=[output, combine_weights], + grad_tensors=[output_g, combine_weights_g], + ) + # 数值估算 + epsilon = 0.005 + input_numpy = input.detach().astype("float32").numpy() + num_grad = paddle.zeros_like(input) + flattened = num_grad.reshape([-1]) + + for i in range(input.numel()): + input_pos = input_numpy.copy() + input_neg = input_numpy.copy() + input_pos.flat[i] += epsilon + input_neg.flat[i] -= epsilon + + output_pos, _, _, _, _ = layer( + paddle.to_tensor(input_pos), gate_prob, k, capacity + ) + output_neg, _, _, _, _ = layer( + paddle.to_tensor(input_neg), gate_prob, k, capacity + ) + + ''' + flattened[i] = (output_pos.astype("float32").numpy() - output_neg.astype("float32").numpy()).sum() / ( + 2 * epsilon + ) + ''' + grad_value = (output_pos - output_neg).sum() / (2 * epsilon) + flattened[i] = grad_value + + flattened = flattened.reshape(input.shape) + + print(f"input gradient: {input.grad}") + print(f"numerical gradient: {flattened}") + np.testing.assert_allclose( + input.grad.astype("float32").numpy(), + flattened.astype("float32").numpy(), + rtol=1e-5, + atol=0, + ) + + # 数值估算 gate_prob + epsilon = 0.0005 + gate_prob_numpy = gate_prob.detach().astype("float32").numpy() + num_grad = paddle.zeros_like(gate_prob) + flattened = num_grad.reshape([-1]) + + for i in range(gate_prob.numel()): + input_pos = gate_prob_numpy.copy() + input_neg = gate_prob_numpy.copy() + input_pos.flat[i] += epsilon + input_neg.flat[i] -= epsilon + + _, output_pos, _, _, _ = layer( + input, paddle.to_tensor(input_pos), k, capacity + ) + _, output_neg, _, _, _ = layer( + input, paddle.to_tensor(input_neg), k, capacity + ) + + grad_value = paddle.to_tensor( + (output_pos.numpy() - output_neg.numpy()).sum() / (2 * epsilon) + ) + flattened[i] = grad_value + + flattened = flattened.reshape(gate_prob.shape) + + print(f"gate_prob gradient: {gate_prob.grad}") + print(f"numerical gradient: {flattened}") + np.testing.assert_allclose( + gate_prob.grad.astype("float32").numpy(), + flattened.astype("float32").numpy(), + rtol=1e-4, + atol=0, + ) + + +class TestFused(unittest.TestCase): + def test_moe_backward(self): + check_backward_correctness(TestLayer) + + def test_moe_permute_backward(self): + check_backward_correctness(TestLayerPermute) + + +if __name__ == "__main__": + unittest.main()