Skip to content

Commit 308e758

Browse files
A-nnonymouspesionzhaofeixi21zhenghuaijin
authored
OP move task from ernie-core to framework (#72957)
* init * insert moe_combine * init * update yaml * update python API * delete useless header file * remove supported by DCU * add expand_modality_expert_id kernel * reorder the new code and refine OP type * add unit test * add cal_aux_loss_op and build_src_rank_and_local_expert_id_op * moegatedispatch init * insert moegatedispatch * remove DCU support * fix-bugs fix-bugs fix-bugs * fix log2 in windows maybe * update header file format * fix-bugs * delete op test for pass CI * add cmath header * tmp * pass int_bincount * add moe_dispatch_bwd * add moe_gate_dispatch * fix-bugs * fix optional Tensor * update cal_aux_loss_kernel * Finished moe_combine & expand_modality_expert_id integrate and optests. * add python interface * nosoftmax forward has finished * finishi fused_rms_norm fwd * finish rms_norm bwd * finish rms norm bwd * add optional in ops.yaml * nosoftmax bwd has finished * update python api * Verified cal_aux_loss op and bwd. * Verified build_src_rank_and_local_expert_id * gate_dispatch_permute has finished * Verified fused_rms_norm_ext(with bwd) and int_bincount. * Add stage2 fwd and bwd optests. * Clean print * Fix conflict, move some headers. * sync with dev * Add incubate port. * fix miscs * Fix module issue * Add missing yamls * Fix stale package problems * fix moe_combine bug. * Fix miscs * Align with original initializations. * fix typos and pre-commit warnings * Fix miscs * try to pass CI * format header file * remove win32 supported * check OP type * remove optest for WIN & APPLE * fix bug for (int32_t and int) * rename fused_rms_norm op * select op test env not for Volta * fix openblas mistake * CMake code format * fix bugs in CPU * CodeStyle format * fix bugs in CPU * fix bugs in CPU * skip some op when CUDA<12.0 * skip op when CUDA<12.0 * fix bugs in CPU --------- Co-authored-by: pesionzhao <pesionzhao@gmail.com> Co-authored-by: feixi21 <1802550529@qq.com> Co-authored-by: zhenghuaijin <zhenghuaijin@baidu.com>
1 parent 6fbf44a commit 308e758

File tree

69 files changed

+11235
-7
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+11235
-7
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,106 @@ void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
12171217
}
12181218
}
12191219

