@@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform(
35
35
inverse = False ,
36
36
tails = None ,
37
37
tail_bound = 1.0 ,
38
- min_bin_width = DEFAULT_MIN_BIN_WIDTH ,
39
- min_bin_height = DEFAULT_MIN_BIN_HEIGHT ,
40
- min_derivative = DEFAULT_MIN_DERIVATIVE , ):
38
+ # for dygraph-to-static
39
+ min_bin_width = 1e-3 ,
40
+ min_bin_height = 1e-3 ,
41
+ min_derivative = 1e-3 , ):
41
42
if tails is None :
42
43
spline_fn = rational_quadratic_spline
43
44
spline_kwargs = {}
@@ -74,23 +75,27 @@ def unconstrained_rational_quadratic_spline(
74
75
inverse = False ,
75
76
tails = "linear" ,
76
77
tail_bound = 1.0 ,
77
- min_bin_width = DEFAULT_MIN_BIN_WIDTH ,
78
- min_bin_height = DEFAULT_MIN_BIN_HEIGHT ,
79
- min_derivative = DEFAULT_MIN_DERIVATIVE , ):
78
+ # for dygraph-to-static
79
+ min_bin_width = 1e-3 ,
80
+ min_bin_height = 1e-3 ,
81
+ min_derivative = 1e-3 , ):
80
82
inside_interval_mask = (inputs >= - tail_bound ) & (inputs <= tail_bound )
81
83
outside_interval_mask = ~ inside_interval_mask
82
-
83
- outputs = paddle .zeros (paddle .shape (inputs ))
84
- logabsdet = paddle .zeros (paddle .shape (inputs ))
84
+ # for dygraph to static
85
+ # 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
86
+ # 如果用 x.shape 的话可以保留确定的维度
87
+ outputs = paddle .zeros (inputs .shape )
88
+ logabsdet = paddle .zeros (inputs .shape )
85
89
if tails == "linear" :
86
90
unnormalized_derivatives = F .pad (
87
91
unnormalized_derivatives ,
88
92
pad = [0 ] * (len (unnormalized_derivatives .shape ) - 1 ) * 2 + [1 , 1 ])
89
93
constant = np .log (np .exp (1 - min_derivative ) - 1 )
90
94
unnormalized_derivatives [..., 0 ] = constant
91
95
unnormalized_derivatives [..., - 1 ] = constant
92
-
93
- outputs [outside_interval_mask ] = inputs [outside_interval_mask ]
96
+ # for dygraph to static
97
+ tmp = inputs [outside_interval_mask ]
98
+ outputs [outside_interval_mask ] = tmp
94
99
logabsdet [outside_interval_mask ] = 0
95
100
else :
96
101
raise RuntimeError ("{} tails are not implemented." .format (tails ))
@@ -130,18 +135,20 @@ def rational_quadratic_spline(
130
135
right = 1.0 ,
131
136
bottom = 0.0 ,
132
137
top = 1.0 ,
133
- min_bin_width = DEFAULT_MIN_BIN_WIDTH ,
134
- min_bin_height = DEFAULT_MIN_BIN_HEIGHT ,
135
- min_derivative = DEFAULT_MIN_DERIVATIVE , ):
136
- if paddle .min (inputs ) < left or paddle .max (inputs ) > right :
137
- raise ValueError ("Input to a transform is not within its domain" )
138
+ # for dygraph-to-static
139
+ min_bin_width = 1e-3 ,
140
+ min_bin_height = 1e-3 ,
141
+ min_derivative = 1e-3 , ):
142
+ # for dygraph to static
143
+ # if paddle.min(inputs) < left or paddle.max(inputs) > right:
144
+ # raise ValueError("Input to a transform is not within its domain")
138
145
139
146
num_bins = unnormalized_widths .shape [- 1 ]
140
-
141
- if min_bin_width * num_bins > 1.0 :
142
- raise ValueError ("Minimal bin width too large for the number of bins" )
143
- if min_bin_height * num_bins > 1.0 :
144
- raise ValueError ("Minimal bin height too large for the number of bins" )
147
+ # for dygraph to static
148
+ # if min_bin_width * num_bins > 1.0:
149
+ # raise ValueError("Minimal bin width too large for the number of bins")
150
+ # if min_bin_height * num_bins > 1.0:
151
+ # raise ValueError("Minimal bin height too large for the number of bins")
145
152
146
153
widths = F .softmax (unnormalized_widths , axis = - 1 )
147
154
widths = min_bin_width + (1 - min_bin_width * num_bins ) * widths
0 commit comments