Skip to content

Commit 05170c6

Browse files
authored
【Paddle Tensor No.10】新增 Tensor.__complex__ (#69257)
* add tensor complex * update jit * update convert_operators * update test * update test * add test
1 parent 44c1872 commit 05170c6

File tree

12 files changed

+151
-1
lines changed

12 files changed

+151
-1
lines changed

python/paddle/base/dygraph/math_op_patch.py

+11
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def _neg_(var: Tensor) -> Tensor:
110110
def _abs_(var: Tensor) -> Tensor:
111111
return var.abs()
112112

113+
def _complex_(var: Tensor) -> complex:
114+
numel = np.prod(var.shape)
115+
assert (
116+
numel == 1
117+
), "only one element variable can be converted to complex."
118+
assert var._is_initialized(), "variable's tensor is not initialized"
119+
if not var.is_complex():
120+
var = var.astype('complex64')
121+
return complex(np.array(var))
122+
113123
def _float_(var: Tensor) -> float:
114124
numel = np.prod(var.shape)
115125
assert (
@@ -203,6 +213,7 @@ def _mT_(var: Tensor) -> Tensor:
203213
eager_methods = [
204214
('__neg__', _neg_),
205215
('__abs__', _abs_),
216+
('__complex__', _complex_),
206217
('__float__', _float_),
207218
('__long__', _long_),
208219
('__int__', _int_),

python/paddle/base/layers/math_op_patch.py

+8
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,13 @@ def _float_(self):
749749
"2. If you want to run it in full graph mode, you need use Variable directly, and do not use float(Variable)."
750750
)
751751

752+
def _complex_(self):
753+
raise TypeError(
754+
"complex(Variable) is not supported in static graph mode. If you are using @to_static, you can try this:\n"
755+
"1. If you want to get the value of Variable, you can switch to non-fullgraph mode by setting @to_static(full_graph=True).\n"
756+
"2. If you want to run it in full graph mode, you need use Variable directly, and do not use complex(Variable)."
757+
)
758+
752759
def values(var):
753760
block = current_block(var)
754761
out = create_new_tmp_var(block, var.dtype)
@@ -889,6 +896,7 @@ def to_dense(var):
889896
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
890897
('__float__', _float_),
891898
('__int__', _int_),
899+
('__complex__', _complex_),
892900
('values', values),
893901
('indices', indices),
894902
('to_dense', to_dense),

python/paddle/jit/dy2static/convert_operators.py

+3
Original file line numberDiff line numberDiff line change
@@ -759,18 +759,21 @@ def convert_var_dtype(var, dtype):
759759
'bool',
760760
'int',
761761
'float',
762+
'complex',
762763
], f"The casted target dtype is {dtype}, which is not supported in type casting."
763764
cast_map = {
764765
'bool': 'bool',
765766
'int': 'int32',
766767
'float': 'float32',
768+
'complex': 'complex64',
767769
}
768770
return paddle.cast(var, dtype=cast_map[dtype])
769771
else:
770772
assert dtype in [
771773
'bool',
772774
'int',
773775
'float',
776+
'complex',
774777
], f"The casted target dtype is {dtype}, which is not supported in type casting."
775778
return eval(dtype)(var)
776779

python/paddle/jit/dy2static/transformers/cast_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class CastTransformer(BaseTransformer):
2727

2828
def __init__(self, root):
2929
self.root = root
30-
self._castable_type = {'bool', 'int', 'float'}
30+
self._castable_type = {'bool', 'int', 'float', 'complex'}
3131

3232
def transform(self):
3333
self.visit(self.root)

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

+6
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,12 @@ def __float__(self) -> float:
825825
def float(self):
826826
return ConstantVariable(float(self), self.graph, DummyTracker([self]))
827827

828+
def __complex__(self) -> complex:
829+
return complex(self.get_py_value())
830+
831+
def complex(self):
832+
return ConstantVariable(complex(self), self.graph, DummyTracker([self]))
833+
828834
@property
829835
def out_var_name(self):
830836
return f"{self.graph.OUT_VAR_PREFIX}{self.var_name}"

python/paddle/jit/sot/utils/magic_methods.py

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
float: "__float__",
8686
len: "__len__",
8787
int: "__int__",
88+
complex: "__complex__",
8889
}
8990
# TODO(SigureMo): support any, all, sum
9091

python/paddle/pir/math_op_patch.py

+25
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,30 @@ def _bool_(self):
595595
"""
596596
raise TypeError(textwrap.dedent(error_msg))
597597

598+
def _complex_(self):
599+
error_msg = """\
600+
complex(Tensor) is not supported in static graph mode. Because it's value is not available during the static mode.
601+
It's usually triggered by the logging implicitly, for example:
602+
>>> logging.info("The value of x is: {complex(x)}")
603+
^ `x` is Tensor, `complex(x)` triggers complex(Tensor)
604+
605+
There are two common workarounds available:
606+
If you are logging Tensor values, then consider logging only at dynamic graphs, for example:
607+
608+
Modify the following code
609+
>>> logging.info("The value of x is: {complex(x)}")
610+
to
611+
>>> if paddle.in_dynamic_mode():
612+
... logging.info("The value of x is: {complex(x)}")
613+
614+
If you need to convert the Tensor type, for example:
615+
Modify the following code
616+
>>> x = complex(x)
617+
to
618+
>>> x = x.astype("complex64")
619+
"""
620+
raise TypeError(textwrap.dedent(error_msg))
621+
598622
def clone(self):
599623
"""
600624
Returns a new static Value, which is the clone of the original static
@@ -1143,6 +1167,7 @@ def register_hook(self, hook):
11431167
('__float__', _float_),
11441168
('__int__', _int_),
11451169
('__bool__', _bool_),
1170+
('__complex__', _complex_),
11461171
]
11471172