1220+
void MoeCombineGradInferMeta(const MetaTensor& x,
1221+
const MetaTensor& combine_weights,
1222+
const MetaTensor& scatter_index,
1223+
const MetaTensor& y,
1224+
MetaTensor* grad_x,
1225+
MetaTensor* grad_combine_weights_helper) {
1226+
auto x_dim = x.dims();
1227+
auto combine_weights_shape = combine_weights.dims();
1228+
PADDLE_ENFORCE_EQ(
1229+
x_dim.size(),
1230+
2,
1231+
errors::InvalidArgument("The input X should have 2 dimensions"
1232+
"But received X's dimension = %d",
1233+
x_dim.size()));
1234+
PADDLE_ENFORCE_EQ(
1235+
(scatter_index.dtype() == phi::DataType::INT32),
1236+
true,
1237+
errors::InvalidArgument("The input scatter_index type should be int32"
1238+
"But received scatter_index type = %s",
1239+
scatter_index.dtype()));
1240+
grad_x->set_dims(common::make_ddim({x_dim[0], x_dim[1]}));
1241+
grad_x->set_dtype(x.dtype());
1242+
grad_combine_weights_helper->set_dims(common::make_ddim(
1243+
{combine_weights_shape[0], combine_weights_shape[1], x_dim[1]}));
1244+
grad_combine_weights_helper->set_dtype(x.dtype());
1245+
}
1246+
1247+
void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(
1248+
const MetaTensor& combine_weights_out,
1249+
const MetaTensor& scatter_index,
1250+
const MetaTensor& scatter_index_rev,
1251+
const MetaTensor& expert_offset,
1252+
const MetaTensor& expert_offset_local,
1253+
const MetaTensor& y_grad,
1254+
const MetaTensor& combine_weights_out_grad,
1255+
int64_t k,
1256+
int64_t capacity,
1257+
bool use_pad,
1258+
int64_t expert_start_index,
1259+
int64_t expert_end_index,
1260+
MetaTensor* x_grad,
1261+
MetaTensor* combine_weights_grad) {
1262+
int64_t num_experts = expert_offset.dims()[0];
1263+
int64_t hidden_size = y_grad.dims()[1];
1264+
int64_t num_rows = scatter_index.dims()[1];
1265+
PADDLE_ENFORCE_GT(num_experts,
1266+
0,
1267+
common::errors::InvalidArgument(
1268+
"Input num_experts should be greater than 0"));
1269+
PADDLE_ENFORCE_EQ((expert_offset.dtype() == phi::DataType::INT64),
1270+
true,
1271+
common::errors::InvalidArgument(
1272+
"Input expert_offset type should be int64"));
1273+
if (use_pad) {
1274+
PADDLE_ENFORCE_GE(num_experts,
1275+
y_grad.dims()[0] / capacity,
1276+
common::errors::InvalidArgument(
1277+
"Number of experts should be greater than or equal "
1278+
"to y_grad.dims()[0]/capacity"));
1279+
} else {
1280+
PADDLE_ENFORCE_GT(y_grad.dims()[0],
1281+
0,
1282+
common::errors::InvalidArgument(
1283+
"Input y_grad.dims()[0] should be greater than 0"));
1284+
}
1285+
combine_weights_grad->set_dims(combine_weights_out_grad.dims());
1286+
combine_weights_grad->set_dtype(phi::DataType::FLOAT32);
1287+
x_grad->set_dims({num_rows, hidden_size});
1288+
x_grad->set_dtype(y_grad.dtype());
1289+
}
1290+
1291+
void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights,
1292+
const MetaTensor& scatter_index,
1293+
const MetaTensor& expert_id,
1294+
const MetaTensor& y_grad,
1295+
const MetaTensor& combine_weights_grad,
1296+
int64_t k,
1297+
int64_t capacity,
1298+
int64_t world_size,
1299+
MetaTensor* x_grad,
1300+
MetaTensor* gate_logits_grad) {
1301+
auto y_grad_dims = y_grad.dims();
1302+
PADDLE_ENFORCE_EQ(
1303+
y_grad_dims[1],
1304+
world_size,
1305+
common::errors::InvalidArgument(
1306+
"The second dimension of y_grad should be equal to world_size, but "
1307+
"received y_grad_dims[1] = %d, world_size = %d",
1308+
y_grad_dims[1],
1309+
world_size));
1310+
int64_t num_local_experts = y_grad_dims[0];
1311+
int64_t num_experts = world_size * num_local_experts;
1312+
int64_t hidden_size = y_grad_dims[y_grad_dims.size() - 1];
1313+
int64_t num_rows = scatter_index.dims()[1];
1314+
x_grad->set_dims({num_rows, hidden_size});
1315+
x_grad->set_dtype(y_grad.dtype());
1316+
gate_logits_grad->set_dims({num_rows, num_experts});
1317+
gate_logits_grad->set_dtype(phi::DataType::FLOAT32);
1318+
}
1319+
12201320
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
12211321
const MetaTensor& out_grad,
12221322
std::vector<MetaTensor*> x_grad) {
@@ -1887,4 +1987,89 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
18871987
value_grad->share_lod(values);
18881988
}
18891989
}
1990+
1991+
void CalAuxLossGradInferMeta(const MetaTensor& gate_prob,
1992+
const MetaTensor& seqlen_float,
1993+
const MetaTensor& ce,
1994+
const MetaTensor& l_aux_loss_grad,
1995+
const int64_t num_experts,
1996+
const bool use_group,
1997+
const int64_t moe_k,
1998+
MetaTensor* gate_prob_grad) {
1999+
auto gate_prob_dims = gate_prob.dims();
2000+
2001+
PADDLE_ENFORCE_EQ(
2002+
gate_prob.dtype(),
2003+
l_aux_loss_grad.dtype(),
2004+
errors::InvalidArgument(
2005+
"The input out_grad type should be equal to gate_prob type"));
2006+
2007+
gate_prob_grad->set_dims({gate_prob_dims});
2008+
gate_prob_grad->set_dtype(gate_prob.dtype());
2009+
}
2010+
2011+
void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights,
2012+
const MetaTensor& scatter_index,
2013+
const MetaTensor& expert_id,
2014+
const MetaTensor& y_grad,
2015+
const MetaTensor& combine_weights_grad,
2016+
const int64_t k,
2017+
const int64_t capacity,
2018+
const bool use_pad,
2019+
MetaTensor* x_grad,
2020+
MetaTensor* gate_logits_grad) {
2021+
auto combine_weights_dims = combine_weights.dims();
2022+
auto scatter_index_dims = scatter_index.dims();
2023+
auto expert_id_dims = expert_id.dims();
2024+
auto y_grad_dims = y_grad.dims();
2025+
auto combine_weights_grad_dims = combine_weights_grad.dims();
2026+
2027+
PADDLE_ENFORCE_EQ(combine_weights_dims.size(),
2028+
2,
2029+
errors::InvalidArgument(
2030+
"Input combine_weights should have 2 dimensions"));
2031+
2032+
PADDLE_ENFORCE_EQ(
2033+
scatter_index_dims.size(),
2034+
2,
2035+
errors::InvalidArgument("Input scatter_index should have 2 dimensions"));
2036+
2037+
PADDLE_ENFORCE_EQ(
2038+
expert_id_dims.size(),
2039+
2,
2040+
errors::InvalidArgument("Input expert_id should have 2 dimensions"));
2041+
2042+
PADDLE_ENFORCE_EQ(
2043+
y_grad_dims.size(),
2044+
2,
2045+
errors::InvalidArgument("Input y_grad should have 2 dimensions"));
2046+
2047+
PADDLE_ENFORCE_EQ(combine_weights_grad_dims.size(),
2048+
2,
2049+
errors::InvalidArgument(
2050+
"Input combine_weights_grad should have 2 dimensions"));
2051+
2052+
int64_t num_experts = y_grad_dims[0] / capacity;
2053+
int64_t hidden_size = y_grad_dims[1];
2054+
2055+
int64_t num_rows = scatter_index_dims[1];
2056+
2057+
gate_logits_grad->set_dims(common::make_ddim({num_rows, num_experts}));
2058+
gate_logits_grad->set_dtype(phi::DataType::FLOAT32);
2059+
2060+
x_grad->set_dims(common::make_ddim({num_rows, hidden_size}));
2061+
x_grad->set_dtype(y_grad.dtype());
2062+
}
2063+
void FusedRMSNormGradInferMeta(const MetaTensor& x,
2064+
const MetaTensor& scale,
2065+
const MetaTensor& invvar,
2066+
const MetaTensor& dy,
2067+
float epsilon,
2068+
MetaTensor* x_grad,
2069+
MetaTensor* scale_grad) {
2070+
x_grad->set_dims(x.dims());
2071+
x_grad->set_dtype(x.dtype());
2072+
scale_grad->set_dims(scale.dims());
2073+
scale_grad->set_dtype(scale.dtype());
2074+
}
18902075
} // namespace phi

