@@ -131,22 +131,28 @@ T Erfinv(T x) {
131
131
}
132
132
}
133
133
134
+ template <typename T>
135
+ T clamp (T val, T min, T max) {
136
+ return val < min ? min : (val > max ? max : val);
137
+ }
138
+
134
139
template <typename T>
135
140
struct TruncatedNormal {
136
- T mean, std;
141
+ T mean, std, a, b ;
137
142
T a_normal_cdf;
138
143
T b_normal_cdf;
139
- TruncatedNormal (T mean, T std) : mean(mean), std(std) {
144
+ TruncatedNormal (T mean, T std, T a, T b ) : mean(mean), std(std), a(a), b(b ) {
140
145
auto normal_cdf = [](T x) {
141
146
return (1.0 + std::erf (x / std::sqrt (2.0 ))) / 2.0 ;
142
147
};
143
- a_normal_cdf = normal_cdf (- 2.0 );
144
- b_normal_cdf = normal_cdf (2.0 );
148
+ a_normal_cdf = normal_cdf ((a - mean) / std );
149
+ b_normal_cdf = normal_cdf ((b - mean) / std );
145
150
}
146
151
147
152
T operator ()(T value) const {
148
153
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
149
- return std::sqrt (2.0 ) * Erfinv (2 * p - 1 ) * std + mean;
154
+ T ret = std::sqrt (2.0 ) * Erfinv (2 * p - 1 ) * std + mean;
155
+ return clamp (ret, a, b);
150
156
}
151
157
};
152
158
@@ -156,6 +162,8 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
156
162
float mean,
157
163
float std,
158
164
int seed,
165
+ float a,
166
+ float b,
159
167
phi::DataType dtype,
160
168
phi::DenseTensor* out) {
161
169
dev_ctx.template Alloc <T>(out);
@@ -167,7 +175,7 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
167
175
168
176
std::uniform_real_distribution<T> dist (std::numeric_limits<float >::min (),
169
177
1.0 );
170
- TruncatedNormal<T> truncated_normal (mean, std);
178
+ TruncatedNormal<T> truncated_normal (mean, std, a, b );
171
179
int64_t size = out->numel ();
172
180
173
181
std::shared_ptr<std::mt19937_64> engine;
0 commit comments