12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import numbers
16
15
17
16
import paddle
18
17
from paddle import distribution
19
- from paddle .base import framework
18
+ from paddle .base .data_feeder import check_type , convert_dtype
19
+ from paddle .base .framework import Variable
20
20
from paddle .distribution import exponential_family
21
+ from paddle .framework import in_dynamic_mode
21
22
22
23
23
24
class Gamma (exponential_family .ExponentialFamily ):
@@ -33,11 +34,11 @@ class Gamma(exponential_family.ExponentialFamily):
33
34
\Gamma(\alpha)=\int_{0}^{\infty} x^{\alpha-1} e^{-x} \mathrm{~d} x, (\alpha>0)
34
35
35
36
Args:
36
- concentration (int| float|Tensor): Concentration parameter. It supports broadcast semantics.
37
+ concentration (float|Tensor): Concentration parameter. It supports broadcast semantics.
37
38
The value of concentration must be positive. When the parameter is a tensor,
38
39
it represents multiple independent distribution with
39
40
a batch_shape(refer to ``Distribution`` ).
40
- rate (int| float|Tensor): Rate parameter. It supports broadcast semantics.
41
+ rate (float|Tensor): Rate parameter. It supports broadcast semantics.
41
42
The value of rate must be positive. When the parameter is tensor,
42
43
it represent multiple independent distribution with
43
44
a batch_shape(refer to ``Distribution`` ).
@@ -71,29 +72,40 @@ class Gamma(exponential_family.ExponentialFamily):
71
72
"""
72
73
73
74
def __init__ (self , concentration , rate ):
74
- if not isinstance (
75
- concentration , (numbers .Real , paddle .Tensor , framework .Variable )
76
- ):
77
- raise TypeError (
78
- f"Expected type of concentration is scalar or tensor, but got { type (concentration )} "
75
+ if not in_dynamic_mode ():
76
+ check_type (
77
+ concentration ,
78
+ 'concentration' ,
79
+ (float , Variable ),
80
+ 'Gamma' ,
81
+ )
82
+ check_type (
83
+ rate ,
84
+ 'rate' ,
85
+ (float , Variable ),
86
+ 'Gamma' ,
79
87
)
80
88
81
- if not isinstance (
82
- rate , (numbers .Real , paddle .Tensor , framework .Variable )
83
- ):
84
- raise TypeError (
85
- f"Expected type of rate is scalar or tensor, but got { type (rate )} "
89
+ # Get/convert concentration/rate to tensor.
90
+ if self ._validate_args (concentration , rate ):
91
+ self .concentration = concentration
92
+ self .rate = rate
93
+ self .dtype = convert_dtype (concentration .dtype )
94
+ else :
95
+ [self .concentration , self .rate ] = self ._to_tensor (
96
+ concentration , rate
86
97
)
98
+ self .dtype = paddle .get_default_dtype ()
87
99
88
- if isinstance ( concentration , numbers . Real ):
89
- concentration = paddle . full ( shape = [], fill_value = concentration )
100
+ if not paddle . all ( self . concentration > 0 ):
101
+ raise ValueError ( "The arg of ` concentration` must be positive." )
90
102
91
- if isinstance ( rate , numbers . Real ):
92
- rate = paddle . full ( shape = [], fill_value = rate )
103
+ if not paddle . all ( self . rate > 0 ):
104
+ raise ValueError ( "The arg of ` rate` must be positive." )
93
105
94
- self .concentration , self .rate = paddle .broadcast_tensors (
95
- [concentration , rate ]
96
- )
106
+ # self.concentration, self.rate = paddle.broadcast_tensors(
107
+ # [concentration, rate]
108
+ # )
97
109
super ().__init__ (self .concentration .shape )
98
110
99
111
@property
0 commit comments