paddle/phi/infermeta/backward.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,44 @@ void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
462462
MetaTensor* value_grad,
463463
MetaTensor* bias_grad);
464464

465+
void MoeCombineGradInferMeta(const MetaTensor& x,
466+
const MetaTensor& combine_weights,
467+
const MetaTensor& scatter_index,
468+
const MetaTensor& grad_y,
469+
MetaTensor* grad_x,
470+
MetaTensor* grad_combine_weights_helper);
471+
// Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev,
472+
// Tensor expert_offset, Tensor expert_offset_local, Tensor y_grad, Tensor
473+
// combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t
474+
// expert_start_index, int64_t expert_end_index)
475+
// output : Tensor(x_grad), Tensor(combine_weights_grad)
476+
void MoeGateDispatchPartialNoSoftmaxTopkGradInferMeta(
477+
const MetaTensor& combine_weights_out,
478+
const MetaTensor& scatter_index,
479+
const MetaTensor& scatter_index_rev,
480+
const MetaTensor& expert_offset,
481+
const MetaTensor& expert_offset_local,
482+
const MetaTensor& y_grad,
483+
const MetaTensor& combine_weights_out_grad,
484+
int64_t k,
485+
int64_t capacity,
486+
bool use_pad,
487+
int64_t expert_start_index,
488+
int64_t expert_end_index,
489+
MetaTensor* x_grad,
490+
MetaTensor* combine_weights_grad);
491+
492+
void MoeGateDispatchPermuteGradInferMeta(const MetaTensor& combine_weights,
493+
const MetaTensor& scatter_index,
494+
const MetaTensor& expert_id,
495+
const MetaTensor& y_grad,
496+
const MetaTensor& combine_weights_grad,
497+
int64_t k,
498+
int64_t capacity,
499+
int64_t world_size,
500+
MetaTensor* x_grad,
501+
MetaTensor* gate_logits_grad);
502+
465503
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
466504
const MetaTensor& out_grad,
467505
std::vector<MetaTensor*> x_grad);
@@ -680,4 +718,31 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
680718
MetaTensor* x_grad,
681719
MetaTensor* value_grad);
682720

