Skip to content

OP move task from ernie-core to framework #72957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
77b7b8a
init
pesionzhao May 20, 2025
492ac04
insert moe_combine
pesionzhao May 21, 2025
22e3643
init
pesionzhao May 21, 2025
bebaf44
update yaml
pesionzhao May 21, 2025
df96e81
update python API
pesionzhao May 21, 2025
4c93cdd
Merge branch 'develop' into zps_insert
pesionzhao May 21, 2025
9c7cf25
delete useless header file
pesionzhao May 21, 2025
81efa89
remove supported by DCU
pesionzhao May 21, 2025
f694787
merge from zps_insert
pesionzhao May 21, 2025
56ce7d2
add expand_modality_expert_id kernel
pesionzhao May 21, 2025
e2e6ac7
reorder the new code and refine OP type
pesionzhao May 21, 2025
ddf247c
add unit test
pesionzhao May 22, 2025
ad82cc8
add cal_aux_loss_op and build_src_rank_and_local_expert_id_op
feixi21 May 22, 2025
8ade3ac
moegatedispatch init
pesionzhao May 23, 2025
b8c9636
insert moegatedispatch
pesionzhao May 23, 2025
e5bfdc9
remove DCU support
pesionzhao May 23, 2025
81d9fbc
fix-bugs
feixi21 May 22, 2025
4c04429
fix log2 in windows maybe
pesionzhao May 23, 2025
fb784e3
update header file format
pesionzhao May 23, 2025
d25d23e
fix-bugs
feixi21 May 23, 2025
47f010d
delete op test for pass CI
pesionzhao May 23, 2025
80bd65f
add cmath header
pesionzhao May 23, 2025
9e889aa
tmp
pesionzhao May 23, 2025
c78a3cb
pass int_bincount
May 23, 2025
19085da
add moe_dispatch_bwd
feixi21 May 23, 2025
aca0c5d
add moe_gate_dispatch
feixi21 May 23, 2025
3e4e392
fix-bugs
feixi21 May 23, 2025
0242f9c
fix optional Tensor
feixi21 May 25, 2025
08a93f8
Merge commit 'refs/pull/72835/head (moe_combine and expand_modality_e…
A-nnonymous May 26, 2025
9b09cff
update cal_aux_loss_kernel
feixi21 May 26, 2025
3b3c98c
Finished moe_combine & expand_modality_expert_id integrate and optests.
A-nnonymous May 26, 2025
9e84e5b
add python interface
feixi21 May 26, 2025
09ea5b4
Add Li Zhou's stage 1 PR
A-nnonymous May 26, 2025
6b81588
nosoftmax forward has finished
pesionzhao May 26, 2025
829bad3
finishi fused_rms_norm fwd
May 26, 2025
4f5ede6
finish rms_norm bwd
May 26, 2025
8797589
finish rms norm bwd
May 26, 2025
b321097
add optional in ops.yaml
feixi21 May 27, 2025
3e406ca
nosoftmax bwd has finished
pesionzhao May 27, 2025
9105be0
update python api
pesionzhao May 27, 2025
0f09636
Verified cal_aux_loss op and bwd.
A-nnonymous May 27, 2025
a12713b
Verified build_src_rank_and_local_expert_id
A-nnonymous May 27, 2025
aef7e62
Merged Huaijin Zheng's PR#72909
A-nnonymous May 27, 2025
8aa9b33
gate_dispatch_permute has finished
pesionzhao May 27, 2025
cdece27
finished all work
pesionzhao May 27, 2025
7e529c3
Verified fused_rms_norm_ext(with bwd) and int_bincount.
A-nnonymous May 28, 2025
0d16972
Merge commit 'refs/pull/72835/head(peisen's stage2)' of https://githu…
A-nnonymous May 28, 2025
f53449f
Fix conflict
A-nnonymous May 28, 2025
5d2e65d
Add stage2 fwd and bwd optests.
A-nnonymous May 28, 2025
b0dc90c
Clean print
A-nnonymous May 28, 2025
9cd02e8
Fix conflict, move some headers.
A-nnonymous May 28, 2025
e8e494e
sync with dev
A-nnonymous May 28, 2025
f3d7320
Add incubate port.
A-nnonymous May 28, 2025
c76a357
fix miscs
A-nnonymous May 28, 2025
d4a3472
Fix module issue
A-nnonymous May 28, 2025
36a450f
Add missing yamls
A-nnonymous May 28, 2025
89402c3
Fix stale package problems
A-nnonymous May 28, 2025
422850e
fix moe_combine bug.
A-nnonymous May 28, 2025
4c6ab13
Fix miscs
A-nnonymous May 28, 2025
81aac42
Align with original initializations.
A-nnonymous May 28, 2025
82acb48
fix typos and pre-commit warnings
A-nnonymous May 28, 2025
bba0636
Fix miscs
A-nnonymous May 28, 2025
fcc4f81
try to pass CI
pesionzhao May 29, 2025
29e0b02
format header file
pesionzhao May 29, 2025
c5cab08
remove win32 supported
pesionzhao May 29, 2025
6a1a318
check OP type
pesionzhao May 30, 2025
d290b3b
remove optest for WIN & APPLE
pesionzhao May 30, 2025
31d6465
fix bug for (int32_t and int)
pesionzhao May 30, 2025
0b990d6
rename fused_rms_norm op
pesionzhao May 30, 2025
2fe0ef3
select op test env not for Volta
pesionzhao May 30, 2025
e8a30df
fix openblas mistake
pesionzhao May 30, 2025
4fd6186
CMake code format
pesionzhao Jun 3, 2025
ebd2244
fix bugs in CPU
pesionzhao Jun 3, 2025
10f6058
CodeStyle format
pesionzhao Jun 3, 2025
0637a02
fix bugs in CPU
pesionzhao Jun 3, 2025
e4ecf9b
fix bugs in CPU
pesionzhao Jun 3, 2025
54dda45
skip some op when CUDA<12.0
pesionzhao Jun 3, 2025
0d1b3d0
skip op when CUDA<12.0
pesionzhao Jun 3, 2025
8e0817a
fix bugs in CPU
pesionzhao Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,106 @@ void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
}
}

