|
18 | 18 | from dygraph_to_static_utils import (
|
19 | 19 | Dy2StTestBase,
|
20 | 20 | test_ast_only,
|
| 21 | + test_pir_only, |
21 | 22 | )
|
22 | 23 |
|
23 | 24 | import paddle
|
@@ -58,6 +59,17 @@ def test_mix_cast(x):
|
58 | 59 | return x
|
59 | 60 |
|
60 | 61 |
|
| 62 | +def test_complex_cast(x): |
| 63 | + x = paddle.to_tensor(x) |
| 64 | + x = complex(x) |
| 65 | + return x |
| 66 | + |
| 67 | + |
| 68 | +def test_not_var_complex_cast(x): |
| 69 | + x = complex(x) |
| 70 | + return x |
| 71 | + |
| 72 | + |
61 | 73 | class TestCastBase(Dy2StTestBase):
|
62 | 74 | def setUp(self):
|
63 | 75 | self.place = (
|
@@ -190,5 +202,64 @@ def test_cast_result(self):
|
190 | 202 | )
|
191 | 203 |
|
192 | 204 |
|
| 205 | +@unittest.skipIf( |
| 206 | + paddle.core.is_compiled_with_xpu(), |
| 207 | + "xpu does not support complex cast temporarily", |
| 208 | +) |
| 209 | +class TestComplexCast(TestCastBase): |
| 210 | + def prepare(self): |
| 211 | + self.input_shape = (8, 16) |
| 212 | + self.input_dtype = 'float32' |
| 213 | + self.input = ( |
| 214 | + np.random.binomial(2, 0.5, size=np.prod(self.input_shape)) |
| 215 | + .reshape(self.input_shape) |
| 216 | + .astype(self.input_dtype) |
| 217 | + ) |
| 218 | + self.cast_dtype = 'complex64' |
| 219 | + |
| 220 | + def set_func(self): |
| 221 | + self.func = paddle.jit.to_static(full_graph=True)(test_complex_cast) |
| 222 | + |
| 223 | + @test_pir_only |
| 224 | + def test_cast_result(self): |
| 225 | + self.set_func() |
| 226 | + res = self.do_test().numpy() |
| 227 | + self.assertTrue( |
| 228 | + res.dtype == self.cast_dtype, |
| 229 | + msg=f'The target dtype is {self.cast_dtype}, but the casted dtype is {res.dtype}.', |
| 230 | + ) |
| 231 | + ref_val = self.input.astype(self.cast_dtype) |
| 232 | + np.testing.assert_allclose( |
| 233 | + res, |
| 234 | + ref_val, |
| 235 | + rtol=1e-05, |
| 236 | + err_msg=f'The casted value is {res}.\nThe correct value is {ref_val}.', |
| 237 | + ) |
| 238 | + |
| 239 | + |
| 240 | +class TestNotVarComplexCast(TestCastBase): |
| 241 | + def prepare(self): |
| 242 | + self.input = 3.14 |
| 243 | + self.cast_dtype = 'complex' |
| 244 | + |
| 245 | + def set_func(self): |
| 246 | + self.func = paddle.jit.to_static(full_graph=True)( |
| 247 | + test_not_var_complex_cast |
| 248 | + ) |
| 249 | + |
| 250 | + @test_ast_only |
| 251 | + def test_cast_result(self): |
| 252 | + self.set_func() |
| 253 | + res = self.do_test() |
| 254 | + self.assertTrue( |
| 255 | + type(res) == complex, msg='The casted dtype is not complex.' |
| 256 | + ) |
| 257 | + ref_val = complex(self.input) |
| 258 | + self.assertTrue( |
| 259 | + res == ref_val, |
| 260 | + msg=f'The casted value is {res}.\nThe correct value is {ref_val}.', |
| 261 | + ) |
| 262 | + |
| 263 | + |
193 | 264 | if __name__ == '__main__':
|
194 | 265 | unittest.main()
|
0 commit comments