Skip to content

Commit db089f7

Browse files
committed
ci test
1 parent b769bf8 commit db089f7

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

test/legacy_test/test_elementwise_mul_op.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,7 @@ def init_input_output(self):
596596
self.x = np.random.random((2, 3, 4, 5)).astype(
597597
self.dtype
598598
) + 1j * np.random.random((2, 3, 4, 5)).astype(self.dtype)
599-
self.y = np.random.random((2, 3, 4, 5)).astype(
600-
self.dtype
601-
) + 1j * np.random.random((2, 3, 4, 5)).astype(self.dtype)
599+
self.y = np.random.random((2, 3, 4, 5)).astype(self.dtype)
602600
self.out = self.x * self.y
603601

604602
def test_check_output(self):

test/legacy_test/test_elementwise_pow_op.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def test_check_grad_normal(self):
283283
self.check_grad(['X', 'Y'], 'Out', check_dygraph=False)
284284
else:
285285
self.check_grad(
286-
['X'],
286+
['X', 'Y'],
287287
'Out',
288288
check_pir=True,
289289
)
@@ -369,6 +369,28 @@ def setUp(self):
369369
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
370370

371371

372+
@unittest.skipIf(
373+
core.is_compiled_with_xpu(),
374+
"Skip XPU for complex dtype is not fully supported",
375+
)
376+
class TestElementwisePowComplexOp5(TestElementwisePowComplexOp):
377+
def setUp(self):
378+
self.op_type = "elementwise_pow"
379+
self.python_api = paddle.pow
380+
self.public_python_api = paddle.pow
381+
self.prim_op_type = "prim"
382+
383+
x_real_part = np.random.uniform(-5, 5, size=(5, 3))
384+
x_imag_part = np.random.uniform(-5, 5, size=(5, 3))
385+
y_real_part = np.random.uniform(-5, 5, size=(3, 5, 3))
386+
y_imag_part = np.random.uniform(-5, 5, size=(3, 5, 3))
387+
self.inputs = {
388+
'X': x_real_part + 1j * x_imag_part,
389+
'Y': y_real_part + 1j * y_imag_part,
390+
}
391+
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
392+
393+
372394
class TestElementwisePowOpFP16(OpTest):
373395
def setUp(self):
374396
self.op_type = "elementwise_pow"

0 commit comments

Comments
 (0)