Skip to content

Commit 2ed814b

Browse files
authored
[NPU] Fix truncated_gaussian_random (PaddlePaddle#1289)
1 parent bec4df1 commit 2ed814b

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

backends/npu/kernels/truncated_gaussian_random_kernel.cc

+14-6
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,28 @@ T Erfinv(T x) {
131131
}
132132
}
133133

134+
template <typename T>
135+
T clamp(T val, T min, T max) {
136+
return val < min ? min : (val > max ? max : val);
137+
}
138+
134139
template <typename T>
135140
struct TruncatedNormal {
136-
T mean, std;
141+
T mean, std, a, b;
137142
T a_normal_cdf;
138143
T b_normal_cdf;
139-
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
144+
TruncatedNormal(T mean, T std, T a, T b) : mean(mean), std(std), a(a), b(b) {
140145
auto normal_cdf = [](T x) {
141146
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
142147
};
143-
a_normal_cdf = normal_cdf(-2.0);
144-
b_normal_cdf = normal_cdf(2.0);
148+
a_normal_cdf = normal_cdf((a - mean) / std);
149+
b_normal_cdf = normal_cdf((b - mean) / std);
145150
}
146151

147152
T operator()(T value) const {
148153
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
149-
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
154+
T ret = std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
155+
return clamp(ret, a, b);
150156
}
151157
};
152158

@@ -156,6 +162,8 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
156162
float mean,
157163
float std,
158164
int seed,
165+
float a,
166+
float b,
159167
phi::DataType dtype,
160168
phi::DenseTensor* out) {
161169
dev_ctx.template Alloc<T>(out);
@@ -167,7 +175,7 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
167175

168176
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
169177
1.0);
170-
TruncatedNormal<T> truncated_normal(mean, std);
178+
TruncatedNormal<T> truncated_normal(mean, std, a, b);
171179
int64_t size = out->numel();
172180

173181
std::shared_ptr<std::mt19937_64> engine;

backends/npu/tests/unittests/test_truncated_gaussian_random_op_npu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _test(self, run_npu=True):
4040
weight_attr = paddle.framework.ParamAttr(
4141
name="linear_weight",
4242
initializer=paddle.nn.initializer.TruncatedNormal(
43-
mean=0.0, std=2.0
43+
mean=0.0, std=2.0, a=-2.0, b=2.0
4444
),
4545
)
4646
linear = paddle.nn.Linear(

0 commit comments

Comments
 (0)