Skip to content

OP move task from ernie-core to framework #72957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
Jun 4, 2025

Conversation

A-nnonymous
Copy link
Contributor

@A-nnonymous A-nnonymous commented May 27, 2025

PR Category

User Experience

PR Types

Improvements

Description

OP move task from ernie-core to framework. Locally picked from migration PR 72835, 72875, 72909 by @pesionzhao @feixi21 @zhenghuaijin verified and merged by @A-nnonymous

Ops Ready list:

  • paddle.incubate.nn.functional.cal_aux_loss (and grad).
  • paddle.incubate.nn.functional.moe_combine (and grad).
  • paddle.incubate.nn.functional.expand_modality_expert_id.
  • paddle.incubate.nn.functional.build_src_rank_and_local_expert_id.
  • paddle.incubate.nn.functional.int_bincount.
  • paddle.incubate.nn.functional.fused_rms_norm_ext.
  • paddle.incubate.nn.functional.moe_gate_dispatch(and grad).
  • paddle.incubate.nn.functional.moe_gate_dispatch_permute(and grad).
  • paddle.incubate.nn.functional.moe_gate_dispatch_partial_nosoftmaxtopk(and grad).

pcard-91067

@A-nnonymous
Copy link
Contributor Author

/re-run all-failed

@phlrain phlrain self-requested a review June 4, 2025 06:10
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关类型提示可以下一个 PR 再改

from paddle.base.layer_helper import LayerHelper


def fused_rms_norm_ext(x, scale, epsilon=1e-5, name=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以补一下类型提示

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到,我们在下一阶段将改进,这阶段不动它的代码


def build_src_rank_and_local_expert_id(
expert_num_global_tensor: Tensor,
expert_num_global: list,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
expert_num_global: list,
expert_num_global: list[Xxx],

这里需要补一下泛型内部类型,是 list[Tensor] 还是 list[int] 之类的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到

from paddle.base.layer_helper import LayerHelper


def int_bincount(x, low, high, dtype=None, name=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要补一下类型提示

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should polish the code asap.

@A-nnonymous
Copy link
Contributor Author

/re-run all-failed

@@ -0,0 +1,25 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023 -> 2025

Comment on lines +15 to +36
#pragma once
#ifdef PADDLE_WITH_CUDA
#include "paddle/common/exception.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/moe_kernel_impl.h"

namespace phi {

template <typename T, int64_t vec_size>
__global__ void gather_with_mask_permute_kernel(
const T* dy, // [s*k, d]
const int* scatter_index, // [s, k]
const float* combine_weights, // [s, k]
T* dx, // [s, d]
int64_t num_rows, // s
int64_t k, // k
int64_t dim, // d
int64_t N,
int64_t num_active, // skip > num_active pos is num_active specified
int64_t s_shared_num,
int64_t capacity,
int64_t world_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda 的实现放到 gpu目录下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023 -> 2025

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到gpu目录

@@ -0,0 +1,649 @@
// NOLINT
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022 -> 2025

@time: 2024/09/21 15:11:10
@Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved

这一行开始写关于本文件的说明与解释
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是什么?

)
weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias

# num_expert_per_modality == 0 时只执行 group-expert expand,不执行 multimodal-expand
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成英文注释

Comment on lines +122 to +129
初始化MoE层。

Args:
gate (nn.Layer): 智能门控层,用于选择需要使用的专家。
experts (List[nn.Layer]): 需要使用的专家列表。
layer_idx (int): 当前MoE层的索引。
group (Group): 分布式通信组。默认值为None。
recompute (bool): 是否在每个训练迭代中重新计算MoE输出。默认值为False。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成英文

Comment on lines +150 to +165
"""
对`gate_prob` 进行 softmax 并根据结果选取 topk 路由expert。 最后根据 expert 号对 `x` 进行重排。
Args:
x: [s, d] 输入的 activateion
gate_prob: [s, e]
k: int
capacity: int #no use
Returns:
y: [s*k, d] 将所有 `x` 根据其路由的 `expert-id` 升序的排序,融合到 s 维度。
当截断发生时 s 会比输入 s 小。
combine_weights: [s, k], float: 每个 token 第 k 选择的 expert 的权重。
当截断发生时 s 会比输入 s 小。
scatter_index: [k, s] : 每个 token 第 k 次选择对应到 `y` 中的位置。
expert_offset: [e]: `y`中每个 expert-id 的分割位置。
expert_id: [s] `x` 中激活的 expert 号
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关问题根据结论再提PR修改

@A-nnonymous
Copy link
Contributor Author

/re-run all-failed

@A-nnonymous A-nnonymous closed this Jun 4, 2025
@A-nnonymous A-nnonymous reopened this Jun 4, 2025
@PaddlePaddle PaddlePaddle locked as off-topic and limited conversation to collaborators Jun 4, 2025
@PaddlePaddle PaddlePaddle unlocked this conversation Jun 4, 2025
@phlrain phlrain merged commit 308e758 into PaddlePaddle:develop Jun 4, 2025
135 of 183 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants