Skip to content

Commit d2777ec

Browse files
authored
[Accuracy diff No.38-39] Fix accuracy diff for nextafter API (#72965)
1 parent 6cc3784 commit d2777ec

File tree

9 files changed

+127
-167
lines changed

9 files changed

+127
-167
lines changed

paddle/phi/kernels/cpu/elementwise_kernel.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ void CopySignKernel(const Context& dev_ctx,
106106
}
107107
}
108108

109+
template <typename T, typename Context>
110+
void NextafterKernel(const Context& dev_ctx,
111+
const DenseTensor& x,
112+
const DenseTensor& y,
113+
DenseTensor* out) {
114+
dev_ctx.template Alloc<T>(out);
115+
auto x_dims = x.dims();
116+
auto y_dims = y.dims();
117+
if (x_dims.size() >= y_dims.size()) {
118+
funcs::ElementwiseCompute<funcs::NextafterFunctor<T>, T>(
119+
dev_ctx, x, y, funcs ::NextafterFunctor<T>(), out);
120+
} else {
121+
funcs::ElementwiseCompute<funcs::InverseNextafterFunctor<T>, T>(
122+
dev_ctx, x, y, funcs::InverseNextafterFunctor<T>(), out);
123+
}
124+
}
125+
109126
} // namespace phi
110127

111128
using complex64 = ::phi::dtype::complex<float>;
@@ -193,3 +210,5 @@ PD_REGISTER_KERNEL(copysign,
193210
double,
194211
phi::dtype::float16,
195212
phi::dtype::bfloat16) {}
213+
PD_REGISTER_KERNEL(
214+
nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {}

paddle/phi/kernels/cpu/nextafter_kernel.cc

Lines changed: 0 additions & 22 deletions
This file was deleted.

paddle/phi/kernels/elementwise_kernel.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ void CopySignKernel(const Context& dev_ctx,
7373
const DenseTensor& y,
7474
DenseTensor* out);
7575

76+
template <typename T, typename Context>
77+
void NextafterKernel(const Context& dev_ctx,
78+
const DenseTensor& x,
79+
const DenseTensor& y,
80+
DenseTensor* out);
81+
7682
template <typename T, typename Context>
7783
DenseTensor Maximum(const Context& dev_ctx,
7884
const DenseTensor& x,

paddle/phi/kernels/funcs/elementwise_functor.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,5 +1174,56 @@ struct InverseCopySignFunctor {
11741174
}
11751175
};
11761176

1177+
template <typename T, typename Enable = void>
1178+
struct NextafterFunctor {
1179+
inline HOSTDEVICE T operator()(const T x, const T y) const {
1180+
return static_cast<T>(
1181+
std::nextafter(static_cast<float>(x), static_cast<float>(y)));
1182+
}
1183+
};
1184+
1185+
template <typename T>
1186+
struct NextafterFunctor<
1187+
T,
1188+
typename std::enable_if_t<std::is_same<T, double>::value>> {
1189+
inline HOSTDEVICE T operator()(const T x, const T y) const {
1190+
return std::nextafter(x, y);
1191+
}
1192+
};
1193+
1194+
template <typename T>
1195+
struct NextafterFunctor<T,
1196+
typename std::enable_if_t<std::is_integral<T>::value>> {
1197+
inline HOSTDEVICE double operator()(const T x, const T y) const {
1198+
return std::nextafter(static_cast<double>(x), static_cast<double>(y));
1199+
}
1200+
};
1201+
1202+
template <typename T, typename Enable = void>
1203+
struct InverseNextafterFunctor {
1204+
inline HOSTDEVICE T operator()(const T x, const T y) const {
1205+
return static_cast<T>(
1206+
std::nextafter(static_cast<float>(y), static_cast<float>(x)));
1207+
}
1208+
};
1209+
1210+
template <typename T>
1211+
struct InverseNextafterFunctor<
1212+
T,
1213+
typename std::enable_if_t<std::is_same<T, double>::value>> {
1214+
inline HOSTDEVICE T operator()(const T x, const T y) const {
1215+
return std::nextafter(y, x);
1216+
}
1217+
};
1218+
1219+
template <typename T>
1220+
struct InverseNextafterFunctor<
1221+
T,
1222+
typename std::enable_if_t<std::is_integral<T>::value>> {
1223+
inline HOSTDEVICE double operator()(const T x, const T y) const {
1224+
return std::nextafter(static_cast<double>(y), static_cast<double>(x));
1225+
}
1226+
};
1227+
11771228
} // namespace funcs
11781229
} // namespace phi