721+
void CalAuxLossGradInferMeta(const MetaTensor& gate_prob,
722+
const MetaTensor& seqlen_float,
723+
const MetaTensor& ce,
724+
const MetaTensor& l_aux_loss_grad,
725+
const int64_t num_experts,
726+
const bool use_group,
727+
const int64_t moe_k,
728+
MetaTensor* gate_prob_grad);
729+
730+
void MoeGateDispatchGradInferMeta(const MetaTensor& combine_weights,
731+
const MetaTensor& scatter_index,
732+
const MetaTensor& expert_id,
733+
const MetaTensor& y_grad,
734+
const MetaTensor& combine_weights_grad,
735+
const int64_t k,
736+
const int64_t capacity,
737+
const bool use_pad,
738+
MetaTensor* x_grad,
739+
MetaTensor* gate_logits_grad);
740+
741+
void FusedRMSNormGradInferMeta(const MetaTensor& x,
742+
const MetaTensor& scale,
743+
const MetaTensor& invvar,
744+
const MetaTensor& dy,
745+
float epsilon,
746+
MetaTensor* x_grad,
747+
MetaTensor* scale_grad);
683748
} // namespace phi

paddle/phi/infermeta/binary.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4598,6 +4598,20 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
45984598
out->set_dtype(scale.dtype());
45994599
}
46004600

4601+
void FusedRMSNormInferMeta(const MetaTensor& x,
4602+
const MetaTensor& scale,
4603+
float epsilon,
4604+
MetaTensor* y,
4605+
MetaTensor* invvar) {
4606+
// Y: same shape, dtype, layout as X
4607+
y->set_dims(x.dims());
4608+
y->set_dtype(x.dtype());
4609+
// mean & invvar: 1-D length = x.dims()[0]
4610+
int64_t rows = x.dims()[0];
4611+
invvar->set_dims(DDim({rows}));
4612+
invvar->set_dtype(DataType::FLOAT32);
4613+
}
4614+
46014615
} // namespace phi
46024616

46034617
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);

paddle/phi/infermeta/binary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,5 +790,10 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
790790
const std::string& algo,
791791
const int32_t group_size,
792792
MetaTensor* out);
793+
void FusedRMSNormInferMeta(const MetaTensor& x,
794+
const MetaTensor& scale,
795+
float epsilon,
796+
MetaTensor* y,
797+
MetaTensor* invvar);
793798

794799
} // namespace phi

0 commit comments

Comments
 (0)