11481173
global _already_patch_value

python/paddle/tensor/tensor.prototype.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class AbstractTensor:
187187
def __int__(self) -> int: ...
188188
def __long__(self) -> float: ...
189189
def __nonzero__(self) -> bool: ...
190+
def __complex__(self) -> complex: ...
190191

191192
# emulating container types
192193
def __getitem__(

test/dygraph_to_static/test_cast.py

+71
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
2020
test_ast_only,
21+
test_pir_only,
2122
)
2223

2324
import paddle
@@ -58,6 +59,17 @@ def test_mix_cast(x):
5859
return x
5960

6061

62+
def test_complex_cast(x):
63+
x = paddle.to_tensor(x)
64+
x = complex(x)
65+
return x
66+
67+
68+
def test_not_var_complex_cast(x):
69+
x = complex(x)
70+
return x
71+
72+
6173
class TestCastBase(Dy2StTestBase):
6274
def setUp(self):
6375
self.place = (
@@ -190,5 +202,64 @@ def test_cast_result(self):
190202
)
191203

192204

205+
@unittest.skipIf(
206+
paddle.core.is_compiled_with_xpu(),
207+
"xpu does not support complex cast temporarily",
208+
)
209+
class TestComplexCast(TestCastBase):
210+
def prepare(self):
211+
self.input_shape = (8, 16)
212+
self.input_dtype = 'float32'
213+
self.input = (
214+
np.random.binomial(2, 0.5, size=np.prod(self.input_shape))
215+
.reshape(self.input_shape)
216+
.astype(self.input_dtype)
217+
)
218+
self.cast_dtype = 'complex64'
219+
220+
def set_func(self):
221+
self.func = paddle.jit.to_static(full_graph=True)(test_complex_cast)
222+
223+
@test_pir_only
224+
def test_cast_result(self):
225+
self.set_func()
226+
res = self.do_test().numpy()
227+
self.assertTrue(
228+
res.dtype == self.cast_dtype,
229+
msg=f'The target dtype is {self.cast_dtype}, but the casted dtype is {res.dtype}.',
230+
)
231+
ref_val = self.input.astype(self.cast_dtype)
232+
np.testing.assert_allclose(
233+
res,
234+
ref_val,
235+
rtol=1e-05,
236+
err_msg=f'The casted value is {res}.\nThe correct value is {ref_val}.',
237+
)
238+
239+
240+
class TestNotVarComplexCast(TestCastBase):
241+
def prepare(self):
242+
self.input = 3.14
243+
self.cast_dtype = 'complex'
244+
245+
def set_func(self):
246+
self.func = paddle.jit.to_static(full_graph=True)(
247+
test_not_var_complex_cast
248+
)
249+
250+
@test_ast_only
251+
def test_cast_result(self):
252+
self.set_func()
253+
res = self.do_test()
254+
self.assertTrue(
255+
type(res) == complex, msg='The casted dtype is not complex.'
256+
)
257+
ref_val = complex(self.input)
258+
self.assertTrue(
259+
res == ref_val,
260+
msg=f'The casted value is {res}.\nThe correct value is {ref_val}.',
261+
)
262+
263+
193264
if __name__ == '__main__':
194265
unittest.main()

test/legacy_test/test_math_op_patch.py

+2
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def test_builtin_type_conversion(self):
405405
int(a)
406406
with self.assertRaises(TypeError):
407407
float(a)
408+
with self.assertRaises(TypeError):
409+
complex(a)
408410

409411

410412
class TestDygraphMathOpPatches(unittest.TestCase):

test/legacy_test/test_math_op_patch_pir.py

+14
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,20 @@ def test_builtin_type_conversion(self):
734734
float(x)
735735
with self.assertRaises(TypeError):
736736
bool(x)
737+
with self.assertRaises(TypeError):
738+
complex(x)
739+
740+
def test_builtin_type_conversion_old_ir(self):
741+
with paddle.pir_utils.DygraphOldIrGuard():
742+
_, _, program_guard = new_program()
743+
with program_guard:
744+
x = paddle.static.data(name='x', shape=[], dtype="float32")
745+
with self.assertRaises(TypeError):
746+
int(x)
747+
with self.assertRaises(TypeError):
748+
float(x)
749+
with self.assertRaises(TypeError):
750+
complex(x)
737751

738752
def test_math_exists(self):
739753
with paddle.pir_utils.IrGuard():

test/legacy_test/test_math_op_patch_var_base.py

+8
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,14 @@ def test_float_int_long(self):
522522
self.assertTrue(int(a) == 999424)
523523
self.assertTrue(int(a) == 999424)
524524

525+
def test_complex(self):
526+
with base.dygraph.guard():
527+
a = paddle.to_tensor(np.array([100.1 + 99.9j]))
528+
self.assertTrue(complex(a) == (100.1 + 99.9j))
529+
530+
a = paddle.to_tensor(1000000.0, dtype='bfloat16')
531+
self.assertTrue(complex(a) == (999424 + 0j))
532+
525533
def test_len(self):
526534
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
527535
with base.dygraph.guard():

0 commit comments

Comments
 (0)