Skip to content

Commit 36a450f

Browse files
committed
Add missing yamls
1 parent d4a3472 commit 36a450f

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

paddle/phi/ops/yaml/backward.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,6 +2280,16 @@
22802280
kernel :
22812281
func : moe_combine_grad
22822282

2283+
- backward_op : moe_gate_dispatch_grad
2284+
forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)
2285+
args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad)
2286+
output : Tensor(x_grad), Tensor(gate_logits_grad)
2287+
infer_meta :
2288+
func : MoeGateDispatchGradInferMeta
2289+
kernel :
2290+
func : moe_gate_dispatch_grad
2291+
data_type : y_grad
2292+
22832293
- backward_op : moe_gate_dispatch_partial_nosoftmaxtopk_grad
22842294
forward : moe_gate_dispatch_partial_nosoftmaxtopk (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop) -> Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local)
22852295
args : (Tensor combine_weights_out, Tensor scatter_index, Tensor scatter_index_rev, Tensor expert_offset, Tensor expert_nums_local, Tensor y_grad, Tensor combine_weights_out_grad, int64_t k, int64_t capacity, bool use_pad, int64_t expert_start_index, int64_t expert_end_index)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,6 +3637,17 @@
36373637
data_type : x
36383638
backward : moe_combine_grad
36393639

3640+
- op : moe_gate_dispatch
3641+
args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad)
3642+
output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)
3643+
infer_meta :
3644+
func : MoeGateDispatchInferMeta
3645+
kernel :
3646+
func : moe_gate_dispatch
3647+
data_type : x
3648+
optional : corr_bias
3649+
backward : moe_gate_dispatch_grad
3650+
36403651
- op : moe_gate_dispatch_partial_nosoftmaxtopk
36413652
args : (Tensor x, Tensor combine_weights, Tensor expert_id, int64_t k, int64_t capacity, int64_t num_experts, bool use_pad, int64_t expert_start_index, int64_t expert_end_index, bool reverse_token_drop)
36423653
output : Tensor(y), Tensor(combine_weights_out), Tensor(scatter_index), Tensor(scatter_index_rev), Tensor(expert_offset), Tensor(expert_nums_local)

0 commit comments

Comments
 (0)