Skip to content

Commit 3bbe340

Browse files
committed
fix python
1 parent 17cb8cc commit 3bbe340

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

python/paddle/tensor/linalg.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,7 +2853,7 @@ def det(x: Tensor, name: str | None = None) -> Tensor:
28532853
return out
28542854

28552855

2856-
def slogdet(x: Tensor, name: str | None = None) -> Tensor:
2856+
def slogdet(x: Tensor, name: str | None = None) -> tuple[Tensor, Tensor]:
28572857
"""
28582858
28592859
Calculates the sign and natural logarithm of the absolute value of a square matrix's or batches square matrices' determinant.
@@ -2874,21 +2874,26 @@ def slogdet(x: Tensor, name: str | None = None) -> Tensor:
28742874
developers. Details: :ref:`api_guide_Name`. Default is None.
28752875
28762876
Returns:
2877-
y (Tensor), A tensor containing the sign of the determinant and the natural logarithm
2878-
of the absolute value of determinant, respectively. The output shape is :math:`(2, *)`,
2879-
where math:`*` is one or more batch dimensions of the input `x`.
2877+
tuple(Tensor, Tensor): A tuple containing two Tensors: (sign, logabsdet).
2878+
The first Tensor represents the signs of the determinants and the second Tensor
2879+
represents the natural logarithms of the absolute values of the determinants.
2880+
Each output Tensor has a shape of :math:`(*)`, where :math:`*` matches the
2881+
batch dimensions of the input `x`.
28802882
28812883
Examples:
28822884
.. code-block:: python
28832885
28842886
>>> import paddle
28852887
>>> paddle.seed(2023)
28862888
>>> x = paddle.randn([3, 3, 3])
2887-
>>> A = paddle.linalg.slogdet(x)
2888-
>>> print(A)
2889-
Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
2890-
[[-1. , 1. , 1. ],
2891-
[ 0.25681755, -0.25061053, -0.10809582]])
2889+
>>> sign_values, logabsdet_values = paddle.linalg.slogdet(x) # Updated example
2890+
>>> print("Sign:", sign_values)
2891+
>>> print("LogAbsDet:", logabsdet_values)
2892+
# Expected output would show two separate tensors
2893+
# Sign: Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
2894+
# [-1., 1., 1.])
2895+
# LogAbsDet: Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
2896+
# [0.25681755, -0.25061053, -0.10809582])
28922897
28932898
"""
28942899
if in_dynamic_or_pir_mode():
@@ -2912,14 +2917,25 @@ def slogdet(x: Tensor, name: str | None = None) -> Tensor:
29122917
f"but received {input_shape[-2]} by {input_shape[-1]} matrix.\n"
29132918
)
29142919
helper = LayerHelper('slogdeterminant', **locals())
2915-
out = helper.create_variable_for_type_inference(dtype=x.dtype)
2920+
sign_dtype = x.dtype
2921+
if x.dtype == paddle.complex64:
2922+
logabsdet_dtype = paddle.float32
2923+
elif x.dtype == paddle.complex128:
2924+
logabsdet_dtype = paddle.float64
2925+
else:
2926+
logabsdet_dtype = x.dtype
2927+
2928+
sign_out = helper.create_variable_for_type_inference(dtype=sign_dtype)
2929+
logabsdet_out = helper.create_variable_for_type_inference(
2930+
dtype=logabsdet_dtype
2931+
)
29162932

29172933
helper.append_op(
29182934
type='slogdeterminant',
29192935
inputs={'Input': [x]},
2920-
outputs={'Out': [out]},
2936+
outputs={'Sign': [sign_out], 'Logabsdet': [logabsdet_out]},
29212937
)
2922-
return out
2938+
return sign_out, logabsdet_out
29232939

29242940

29252941
def svd(

0 commit comments

Comments
 (0)