Skip to content

Commit 4b32bc2

Browse files
committed
add test
1 parent 8337334 commit 4b32bc2

File tree

3 files changed

+133
-11
lines changed

3 files changed

+133
-11
lines changed

paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,12 +1444,13 @@ struct PowGradDX<phi::dtype::complex<T>> {
14441444
phi::dtype::complex<T> out,
14451445
phi::dtype::complex<T> dout) const {
14461446
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
1447-
return dout * y * pow(x, y - phi::dtype::complex<T>(1, 0));
1447+
return conj(dout * y * pow(x, y - phi::dtype::complex<T>(1, 0)));
14481448
#else
1449-
return dout * y *
1450-
static_cast<phi::dtype::complex<T>>(std::pow(
1451-
static_cast<std::complex<T>>(x),
1452-
static_cast<std::complex<T>>(y - phi::dtype::complex<T>(1, 0))));
1449+
return conj(
1450+
dout * y *
1451+
static_cast<phi::dtype::complex<T>>(std::pow(
1452+
static_cast<std::complex<T>>(x),
1453+
static_cast<std::complex<T>>(y - phi::dtype::complex<T>(1, 0)))));
14531454
#endif
14541455
}
14551456
};
@@ -1462,12 +1463,12 @@ struct PowGradDY<phi::dtype::complex<T>> {
14621463
phi::dtype::complex<T> out,
14631464
phi::dtype::complex<T> dout) const {
14641465
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
1465-
return dout * log(x) * pow(x, y);
1466+
return conj(dout * log(x) * pow(x, y));
14661467
#else
1467-
return dout * static_cast<phi::dtype::complex<T>>(
1468-
std::log(static_cast<std::complex<T>>(x)) *
1469-
std::pow(static_cast<std::complex<T>>(x),
1470-
static_cast<std::complex<T>>(y)));
1468+
return conj(dout * static_cast<phi::dtype::complex<T>>(
1469+
std::log(static_cast<std::complex<T>>(x)) *
1470+
std::pow(static_cast<std::complex<T>>(x),
1471+
static_cast<std::complex<T>>(y))));
14711472
#endif
14721473
}
14731474
};