void MoeCombineGradInferMeta(const MetaTensor& x,
const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& y,
MetaTensor* grad_x,
MetaTensor* grad_combine_weights_helper) {
auto x_dim = x.dims();
auto combine_weights_shape = combine_weights.dims();
PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
errors::InvalidArgument("The input X should have 2 dimensions"
"But received X's dimension = %d",
x_dim.size()));
PADDLE_ENFORCE_EQ(
(scatter_index.dtype() == phi::DataType::INT32),
true,
errors::InvalidArgument("The input scatter_index type should be int32"
"But received scatter_index type = %s",
scatter_index.dtype()));
grad_x->set_dims(common::make_ddim({x_dim[0], x_dim[1]}));
grad_x->set_dtype(x.dtype());
grad_combine_weights_helper->set_dims(common::make_ddim(
{combine_weights_shape[0], combine_weights_shape[1], x_dim[1]}));
grad_combine_weights_helper->set_dtype(x.dtype());
}

void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(
const MetaTensor& combine_weights_out,
const MetaTensor& scatter_index,
const MetaTensor& scatter_index_rev,
const MetaTensor& expert_offset,
const MetaTensor& expert_offset_local,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_out_grad,
int64_t k,
int64_t capacity,
bool use_pad,
int64_t expert_start_index,
int64_t expert_end_index,
MetaTensor* x_grad,
MetaTensor* combine_weights_grad) {
int64_t num_experts = expert_offset.dims()[0];
int64_t hidden_size = y_grad.dims()[1];
int64_t num_rows = scatter_index.dims()[1];
PADDLE_ENFORCE_GT(num_experts,
0,
common::errors::InvalidArgument(
"Input num_experts should be greater than 0"));
PADDLE_ENFORCE_EQ((expert_offset.dtype() == phi::DataType::INT64),
true,
common::errors::InvalidArgument(
"Input expert_offset type should be int64"));
if (use_pad) {
PADDLE_ENFORCE_GE(num_experts,
y_grad.dims()[0] / capacity,
common::errors::InvalidArgument(
"Number of experts should be greater than or equal "
"to y_grad.dims()[0]/capacity"));
} else {
PADDLE_ENFORCE_GT(y_grad.dims()[0],
0,
common::errors::InvalidArgument(
"Input y_grad.dims()[0] should be greater than 0"));
}
combine_weights_grad->set_dims(combine_weights_out_grad.dims());
combine_weights_grad->set_dtype(phi::DataType::FLOAT32);
x_grad->set_dims({num_rows, hidden_size});
x_grad->set_dtype(y_grad.dtype());
}

void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& expert_id,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_grad,
int64_t k,
int64_t capacity,
int64_t world_size,
MetaTensor* x_grad,
MetaTensor* gate_logits_grad) {
auto y_grad_dims = y_grad.dims();
PADDLE_ENFORCE_EQ(
y_grad_dims[1],
world_size,
common::errors::InvalidArgument(
"The second dimension of y_grad should be equal to world_size, but "
"received y_grad_dims[1] = %d, world_size = %d",
y_grad_dims[1],
world_size));
int64_t num_local_experts = y_grad_dims[0];
int64_t num_experts = world_size * num_local_experts;
int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1];
int64_t num_rows = scatter_index.dims()[1];
x_grad->set_dims({num_rows, hidden_size});
x_grad->set_dtype(y_grad.dtype());
gate_logits_grad->set_dims({num_rows, num_experts});
gate_logits_grad->set_dtype(phi::DataType::FLOAT32);
}

void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad) {
Expand Down Expand Up @@ -1887,4 +1987,89 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
value_grad->share_lod(values);
}
}

