Skip to content

paddle.distribution.StudentT improve input data type and fix returned dimension 易用性提升 #68895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,44 @@ def _logits_to_probs(
if is_binary
else paddle.nn.functional.softmax(logits, axis=-1)
)

def _broadcast_all(
self, *args: TensorLike | NestedNumbericSequence
) -> tuple[Tensor, ...]:
r"""
Returns a list where each arg is broadcasted. Scalar args are upcast to tensors
having the same data type as the first Tensor passed to `args`. If all the
args are scalars, then they are upcasted to Tensors with paddle default data type.

Args:
value (float, list, numpy.ndarray, Tensor)

Returns:
Broadcasted Tensor of args.
"""
for arg in args:
if not isinstance(
arg,
(float, list, tuple, np.ndarray, Variable, paddle.pir.Value),
):
raise TypeError(
f"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {type(arg)}"
)
if not all(
isinstance(arg, (Variable, paddle.pir.Value)) for arg in args
):
dtype = paddle.get_default_dtype()
for arg in args:
if isinstance(arg, (Variable, paddle.pir.Value)):
dtype = arg.dtype
break
new_args = [
(
arg
if isinstance(arg, (Variable, paddle.pir.Value))
else paddle.to_tensor(arg, dtype=dtype)
)
for arg in args
]
return paddle.broadcast_tensors(new_args)
return paddle.broadcast_tensors(args)
21 changes: 3 additions & 18 deletions python/paddle/distribution/student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING

import paddle
from paddle.base.data_feeder import check_type, convert_dtype
from paddle.base.data_feeder import check_type
from paddle.base.framework import Variable
from paddle.distribution import Gamma, distribution
from paddle.framework import in_dynamic_mode
Expand Down Expand Up @@ -135,18 +135,7 @@ def __init__(
)

self.name = name if name is not None else 'StudentT'
self.dtype = paddle.get_default_dtype()

if self._validate_args(df, loc, scale):
self.df = df
self.loc = loc
self.scale = scale
self.df, self.loc, self.scale = paddle.broadcast_tensors(
[self.df, self.loc, self.scale]
)
self.dtype = convert_dtype(df.dtype)
else:
self.df, self.loc, self.scale = self._to_tensor(df, loc, scale)
self.df, self.loc, self.scale = self._broadcast_all(df, loc, scale)

if not self._check_nonnegative(self.df):
raise ValueError(
Expand All @@ -157,10 +146,6 @@ def __init__(
'Every element of input parameter `scale` should be nonnegative.'
)

if self.df.shape == []:
self.df = self.df.reshape([1])
Copy link
Contributor

@zhwesky2010 zhwesky2010 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NKNaN 这种0维和1维的问题你再排查下其他distribution api,看看还有没有这种0维改成1维的不合理操作,这个是一个典型的不合理操作:

if x.shape == []:
            x= x.reshape([1])

看看其他distribution有没这种问题,目前0维Tensor在paddle内是支持的并全面推广,如果有错误写法,支持不兼容升级过来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

self.loc = self.loc.reshape([1])
self.scale = self.scale.reshape([1])
batch_shape = self.df.shape
super().__init__(batch_shape)
self._chi2 = Gamma(0.5 * self.df, paddle.full_like(self.df, 0.5))
Expand Down Expand Up @@ -222,7 +207,7 @@ def sample(self, shape: Sequence[int] = ()) -> Tensor:
raise TypeError('sample shape must be Sequence object.')

output_shape = self._extend_shape(shape)
z = paddle.cast(paddle.normal(shape=output_shape), self.dtype)
z = paddle.normal(shape=output_shape)
chi2 = self._chi2.sample(shape)
x = z * paddle.rsqrt(chi2 / self.df)
return self.loc + self.scale * x
Expand Down
66 changes: 66 additions & 0 deletions test/distribution/test_distribution_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,47 @@ def _np_entropy(self):
return scipy.stats.t.entropy(df, loc, scale)


@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale'),
[
(
'float-tensor',
10.0,
paddle.to_tensor(1.0),
2.0,
),
(
'float-tensor1',
10.0,
parameterize.xrand((2, 3), dtype='float32', min=1, max=10),
2.0,
),
(
'float-tensor2',
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
parameterize.xrand((2, 3), dtype='float64', min=1, max=10),
2.0,
),
(
'float-tensor3',
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
1.0,
parameterize.xrand((2, 1), dtype='float64', min=0.1, max=3),
),
(
'float-tensor4',
5.0,
parameterize.xrand((2, 1), dtype='float32', min=-1, max=-10),
parameterize.xrand((2, 3), dtype='float32', min=0.1, max=3),
),
],
)
class TestStudentT2(TestStudentT):
def setUp(self):
self._dist = StudentT(self.df, self.loc, self.scale)


@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale', 'value'),
Expand Down Expand Up @@ -247,6 +288,31 @@ def test_log_prob(self):
)


@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale', 'value'),
[
(
'float-tensor1',
10.0,
parameterize.xrand((2, 1), dtype='float32', min=-10, max=10),
1.0,
np.array(3.3).astype("float32"),
),
(
'float-tensor2',
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
1.0,
parameterize.xrand((2, 1), dtype='float64', min=0.1, max=5),
parameterize.xrand((2, 4), dtype='float64', min=-10, max=10),
),
],
)
class TestStudentTProbs2(TestStudentTProbs):
def setUp(self):
self._dist = StudentT(self.df, self.loc, self.scale)


@parameterize.place(config.DEVICES)
@parameterize_cls([TEST_CASE_NAME], ['StudentTTestError'])
class StudentTTestError(unittest.TestCase):
Expand Down