paddle/phi/kernels/gpu/nextafter_kernel.cu

Lines changed: 0 additions & 22 deletions
This file was deleted.

paddle/phi/kernels/impl/nextafter_kernel_impl.h

Lines changed: 0 additions & 93 deletions
This file was deleted.

paddle/phi/kernels/kps/elementwise_kernel.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ void CopySignKernel(const Context& dev_ctx,
196196
dev_ctx, inputs, &outputs, funcs::CopySignFunctor<T>());
197197
}
198198

199+
template <typename T, typename Context>
200+
void NextafterKernel(const Context& dev_ctx,
201+
const DenseTensor& x,
202+
const DenseTensor& y,
203+
DenseTensor* out) {
204+
std::vector<const DenseTensor*> inputs = {&x, &y};
205+
std::vector<DenseTensor*> outputs = {out};
206+
dev_ctx.template Alloc<T>(out);
207+
funcs::BroadcastKernel<T>(
208+
dev_ctx, inputs, &outputs, funcs::NextafterFunctor<T>());
209+
}
210+
199211
} // namespace phi
200212

201213
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
@@ -269,6 +281,8 @@ PD_REGISTER_KERNEL(copysign,
269281
double,
270282
phi::dtype::float16,
271283
phi::dtype::bfloat16) {}
284+
PD_REGISTER_KERNEL(
285+
nextafter, GPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {}
272286

273287
#endif
274288

paddle/phi/kernels/nextafter_kernel.h

Lines changed: 0 additions & 28 deletions
This file was deleted.

test/legacy_test/test_nextafter_op.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ def setUp(self):
9595
self.op_type = "nextafter"
9696
self.python_api = paddle.nextafter
9797
self.init_dtype()
98+
self.init_shape()
9899

99-
x = np.array([1, 2]).astype(self.dtype)
100-
y = np.array([2, 1]).astype(self.dtype)
100+
x = np.random.rand(*self.x_shape).astype(self.dtype)
101+
y = np.random.rand(*self.y_shape).astype(self.dtype)
101102
out = np.nextafter(x, y)
102103
self.inputs = {'x': x, 'y': y}
103104
self.outputs = {'out': out}
@@ -108,11 +109,45 @@ def test_check_output(self):
108109
def init_dtype(self):
109110
self.dtype = np.float64
110111

112+
def init_shape(self):
113+
self.x_shape = (2,)
114+
self.y_shape = (2,)
115+
111116

112117
class TestNextafterOPFP32(TestNextafterOP):
113118
def init_dtype(self):
114119
self.dtype = np.float32
115120

116121

122+
class TestNextafterOPFP32Case1(TestNextafterOP):
123+
def init_dtype(self):
124+
self.dtype = np.float32
125+
126+
def init_shape(self):
127+
self.x_shape = (5,)
128+
self.y_shape = (2, 3, 4, 5)
129+
130+
131+
class TestNextafterOPFP32Case2(TestNextafterOP):
132+
def init_dtype(self):
133+
self.dtype = np.float32
134+
135+
def init_shape(self):
136+
self.x_shape = (2, 3, 4, 5)
137+
self.y_shape = (1,)
138+
139+
140+
class TestNextafterOPCase1(TestNextafterOP):
141+
def init_shape(self):
142+
self.x_shape = (5,)
143+
self.y_shape = (2, 3, 4, 5)
144+
145+
146+
class TestNextafterOPCase2(TestNextafterOP):
147+
def init_shape(self):
148+
self.x_shape = (2, 3, 4, 5)
149+
self.y_shape = (1,)
150+
151+
117152
if __name__ == "__main__":
118153
unittest.main()

0 commit comments

Comments
 (0)