Skip to content

Commit ed8e4df

Browse files
committed
add multiply complex
1 parent f1b794d commit ed8e4df

File tree

3 files changed

+139
-2
lines changed

3 files changed

+139
-2
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,17 @@ XPUOpMap& get_kl3_ops() {
460460
{"elementwise_mul_grad",
461461
XPUKernelSet({phi::DataType::FLOAT32,
462462
phi::DataType::FLOAT16,
463+
#ifdef PADDLE_WITH_XPU_FFT
464+
phi::DataType::COMPLEX64,
465+
#endif
463466
phi::DataType::BFLOAT16})},
464467
{"elementwise_mul",
465468
XPUKernelSet({phi::DataType::FLOAT32,
466469
phi::DataType::FLOAT16,
467470
phi::DataType::BFLOAT16,
471+
#ifdef PADDLE_WITH_XPU_FFT
472+
phi::DataType::COMPLEX64,
473+
#endif
468474
phi::DataType::INT32,
469475
phi::DataType::INT64})},
470476
{"elementwise_pow",

paddle/phi/kernels/xpu/elementwise_multiply_grad_kernel.cc

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
#include "paddle/phi/backends/xpu/xpu_context.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/complex_kernel.h"
23+
#include "paddle/phi/kernels/elementwise_add_kernel.h"
24+
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
25+
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
26+
#include "paddle/phi/kernels/expand_grad_kernel.h"
2227
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2328
#include "paddle/phi/kernels/xpu/elementwise.h"
2429

@@ -50,6 +55,94 @@ void MultiplyGradKernel(const Context& dev_ctx,
5055
XPUElementwiseGrad<T, XPUType>(dev_ctx, x, y, dout, axis, dx, dy, f, true);
5156
}
5257

58+
#ifdef PADDLE_WITH_XPU_FFT
59+
template <>
60+
void MultiplyGradKernel<phi::dtype::complex<float>, XPUContext>(
61+
const XPUContext& dev_ctx,
62+
const DenseTensor& x,
63+
const DenseTensor& y,
64+
const DenseTensor& dout,
65+
int axis,
66+
DenseTensor* dx,
67+
DenseTensor* dy) {
68+
using T = phi::dtype::complex<float>;
69+
funcs::ElementwiseGradPreProcess(dout, dx);
70+
// The current complex number implementation uses separate real/imaginary
71+
// parts,resulting in redundant operations and performance
72+
// penalties.Optimization should address this in future iterations.
73+
DenseTensor dout_real = Real<T, XPUContext>(dev_ctx, dout);
74+
DenseTensor dout_imag = Imag<T, XPUContext>(dev_ctx, dout);
75+
76+
if (dx) {
77+
DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
78+
DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
79+
DenseTensor dx_real = Add<float, XPUContext>(
80+
dev_ctx,
81+
Multiply<float, XPUContext>(dev_ctx, dout_real, y_real),
82+
Multiply<float, XPUContext>(dev_ctx, dout_imag, y_imag));
83+
DenseTensor dx_imag = Subtract<float, XPUContext>(
84+
dev_ctx,
85+
Multiply<float, XPUContext>(dev_ctx, dout_imag, y_real),
86+
Multiply<float, XPUContext>(dev_ctx, dout_real, y_imag));
87+
dev_ctx.template Alloc<T>(dx);
88+
if (x.dims() == dout.dims()) {
89+
phi::ComplexKernel<float>(dev_ctx, dx_real, dx_imag, dx);
90+
} else {
91+
DenseTensor dx_real_expanded, dx_imag_expanded;
92+
dx_real_expanded.Resize(dx->dims());
93+
dx_imag_expanded.Resize(dx->dims());
94+
ExpandGradKernel<float, XPUContext>(
95+
dev_ctx,
96+
x,
97+
dx_real,
98+
phi::IntArray(phi::vectorize(x.dims())),
99+
&dx_real_expanded);
100+
ExpandGradKernel<float, XPUContext>(
101+
dev_ctx,
102+
x,
103+
dx_imag,
104+
phi::IntArray(phi::vectorize(x.dims())),
105+
&dx_imag_expanded);
106+
phi::ComplexKernel<float>(
107+
dev_ctx, dx_real_expanded, dx_imag_expanded, dx);
108+
}
109+
}
110+
if (dy) {
111+
DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
112+
DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
113+
DenseTensor dy_real = Add<float, XPUContext>(
114+
dev_ctx,
115+
Multiply<float, XPUContext>(dev_ctx, dout_real, x_real),
116+
Multiply<float, XPUContext>(dev_ctx, dout_imag, x_imag));
117+
DenseTensor dy_imag = Subtract<float, XPUContext>(
118+
dev_ctx,
119+
Multiply<float, XPUContext>(dev_ctx, dout_imag, x_real),
120+
Multiply<float, XPUContext>(dev_ctx, dout_real, x_imag));
121+
dev_ctx.template Alloc<T>(dy);
122+
if (y.dims() == dout.dims()) {
123+
phi::ComplexKernel<float>(dev_ctx, dy_real, dy_imag, dy);
124+
} else {
125+
DenseTensor dy_real_expanded, dy_imag_expanded;
126+
dy_real_expanded.Resize(dy->dims());
127+
dy_imag_expanded.Resize(dy->dims());
128+
ExpandGradKernel<float, XPUContext>(
129+
dev_ctx,
130+
y,
131+
dy_real,
132+
phi::IntArray(phi::vectorize(y.dims())),
133+
&dy_real_expanded);
134+
ExpandGradKernel<float, XPUContext>(
135+
dev_ctx,
136+
y,
137+
dy_imag,
138+
phi::IntArray(phi::vectorize(y.dims())),
139+
&dy_imag_expanded);
140+
phi::ComplexKernel<float>(
141+
dev_ctx, dy_real_expanded, dy_imag_expanded, dy);
142+
}
143+
}
144+
}
145+
#endif
53146
} // namespace phi
54147

55148
PD_REGISTER_KERNEL(multiply_grad,
@@ -58,4 +151,8 @@ PD_REGISTER_KERNEL(multiply_grad,
58151
phi::MultiplyGradKernel,
59152
phi::dtype::float16,
60153
phi::dtype::bfloat16,
61-
float) {}
154+
#ifdef PADDLE_WITH_XPU_FFT
155+
phi::dtype::complex<float>,
156+
#endif
157+
float) {
158+
}

paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "paddle/phi/backends/xpu/xpu_context.h"
2121
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/complex_kernel.h"
23+
#include "paddle/phi/kernels/elementwise_add_kernel.h"
24+
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
2225
#include "paddle/phi/kernels/funcs/elementwise_base.h"
2326
#include "paddle/phi/kernels/xpu/elementwise.h"
2427

@@ -42,6 +45,33 @@ void MultiplyKernel(const Context& dev_ctx,
4245
XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
4346
}
4447

48+
#ifdef PADDLE_WITH_XPU_FFT
49+
template <>
50+
void MultiplyKernel<phi::dtype::complex<float>, XPUContext>(
51+
const XPUContext& dev_ctx,
52+
const DenseTensor& x,
53+
const DenseTensor& y,
54+
DenseTensor* out) {
55+
using T = phi::dtype::complex<float>;
56+
// The current complex number implementation uses separate real/imaginary
57+
// parts,resulting in redundant operations and performance
58+
// penalties.Optimization should address this in future iterations.
59+
const DenseTensor x_real = Real<T, XPUContext>(dev_ctx, x);
60+
const DenseTensor x_imag = Imag<T, XPUContext>(dev_ctx, x);
61+
const DenseTensor y_real = Real<T, XPUContext>(dev_ctx, y);
62+
const DenseTensor y_imag = Imag<T, XPUContext>(dev_ctx, y);
63+
DenseTensor real_out = Subtract<float, XPUContext>(
64+
dev_ctx,
65+
Multiply<float, XPUContext>(dev_ctx, x_real, y_real),
66+
Multiply<float, XPUContext>(dev_ctx, x_imag, y_imag));
67+
DenseTensor imag_out = Add<float, XPUContext>(
68+
dev_ctx,
69+
Multiply<float, XPUContext>(dev_ctx, x_real, y_imag),
70+
Multiply<float, XPUContext>(dev_ctx, x_imag, y_real));
71+
phi::ComplexKernel<float>(dev_ctx, real_out, imag_out, out);
72+
}
73+
#endif
74+
4575
} // namespace phi
4676

4777
PD_REGISTER_KERNEL(multiply,
@@ -50,6 +80,10 @@ PD_REGISTER_KERNEL(multiply,
5080
phi::MultiplyKernel,
5181
phi::dtype::float16,
5282
phi::dtype::bfloat16,
83+
#ifdef PADDLE_WITH_XPU_FFT
84+
phi::dtype::complex<float>,
85+
#endif
5386
float,
5487
int,
55-
int64_t) {}
88+
int64_t) {
89+
}

0 commit comments

Comments
 (0)