19
19
20
20
#include " paddle/phi/backends/xpu/xpu_context.h"
21
21
#include " paddle/phi/core/kernel_registry.h"
22
+ #include " paddle/phi/kernels/complex_kernel.h"
23
+ #include " paddle/phi/kernels/elementwise_add_kernel.h"
24
+ #include " paddle/phi/kernels/elementwise_multiply_kernel.h"
25
+ #include " paddle/phi/kernels/elementwise_subtract_kernel.h"
26
+ #include " paddle/phi/kernels/expand_grad_kernel.h"
22
27
#include " paddle/phi/kernels/funcs/elementwise_base.h"
23
28
#include " paddle/phi/kernels/xpu/elementwise.h"
24
29
@@ -50,6 +55,94 @@ void MultiplyGradKernel(const Context& dev_ctx,
50
55
XPUElementwiseGrad<T, XPUType>(dev_ctx, x, y, dout, axis, dx, dy, f, true );
51
56
}
52
57
58
+ #ifdef PADDLE_WITH_XPU_FFT
59
+ template <>
60
+ void MultiplyGradKernel<phi::dtype::complex<float >, XPUContext>(
61
+ const XPUContext& dev_ctx,
62
+ const DenseTensor& x,
63
+ const DenseTensor& y,
64
+ const DenseTensor& dout,
65
+ int axis,
66
+ DenseTensor* dx,
67
+ DenseTensor* dy) {
68
+ using T = phi::dtype::complex<float >;
69
+ funcs::ElementwiseGradPreProcess (dout, dx);
70
+ // The current complex number implementation uses separate real/imaginary
71
+ // parts,resulting in redundant operations and performance
72
+ // penalties.Optimization should address this in future iterations.
73
+ DenseTensor dout_real = Real<T, XPUContext>(dev_ctx, dout);
74
+ DenseTensor dout_imag = Imag<T, XPUContext>(dev_ctx, dout);
75
+
76
+ if (dx) {
77
+ DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
78
+ DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
79
+ DenseTensor dx_real = Add<float , XPUContext>(
80
+ dev_ctx,
81
+ Multiply<float , XPUContext>(dev_ctx, dout_real, y_real),
82
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, y_imag));
83
+ DenseTensor dx_imag = Subtract<float , XPUContext>(
84
+ dev_ctx,
85
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, y_real),
86
+ Multiply<float , XPUContext>(dev_ctx, dout_real, y_imag));
87
+ dev_ctx.template Alloc <T>(dx);
88
+ if (x.dims () == dout.dims ()) {
89
+ phi::ComplexKernel<float >(dev_ctx, dx_real, dx_imag, dx);
90
+ } else {
91
+ DenseTensor dx_real_expanded, dx_imag_expanded;
92
+ dx_real_expanded.Resize (dx->dims ());
93
+ dx_imag_expanded.Resize (dx->dims ());
94
+ ExpandGradKernel<float , XPUContext>(
95
+ dev_ctx,
96
+ x,
97
+ dx_real,
98
+ phi::IntArray (phi::vectorize (x.dims ())),
99
+ &dx_real_expanded);
100
+ ExpandGradKernel<float , XPUContext>(
101
+ dev_ctx,
102
+ x,
103
+ dx_imag,
104
+ phi::IntArray (phi::vectorize (x.dims ())),
105
+ &dx_imag_expanded);
106
+ phi::ComplexKernel<float >(
107
+ dev_ctx, dx_real_expanded, dx_imag_expanded, dx);
108
+ }
109
+ }
110
+ if (dy) {
111
+ DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
112
+ DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
113
+ DenseTensor dy_real = Add<float , XPUContext>(
114
+ dev_ctx,
115
+ Multiply<float , XPUContext>(dev_ctx, dout_real, x_real),
116
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, x_imag));
117
+ DenseTensor dy_imag = Subtract<float , XPUContext>(
118
+ dev_ctx,
119
+ Multiply<float , XPUContext>(dev_ctx, dout_imag, x_real),
120
+ Multiply<float , XPUContext>(dev_ctx, dout_real, x_imag));
121
+ dev_ctx.template Alloc <T>(dy);
122
+ if (y.dims () == dout.dims ()) {
123
+ phi::ComplexKernel<float >(dev_ctx, dy_real, dy_imag, dy);
124
+ } else {
125
+ DenseTensor dy_real_expanded, dy_imag_expanded;
126
+ dy_real_expanded.Resize (dy->dims ());
127
+ dy_imag_expanded.Resize (dy->dims ());
128
+ ExpandGradKernel<float , XPUContext>(
129
+ dev_ctx,
130
+ y,
131
+ dy_real,
132
+ phi::IntArray (phi::vectorize (y.dims ())),
133
+ &dy_real_expanded);
134
+ ExpandGradKernel<float , XPUContext>(
135
+ dev_ctx,
136
+ y,
137
+ dy_imag,
138
+ phi::IntArray (phi::vectorize (y.dims ())),
139
+ &dy_imag_expanded);
140
+ phi::ComplexKernel<float >(
141
+ dev_ctx, dy_real_expanded, dy_imag_expanded, dy);
142
+ }
143
+ }
144
+ }
145
+ #endif
53
146
} // namespace phi
54
147
55
148
PD_REGISTER_KERNEL (multiply_grad,
@@ -58,4 +151,8 @@ PD_REGISTER_KERNEL(multiply_grad,
58
151
phi::MultiplyGradKernel,
59
152
phi::dtype::float16,
60
153
phi::dtype::bfloat16,
61
- float ) {}
154
+ #ifdef PADDLE_WITH_XPU_FFT
155
+ phi::dtype::complex<float >,
156
+ #endif
157
+ float ) {
158
+ }
0 commit comments