Skip to content

Commit ec5e4df

Browse files
authored
[XPU]Support complex for multiply (#72982)
* add multiply complex * add 0size * fix test
1 parent 213e076 commit ec5e4df

File tree

4 files changed

+345
-4
lines changed

4 files changed

+345
-4
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,17 @@ XPUOpMap& get_kl3_ops() {
460460
{"elementwise_mul_grad",
461461
XPUKernelSet({phi::DataType::FLOAT32,
462462
phi::DataType::FLOAT16,
463+
#ifdef PADDLE_WITH_XPU_FFT
464+
phi::DataType::COMPLEX64,
465+
#endif
463466
phi::DataType::BFLOAT16})},
464467
{"elementwise_mul",
465468
XPUKernelSet({phi::DataType::FLOAT32,
466469
phi::DataType::FLOAT16,
467470
phi::DataType::BFLOAT16,
471+
#ifdef PADDLE_WITH_XPU_FFT
472+
phi::DataType::COMPLEX64,
473+
#endif
468474
phi::DataType::INT32,
469475
phi::DataType::INT64})},
470476
{"elementwise_pow",

paddle/phi/kernels/xpu/elementwise_multiply_grad_kernel.cc

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
#include "paddle/phi/backends/xpu/xpu_context.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/complex_kernel.h"
23+
#include "paddle/phi/kernels/elementwise_add_kernel.h"
24+
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
25+
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
26+
#include "paddle/phi/kernels/expand_grad_kernel.h"
27+
#include "paddle/phi/kernels/full_kernel.h"
2228
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2329
#include "paddle/phi/kernels/xpu/elementwise.h"
2430

@@ -33,6 +39,25 @@ void MultiplyGradKernel(const Context& dev_ctx,
3339
DenseTensor* dx,
3440
DenseTensor* dy) {
3541
using XPUType = typename XPUTypeTrait<T>::Type;
42+
if (dout.numel() == 0) {
43+
if (dx) {
44+
if (dx->numel() == 0) {
45+
dev_ctx.template Alloc<T>(dx);
46+
} else {
47+
phi::Full<T, Context>(
48+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
49+
}
50+
}
51+
if (dy) {
52+
if (dy->numel() == 0) {
53+
dev_ctx.template Alloc<T>(dy);
54+
} else {
55+
phi::Full<T, Context>(
56+
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
57+
}
58+
}
59+
return;
60+
}
3661
funcs::ElementwiseGradPreProcess(dout, dx);
3762
auto f = [](xpu::Context* ctx,
3863
const XPUType* x,
@@ -50,6 +75,113 @@ void MultiplyGradKernel(const Context& dev_ctx,
5075
XPUElementwiseGrad<T, XPUType>(dev_ctx, x, y, dout, axis, dx, dy, f, true);
5176
}
5277

78+
#ifdef PADDLE_WITH_XPU_FFT
79+
template <>
80+
void MultiplyGradKernel<phi::dtype::complex<float>, XPUContext>(
81+
const XPUContext& dev_ctx,
82+
const DenseTensor& x,
83+
const DenseTensor& y,
84+
const DenseTensor& dout,
85+
int axis,
86+
DenseTensor* dx,
87+
DenseTensor* dy) {
88+
using T = phi::dtype::complex<float>;
89+
if (dout.numel() == 0) {
90+
if (dx) {
91+
if (dx->numel() == 0) {
92+
dev_ctx.template Alloc<T>(dx);
93+
} else {
94+
phi::Full<T, XPUContext>(
95+
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), T(0), dx);
96+
}
97+
}
98+
if (dy) {
99+
if (dy->numel() == 0) {
100+
dev_ctx.template Alloc<T>(dy);
101+
} else {
102+
phi::Full<T, XPUContext>(
103+
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), T(0), dy);
104+
}
105+
}
106+
return;
107+
}
108+
funcs::ElementwiseGradPreProcess(dout, dx);
109+
// The current complex number implementation uses separate real/imaginary
110+
// parts,resulting in redundant operations and performance
111+
// penalties.Optimization should address this in future iterations.
112+
DenseTensor dout_real = Real<T, XPUContext>(dev_ctx, dout);
113+
DenseTensor dout_imag = Imag<T, XPUContext>(dev_ctx, dout);
114+
115+
if (dx) {
116+
DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
117+
DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
118+
DenseTensor dx_real = Add<float, XPUContext>(
119+
dev_ctx,
120+
Multiply<float, XPUContext>(dev_ctx, dout_real, y_real),
121+
Multiply<float, XPUContext>(dev_ctx, dout_imag, y_imag));
122+
DenseTensor dx_imag = Subtract<float, XPUContext>(
123+
dev_ctx,
124+
Multiply<float, XPUContext>(dev_ctx, dout_imag, y_real),
125+
Multiply<float, XPUContext>(dev_ctx, dout_real, y_imag));
126+
dev_ctx.template Alloc<T>(dx);
127+
if (x.dims() == dout.dims()) {
128+
phi::ComplexKernel<float>(dev_ctx, dx_real, dx_imag, dx);
129+
} else {
130+
DenseTensor dx_real_expanded, dx_imag_expanded;
131+
dx_real_expanded.Resize(dx->dims());
132+
dx_imag_expanded.Resize(dx->dims());
133+
ExpandGradKernel<float, XPUContext>(
134+
dev_ctx,
135+
x,
136+
dx_real,
137+
phi::IntArray(phi::vectorize(x.dims())),
138+
&dx_real_expanded);
139+
ExpandGradKernel<float, XPUContext>(
140+
dev_ctx,
141+
x,
142+
dx_imag,
143+
phi::IntArray(phi::vectorize(x.dims())),
144+
&dx_imag_expanded);
145+
phi::ComplexKernel<float>(
146+
dev_ctx, dx_real_expanded, dx_imag_expanded, dx);
147+
}
148+
}
149+
if (dy) {
150+
DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
151+
DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
152+
DenseTensor dy_real = Add<float, XPUContext>(
153+
dev_ctx,
154+
Multiply<float, XPUContext>(dev_ctx, dout_real, x_real),
155+
Multiply<float, XPUContext>(dev_ctx, dout_imag, x_imag));
156+
DenseTensor dy_imag = Subtract<float, XPUContext>(
157+
dev_ctx,
158+
Multiply<float, XPUContext>(dev_ctx, dout_imag, x_real),
159+
Multiply<float, XPUContext>(dev_ctx, dout_real, x_imag));
160+
dev_ctx.template Alloc<T>(dy);
161+
if (y.dims() == dout.dims()) {
162+
phi::ComplexKernel<float>(dev_ctx, dy_real, dy_imag, dy);
163+
} else {
164+
DenseTensor dy_real_expanded, dy_imag_expanded;
165+
dy_real_expanded.Resize(dy->dims());
166+
dy_imag_expanded.Resize(dy->dims());
167+
ExpandGradKernel<float, XPUContext>(
168+
dev_ctx,
169+
y,
170+
dy_real,
171+
phi::IntArray(phi::vectorize(y.dims())),
172+
&dy_real_expanded);
173+
ExpandGradKernel<float, XPUContext>(
174+
dev_ctx,
175+
y,
176+
dy_imag,
177+
phi::IntArray(phi::vectorize(y.dims())),
178+
&dy_imag_expanded);
179+
phi::ComplexKernel<float>(
180+
dev_ctx, dy_real_expanded, dy_imag_expanded, dy);
181+
}
182+
}
183+
}
184+
#endif
53185
} // namespace phi
54186

55187
PD_REGISTER_KERNEL(multiply_grad,
@@ -58,4 +190,8 @@ PD_REGISTER_KERNEL(multiply_grad,
58190
phi::MultiplyGradKernel,
59191
phi::dtype::float16,
60192
phi::dtype::bfloat16,
61-
float) {}
193+
#ifdef PADDLE_WITH_XPU_FFT
194+
phi::dtype::complex<float>,
195+
#endif
196+
float) {
197+
}

paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "paddle/phi/backends/xpu/xpu_context.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/complex_kernel.h"
23+
#include "paddle/phi/kernels/elementwise_add_kernel.h"
24+
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
2225
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2326
#include "paddle/phi/kernels/xpu/elementwise.h"
2427

@@ -30,6 +33,10 @@ void MultiplyKernel(const Context& dev_ctx,
3033
const DenseTensor& y,
3134
DenseTensor* out) {
3235
using XPUType = typename XPUTypeTrait<T>::Type;
36+
if (out->numel() == 0) {
37+
dev_ctx.template Alloc<T>(out);
38+
return;
39+
}
3340
auto f = [](xpu::Context* ctx,
3441
const XPUType* x,
3542
const XPUType* y,
@@ -42,6 +49,37 @@ void MultiplyKernel(const Context& dev_ctx,
4249
XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
4350
}
4451

52+
#ifdef PADDLE_WITH_XPU_FFT
53+
template <>
54+
void MultiplyKernel<phi::dtype::complex<float>, XPUContext>(
55+
const XPUContext& dev_ctx,
56+
const DenseTensor& x,
57+
const DenseTensor& y,
58+
DenseTensor* out) {
59+
using T = phi::dtype::complex<float>;
60+
if (out->numel() == 0) {
61+
dev_ctx.template Alloc<T>(out);
62+
return;
63+
}
64+
// The current complex number implementation uses separate real/imaginary
65+
// parts,resulting in redundant operations and performance
66+
// penalties.Optimization should address this in future iterations.
67+
const DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
68+
const DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
69+
const DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
70+
const DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
71+
DenseTensor real_out = Subtract<float, XPUContext>(
72+
dev_ctx,
73+
Multiply<float, XPUContext>(dev_ctx, x_real, y_real),
74+
Multiply<float, XPUContext>(dev_ctx, x_imag, y_imag));
75+
DenseTensor imag_out = Add<float, XPUContext>(
76+
dev_ctx,
77+
Multiply<float, XPUContext>(dev_ctx, x_real, y_imag),
78+
Multiply<float, XPUContext>(dev_ctx, x_imag, y_real));
79+
phi::ComplexKernel<float>(dev_ctx, real_out, imag_out, out);
80+
}
81+
#endif
82+
4583
} // namespace phi
4684

4785
PD_REGISTER_KERNEL(multiply,
@@ -50,6 +88,10 @@ PD_REGISTER_KERNEL(multiply,
5088
phi::MultiplyKernel,
5189
phi::dtype::float16,
5290
phi::dtype::bfloat16,
91+
#ifdef PADDLE_WITH_XPU_FFT
92+
phi::dtype::complex<float>,
93+
#endif
5394
float,
5495
int,
55-
int64_t) {}
96+
int64_t) {
97+
}

0 commit comments

Comments
 (0)