Skip to content

Commit 422850e

Browse files
committed
fix moe_combine bug.
1 parent 89402c3 commit 422850e

File tree

6 files changed

+10
-11
lines changed

6 files changed

+10
-11
lines changed

paddle/phi/kernels/gpu/moe_combine_grad_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "paddle/phi/backends/gpu/gpu_context.h"
22
#include "paddle/phi/core/kernel_registry.h"
33
#include "paddle/phi/kernels/moe_combine_grad_kernel.h"
4+
#include "paddle/phi/kernels/full_kernel.h"
45
namespace phi {
56

67
template <typename T>
@@ -129,6 +130,8 @@ void MoeCombineGradKernel(const Context& dev_ctx,
129130
DenseTensor* grad_combine_weights_helper) {
130131
dev_ctx.template Alloc<T>(grad_x);
131132
dev_ctx.template Alloc<T>(grad_combine_weights_helper);
133+
phi::Full<T, Context>(dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
134+
phi::Full<T, Context>(dev_ctx, phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())), 0, grad_combine_weights_helper);
132135
auto x_shape = x.dims();
133136
auto combine_weights_shape = combine_weights.dims();
134137
moe_combine_bwd<T, Context>(dev_ctx,

paddle/phi/kernels/gpu/moe_combine_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "paddle/phi/kernels/moe_combine_kernel.h"
22
#include "paddle/phi/backends/gpu/gpu_context.h"
33
#include "paddle/phi/core/kernel_registry.h"
4+
#include "paddle/phi/kernels/full_kernel.h"
45

56
namespace phi {
67

@@ -91,6 +92,7 @@ void moe_combine_fwd(const Context& dev_ctx,
9192
DenseTensor* y) {
9293
dev_ctx.template Alloc<T>(y); // T cannot support phi::dtype::float8 very
9394
// well, maybe replaced with x.dtype();
95+
phi::Full<T, Context>(dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
9496
auto combine_weights_shape = combine_weights.dims();
9597
auto x_shape = x.dims();
9698
moe_combine_fwd<T, Context>(dev_ctx,

paddle/phi/kernels/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(const Context& dev_ctx,
110110
int64_t expert_end_index,
111111
DenseTensor* x_grad,
112112
DenseTensor* combine_weights_grad){
113-
printf("MoeGateDispatchPartialNoSoftMaxTopkGradKernel begin\n");
114113
dev_ctx.template Alloc<T>(x_grad);
115114
dev_ctx.template Alloc<float>(combine_weights_grad);
116115
// DenseTensor t_scatter_index;

test/legacy_test/ernie_utils/moe_layer_uneven.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def forward(ctx, x, combine_weights, scatter_index):
218218

219219
@staticmethod
220220
def backward(ctx, grad_y, *_):
221-
'''
222221
"""
223222
Input:
224223
grad_y: [seqlen, hidden_size]
@@ -243,10 +242,9 @@ def backward(ctx, grad_y, *_):
243242
# grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim]
244243
# reduce the hidden shape
245244
# TODO: implement reduce in cuda ops
246-
#grad_combine_weight = grad_combine_weight_helper.sum(-1)
247-
#return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None
248-
return grad_x, grad_combine_weight_helper
249-
'''
245+
grad_combine_weight = grad_combine_weight_helper.sum(-1)
246+
return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None
247+
#return grad_x, grad_combine_weight_helper
250248

251249

252250

test/legacy_test/test_incubate_moe_combine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_moe_combine(x_numpy, combine_weights_numpy, scatter_index_numpy, grad_n
6666
grad = paddle.to_tensor(grad_numpy).cast("float32")
6767

6868
y = GateCombine.apply(x, combine_weights, scatter_index)
69-
#paddle.autograd.backward([y], [grad], True)
70-
grad.backward()
69+
paddle.autograd.backward([y], [grad], True)
70+
#grad.backward()
7171
return [x.grad, combine_weights.grad, y]
7272

7373

test/legacy_test/test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
def test_moe_dispatch_partial_nosoftmaxtopk_nonepad_op():
15-
import moe_ops_partial_nosoftmaxtopk
1615

1716
s, d, e = 4, 100, 8
1817
k, cap = 4, 3
@@ -137,7 +136,6 @@ def check_ascend(index_rev, chunks):
137136

138137

139138
def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop():
140-
import moe_ops_partial_nosoftmaxtopk
141139

142140
S, E, D = 3, 4, 3
143141
k = 2
@@ -162,7 +160,6 @@ def test_moe_ops_partial_nosoftmaxtopk_w_reverse_token_drop():
162160

163161

164162
def test_moe_ops_partial_nosoftmax_topk_empty_output():
165-
import moe_ops_partial_nosoftmaxtopk
166163

167164
S, E, D = 3, 4, 3
168165
k = 2

0 commit comments

Comments
 (0)