python/paddle/tensor/math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def pow(x: Tensor, y: float | Tensor, name: str | None = None) -> Tensor:
531531
532532
533533
Args:
534-
x (Tensor): An N-D Tensor, the data type is bfloat16, float16, float32, float64, int32 or int64.
534+
x (Tensor): An N-D Tensor, the data type is bfloat16, float16, float32, float64, int32, int64, complex64 or complex128.
535535
y (float|int|Tensor): If it is an N-D Tensor, its data type should be the same as `x`.
536536
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
537537
@@ -638,6 +638,7 @@ def _elementwise_op(helper):
638638
"elementwise_mul",
639639
"elementwise_div",
640640
"elementwise_max",
641+
"elementwise_pow",
641642
]
642643
if original_op_type in bf16_and_complex_supported_ops:
643644
data_type = [

test/legacy_test/test_elementwise_pow_op.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,126 @@ def test_grad(self):
249249
np.testing.assert_array_equal(y.gradient(), self.grad_y)
250250

251251

252+
@unittest.skipIf(
253+
core.is_compiled_with_xpu(),
254+
"Skip XPU for complex dtype is not fully supported",
255+
)
256+
class TestElementwisePowComplexOp(OpTest):
257+
def setUp(self):
258+
self.op_type = "elementwise_pow"
259+
self.python_api = paddle.pow
260+
self.public_python_api = paddle.pow
261+
self.prim_op_type = "prim"
262+
263+
self.inputs = {
264+
'X': np.asarray([1 + 2j, 3 + 4j, 5 + 6j]),
265+
'Y': np.asarray([2.0, 3.0, 4.0]),
266+
}
267+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
268+
269+
def _get_places(self):
270+
places = [base.CPUPlace()]
271+
if core.is_compiled_with_cuda():
272+
places.append(base.CUDAPlace(0))
273+
return places
274+
275+
def test_check_output(self):
276+
if hasattr(self, 'attrs'):
277+
self.check_output(check_dygraph=False)
278+
else:
279+
self.check_output(check_pir=True, check_symbol_infer=False)
280+
281+
def test_check_grad_normal(self):
282+
if hasattr(self, 'attrs'):
283+
self.check_grad(['X', 'Y'], 'Out', check_dygraph=False)
284+
else:
285+
self.check_grad(
286+
['X', 'Y'],
287+
'Out',
288+
check_pir=True,
289+
)
290+
291+
292+
@unittest.skipIf(
293+
core.is_compiled_with_xpu(),
294+
"Skip XPU for complex dtype is not fully supported",
295+
)
296+
class TestElementwisePowComplexOp1(TestElementwisePowComplexOp):
297+
def setUp(self):
298+
self.op_type = "elementwise_pow"
299+
self.python_api = paddle.pow
300+
self.public_python_api = paddle.pow
301+
self.prim_op_type = "prim"
302+
303+
real_part = np.random.uniform(-5, 5, size=(3, 4))
304+
imag_part = np.random.uniform(-5, 5, size=(3, 4))
305+
self.inputs = {
306+
'X': real_part + 1j * imag_part,
307+
'Y': np.random.uniform(1, 5, size=(3, 4)),
308+
}
309+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
310+
311+
312+
@unittest.skipIf(
313+
core.is_compiled_with_xpu(),
314+
"Skip XPU for complex dtype is not fully supported",
315+
)
316+
class TestElementwisePowComplexOp2(TestElementwisePowComplexOp):
317+
def setUp(self):
318+
self.op_type = "elementwise_pow"
319+
self.python_api = paddle.pow
320+
self.public_python_api = paddle.pow
321+
self.prim_op_type = "prim"
322+
323+
real_part = np.random.uniform(-5, 5, size=(20, 50))
324+
imag_part = np.random.uniform(-5, 5, size=(20, 50))
325+
self.inputs = {
326+
'X': real_part + 1j * imag_part,
327+
'Y': np.random.uniform(1, 5, size=(20, 50)),
328+
}
329+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
330+
331+
332+
@unittest.skipIf(
333+
core.is_compiled_with_xpu(),
334+
"Skip XPU for complex dtype is not fully supported",
335+
)
336+
class TestElementwisePowComplexOp3(TestElementwisePowComplexOp):
337+
def setUp(self):
338+
self.op_type = "elementwise_pow"
339+
self.python_api = paddle.pow
340+
self.public_python_api = paddle.pow
341+
self.prim_op_type = "prim"
342+
343+
real_part = np.random.uniform(-5, 5, size=(3, 5, 3))
344+
imag_part = np.random.uniform(-5, 5, size=(3, 5, 3))
345+
self.inputs = {
346+
'X': real_part + 1j * imag_part,
347+
'Y': np.random.uniform(1, 5, size=(3, 5, 3)),
348+
}
349+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
350+
351+
352+
@unittest.skipIf(
353+
core.is_compiled_with_xpu(),
354+
"Skip XPU for complex dtype is not fully supported",
355+
)
356+
class TestElementwisePowComplexOp4(TestElementwisePowComplexOp):
357+
def setUp(self):
358+
self.op_type = "elementwise_pow"
359+
self.python_api = paddle.pow
360+
self.public_python_api = paddle.pow
361+
self.prim_op_type = "prim"
362+
363+
real_part = np.random.uniform(-5, 5, size=(3, 5, 3))
364+
imag_part = np.random.uniform(-5, 5, size=(3, 5, 3))
365+
self.inputs = {
366+
'X': real_part + 1j * imag_part,
367+
'Y': real_part + 1j * imag_part,
368+
}
369+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
370+
371+
252372
class TestElementwisePowOpFP16(OpTest):
253373
def setUp(self):
254374
self.op_type = "elementwise_pow"

0 commit comments

Comments
 (0)