19
19
20
20
#include " paddle/phi/backends/xpu/xpu_context.h"
21
21
#include " paddle/phi/core/kernel_registry.h"
22
+ #include " paddle/phi/kernels/complex_kernel.h"
23
+ #include " paddle/phi/kernels/elementwise_add_kernel.h"
24
+ #include " paddle/phi/kernels/elementwise_multiply_kernel.h"
25
+ #include " paddle/phi/kernels/elementwise_subtract_kernel.h"
26
+ #include " paddle/phi/kernels/expand_grad_kernel.h"
27
+ #include " paddle/phi/kernels/full_kernel.h"
22
28
#include " paddle/phi/kernels/funcs/elementwise_base.h"
23
29
#include " paddle/phi/kernels/xpu/elementwise.h"
24
30
@@ -33,6 +39,25 @@ void MultiplyGradKernel(const Context& dev_ctx,
33
39
DenseTensor* dx,
34
40
DenseTensor* dy) {
35
41
using XPUType = typename XPUTypeTrait<T>::Type;
42
+ if (dout.numel () == 0 ) {
43
+ if (dx) {
44
+ if (dx->numel () == 0 ) {
45
+ dev_ctx.template Alloc <T>(dx);
46
+ } else {
47
+ phi::Full<T, Context>(
48
+ dev_ctx, phi::IntArray (common::vectorize (dx->dims ())), 0 , dx);
49
+ }
50
+ }
51
+ if (dy) {
52
+ if (dy->numel () == 0 ) {
53
+ dev_ctx.template Alloc <T>(dy);
54
+ } else {
55
+ phi::Full<T, Context>(
56
+ dev_ctx, phi::IntArray (common::vectorize (dy->dims ())), 0 , dy);
57
+ }
58
+ }
59
+ return ;
60
+ }
36
61
funcs::ElementwiseGradPreProcess (dout, dx);
37
62
auto f = [](xpu::Context* ctx,
38
63
const XPUType* x,
@@ -50,6 +75,113 @@ void MultiplyGradKernel(const Context& dev_ctx,
50
75
XPUElementwiseGrad<T, XPUType>(dev_ctx, x, y, dout, axis, dx, dy, f, true );
51
76
}
52
77
78
+ #ifdef PADDLE_WITH_XPU_FFT
79
+ template <>
80
+ void MultiplyGradKernel<phi::dtype::complex<float >, XPUContext>(
81
+ const XPUContext& dev_ctx,
82
+ const DenseTensor& x,
83
+ const DenseTensor& y,
84
+ const DenseTensor& dout,
85
+ int axis,
86
+ DenseTensor* dx,
87
+ DenseTensor* dy) {
88
+ using T = phi::dtype::complex<float >;
89
+ if (dout.numel () == 0 ) {
90
+ if (dx) {
91
+ if (dx->numel () == 0 ) {
92
+ dev_ctx.template Alloc <T>(dx);
93
+ } else {
94
+ phi::Full<T, XPUContext>(
95
+ dev_ctx, phi::IntArray (common::vectorize (dx->dims ())), T (0 ), dx);
96
+ }
97
+ }
98
+ if (dy) {
99
+ if (dy->numel () == 0 ) {
100
+ dev_ctx.template Alloc <T>(dy);
101
+ } else {
102
+ phi::Full<T, XPUContext>(
103
+ dev_ctx, phi::IntArray (common::vectorize (dy->dims ())), T (0 ), dy);
104
+ }
105
+ }
106
+ return ;
107
+ }
108
+ funcs::ElementwiseGradPreProcess (dout, dx);
109
+ // The current complex number implementation uses separate real/imaginary
110
+ // parts,resulting in redundant operations and performance
111
+ // penalties.Optimization should address this in future iterations.
112
+ DenseTensor dout_real = Real<T, XPUContext>(dev_ctx, dout);
113
+ DenseTensor dout_imag = Imag<T, XPUContext>(dev_ctx, dout);
114
+
115
+ if (dx) {
116
+ DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
117
+ DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
118
+ DenseTensor dx_real = Add<float , XPUContext>(
119
+ dev_ctx,
120
+ Multiply<float , XPUContext>(dev_ctx, dout_real, y_real),
121
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, y_imag));
122
+ DenseTensor dx_imag = Subtract<float , XPUContext>(
123
+ dev_ctx,
124
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, y_real),
125
+ Multiply<float , XPUContext>(dev_ctx, dout_real, y_imag));
126
+ dev_ctx.template Alloc <T>(dx);
127
+ if (x.dims () == dout.dims ()) {
128
+ phi::ComplexKernel<float >(dev_ctx, dx_real, dx_imag, dx);
129
+ } else {
130
+ DenseTensor dx_real_expanded, dx_imag_expanded;
131
+ dx_real_expanded.Resize (dx->dims ());
132
+ dx_imag_expanded.Resize (dx->dims ());
133
+ ExpandGradKernel<float , XPUContext>(
134
+ dev_ctx,
135
+ x,
136
+ dx_real,
137
+ phi::IntArray (phi::vectorize (x.dims ())),
138
+ &dx_real_expanded);
139
+ ExpandGradKernel<float , XPUContext>(
140
+ dev_ctx,
141
+ x,
142
+ dx_imag,
143
+ phi::IntArray (phi::vectorize (x.dims ())),
144
+ &dx_imag_expanded);
145
+ phi::ComplexKernel<float >(
146
+ dev_ctx, dx_real_expanded, dx_imag_expanded, dx);
147
+ }
148
+ }
149
+ if (dy) {
150
+ DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
151
+ DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
152
+ DenseTensor dy_real = Add<float , XPUContext>(
153
+ dev_ctx,
154
+ Multiply<float , XPUContext>(dev_ctx, dout_real, x_real),
155
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, x_imag));
156
+ DenseTensor dy_imag = Subtract<float , XPUContext>(
157
+ dev_ctx,
158
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, x_real),
159
+ Multiply<float , XPUContext>(dev_ctx, dout_real, x_imag));
160
+ dev_ctx.template Alloc <T>(dy);
161
+ if (y.dims () == dout.dims ()) {
162
+ phi::ComplexKernel<float >(dev_ctx, dy_real, dy_imag, dy);
163
+ } else {
164
+ DenseTensor dy_real_expanded, dy_imag_expanded;
165
+ dy_real_expanded.Resize (dy->dims ());
166
+ dy_imag_expanded.Resize (dy->dims ());
167
+ ExpandGradKernel<float , XPUContext>(
168
+ dev_ctx,
169
+ y,
170
+ dy_real,
171
+ phi::IntArray (phi::vectorize (y.dims ())),
172
+ &dy_real_expanded);
173
+ ExpandGradKernel<float , XPUContext>(
174
+ dev_ctx,
175
+ y,
176
+ dy_imag,
177
+ phi::IntArray (phi::vectorize (y.dims ())),
178
+ &dy_imag_expanded);
179
+ phi::ComplexKernel<float >(
180
+ dev_ctx, dy_real_expanded, dy_imag_expanded, dy);
181
+ }
182
+ }
183
+ }
184
+ #endif
53
185
} // namespace phi
54
186
55
187
PD_REGISTER_KERNEL (multiply_grad,
@@ -58,4 +190,8 @@ PD_REGISTER_KERNEL(multiply_grad,
58
190
phi::MultiplyGradKernel,
59
191
phi::dtype::float16,
60
192
phi::dtype::bfloat16,
61
- float ) {}
193
+ #ifdef PADDLE_WITH_XPU_FFT
194
+ phi::dtype::complex<float >,
195
+ #endif
196
+ float ) {
197
+ }
0 commit comments