Skip to content

Commit 6c687c1

Browse files
fix type check
1 parent 7ac279d commit 6c687c1

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

python/paddle/distribution/exponential.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numbers
1615

1716
import numpy as np
1817

1918
import paddle
2019
from paddle import distribution
21-
from paddle.base import framework
20+
from paddle.base.data_feeder import check_type, convert_dtype
21+
from paddle.base.framework import Variable
2222
from paddle.distribution import exponential_family
23+
from paddle.framework import in_dynamic_mode
2324

2425

2526
class Exponential(exponential_family.ExponentialFamily):
@@ -37,7 +38,7 @@ class Exponential(exponential_family.ExponentialFamily):
3738
* :math:`rate = \theta`: is the rate parameter.
3839
3940
Args:
40-
rate (int|float|Tensor): Rate parameter. The value of rate must be positive.
41+
rate (float|Tensor): Rate parameter. The value of rate must be positive.
4142
4243
Example:
4344
.. code-block:: python
@@ -56,16 +57,25 @@ class Exponential(exponential_family.ExponentialFamily):
5657
"""
5758

5859
def __init__(self, rate):
59-
if not isinstance(
60-
rate, (numbers.Real, paddle.Tensor, framework.Variable)
61-
):
62-
raise TypeError(
63-
f"Expected type of rate is scalar or tensor, but got {type(rate)}"
60+
if not in_dynamic_mode():
61+
check_type(
62+
rate,
63+
'rate',
64+
(float, Variable),
65+
'Exponential',
6466
)
6567

66-
if isinstance(rate, numbers.Real):
67-
rate = paddle.full(shape=(), fill_value=rate, dtype=paddle.float32)
68-
self.rate = rate
68+
# Get/convert rate to tensor.
69+
if self._validate_args(rate):
70+
self.rate = rate
71+
self.dtype = convert_dtype(rate.dtype)
72+
else:
73+
[self.rate] = self._to_tensor(rate)
74+
self.dtype = paddle.get_default_dtype()
75+
76+
if not paddle.all(self.rate > 0):
77+
raise ValueError("The arg of `rate` must be positive.")
78+
6979
super().__init__(self.rate.shape)
7080

7181
@property

python/paddle/distribution/gamma.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numbers
1615

1716
import paddle
1817
from paddle import distribution
19-
from paddle.base import framework
18+
from paddle.base.data_feeder import check_type, convert_dtype
19+
from paddle.base.framework import Variable
2020
from paddle.distribution import exponential_family
21+
from paddle.framework import in_dynamic_mode
2122

2223

2324
class Gamma(exponential_family.ExponentialFamily):
@@ -33,11 +34,11 @@ class Gamma(exponential_family.ExponentialFamily):
3334
\Gamma(\alpha)=\int_{0}^{\infty} x^{\alpha-1} e^{-x} \mathrm{~d} x, (\alpha>0)
3435
3536
Args:
36-
concentration (int|float|Tensor): Concentration parameter. It supports broadcast semantics.
37+
concentration (float|Tensor): Concentration parameter. It supports broadcast semantics.
3738
The value of concentration must be positive. When the parameter is a tensor,
3839
it represents multiple independent distribution with
3940
a batch_shape(refer to ``Distribution`` ).
40-
rate (int|float|Tensor): Rate parameter. It supports broadcast semantics.
41+
rate (float|Tensor): Rate parameter. It supports broadcast semantics.
4142
The value of rate must be positive. When the parameter is tensor,
4243
it represent multiple independent distribution with
4344
a batch_shape(refer to ``Distribution`` ).
@@ -71,29 +72,40 @@ class Gamma(exponential_family.ExponentialFamily):
7172
"""
7273

7374
def __init__(self, concentration, rate):
74-
if not isinstance(
75-
concentration, (numbers.Real, paddle.Tensor, framework.Variable)
76-
):
77-
raise TypeError(
78-
f"Expected type of concentration is scalar or tensor, but got {type(concentration)}"
75+
if not in_dynamic_mode():
76+
check_type(
77+
concentration,
78+
'concentration',
79+
(float, Variable),
80+
'Gamma',
81+
)
82+
check_type(
83+
rate,
84+
'rate',
85+
(float, Variable),
86+
'Gamma',
7987
)
8088

81-
if not isinstance(
82-
rate, (numbers.Real, paddle.Tensor, framework.Variable)
83-
):
84-
raise TypeError(
85-
f"Expected type of rate is scalar or tensor, but got {type(rate)}"
89+
# Get/convert concentration/rate to tensor.
90+
if self._validate_args(concentration, rate):
91+
self.concentration = concentration
92+
self.rate = rate
93+
self.dtype = convert_dtype(concentration.dtype)
94+
else:
95+
[self.concentration, self.rate] = self._to_tensor(
96+
concentration, rate
8697
)
98+
self.dtype = paddle.get_default_dtype()
8799

88-
if isinstance(concentration, numbers.Real):
89-
concentration = paddle.full(shape=[], fill_value=concentration)
100+
if not paddle.all(self.concentration > 0):
101+
raise ValueError("The arg of `concentration` must be positive.")
90102

91-
if isinstance(rate, numbers.Real):
92-
rate = paddle.full(shape=[], fill_value=rate)
103+
if not paddle.all(self.rate > 0):
104+
raise ValueError("The arg of `rate` must be positive.")
93105

94-
self.concentration, self.rate = paddle.broadcast_tensors(
95-
[concentration, rate]
96-
)
106+
# self.concentration, self.rate = paddle.broadcast_tensors(
107+
# [concentration, rate]
108+
# )
97109
super().__init__(self.concentration.shape)
98110

99111
@property

0 commit comments

Comments
 (0)