void CalAuxLossGradInferMeta(const MetaTensor& gate_prob,
const MetaTensor& seqlen_float,
const MetaTensor& ce,
const MetaTensor& l_aux_loss_grad,
const int64_t num_experts,
const bool use_group,
const int64_t moe_k,
MetaTensor* gate_prob_grad) {
auto gate_prob_dims = gate_prob.dims();

PADDLE_ENFORCE_EQ(
gate_prob.dtype(),
l_aux_loss_grad.dtype(),
errors::InvalidArgument(
"The input out_grad type should be equal to gate_prob type"));

gate_prob_grad->set_dims({gate_prob_dims});
gate_prob_grad->set_dtype(gate_prob.dtype());
}

void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& expert_id,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_grad,
const int64_t k,
const int64_t capacity,
const bool use_pad,
MetaTensor* x_grad,
MetaTensor* gate_logits_grad) {
auto combine_weights_dims = combine_weights.dims();
auto scatter_index_dims = scatter_index.dims();
auto expert_id_dims = expert_id.dims();
auto y_grad_dims = y_grad.dims();
auto combine_weights_grad_dims = combine_weights_grad.dims();

PADDLE_ENFORCE_EQ(combine_weights_dims.size(),
2,
errors::InvalidArgument(
"Input combine_weights should have 2 dimensions"));

PADDLE_ENFORCE_EQ(
scatter_index_dims.size(),
2,
errors::InvalidArgument("Input scatter_index should have 2 dimensions"));

PADDLE_ENFORCE_EQ(
expert_id_dims.size(),
2,
errors::InvalidArgument("Input expert_id should have 2 dimensions"));

PADDLE_ENFORCE_EQ(
y_grad_dims.size(),
2,
errors::InvalidArgument("Input y_grad should have 2 dimensions"));

PADDLE_ENFORCE_EQ(combine_weights_grad_dims.size(),
2,
errors::InvalidArgument(
"Input combine_weights_grad should have 2 dimensions"));

int64_t num_experts = y_grad_dims[0] / capacity;
int64_t hidden_size = y_grad_dims[1];

int64_t num_rows = scatter_index_dims[1];

gate_logits_grad->set_dims(common::make_ddim({num_rows, num_experts}));
gate_logits_grad->set_dtype(phi::DataType::FLOAT32);

x_grad->set_dims(common::make_ddim({num_rows, hidden_size}));
x_grad->set_dtype(y_grad.dtype());
}
void FusedRMSNormGradInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& dy,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad) {
x_grad->set_dims(x.dims());
x_grad->set_dtype(x.dtype());
scale_grad->set_dims(scale.dims());
scale_grad->set_dtype(scale.dtype());
}
} // namespace phi
65 changes: 65 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,44 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
MetaTensor* value_grad,
MetaTensor* bias_grad);

void MoeCombineGradInferMeta(const MetaTensor& x,
const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& grad_y,
MetaTensor* grad_x,
MetaTensor* grad_combine_weights_helper);
// Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev,
// Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor
// combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t
// expert_start_index, int64_t expert_end_index)
// output : Tensor(x_grad), Tensor(combine_weights_grad)
void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(
const MetaTensor& combine_weights_out,
const MetaTensor& scatter_index,
const MetaTensor& scatter_index_rev,
const MetaTensor& expert_offset,
const MetaTensor& expert_offset_local,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_out_grad,
int64_t k,
int64_t capacity,
bool use_pad,
int64_t expert_start_index,
int64_t expert_end_index,
MetaTensor* x_grad,
MetaTensor* combine_weights_grad);

void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& expert_id,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_grad,
int64_t k,
int64_t capacity,
int64_t world_size,
MetaTensor* x_grad,
MetaTensor* gate_logits_grad);

void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad);
Expand Down Expand Up @@ -680,4 +718,31 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* value_grad);

void CalAuxLossGradInferMeta(const MetaTensor& gate_prob,
const MetaTensor& seqlen_float,
const MetaTensor& ce,
const MetaTensor& l_aux_loss_grad,
const int64_t num_experts,
const bool use_group,
const int64_t moe_k,
MetaTensor* gate_prob_grad);

void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights,
const MetaTensor& scatter_index,
const MetaTensor& expert_id,
const MetaTensor& y_grad,
const MetaTensor& combine_weights_grad,
const int64_t k,
const int64_t capacity,
const bool use_pad,
MetaTensor* x_grad,
MetaTensor* gate_logits_grad);

void FusedRMSNormGradInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& dy,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad);
} // namespace phi
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4592,6 +4592,20 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
out->set_dtype(scale.dtype());
}

void FusedRMSNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
MetaTensor* y,
MetaTensor* invvar) {
// Y: same shape, dtype, layout as X
y->set_dims(x.dims());
y->set_dtype(x.dtype());
// mean & invvar: 1-D length = x.dims()[0]
int64_t rows = x.dims()[0];
invvar->set_dims(DDim({rows}));
invvar->set_dtype(DataType::FLOAT32);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,5 +790,10 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
const std::string& algo,
const int32_t group_size,
MetaTensor* out);
void FusedRMSNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
MetaTensor* y,
MetaTensor* invvar);

} // namespace phi
Loading
Loading