@@ -2999,7 +2999,7 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
2999
2999
};
3000
3000
3001
3001
// round(x) = [x]
3002
- template <typename T>
3002
+ template <typename T, typename Enable = void >
3003
3003
struct RoundFunctor : public BaseActivationFunctor <T> {
3004
3004
int decimals;
3005
3005
@@ -3010,13 +3010,85 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
3010
3010
template <typename Device, typename X, typename Out>
3011
3011
void operator ()(Device d, X x, Out out) const {
3012
3012
if (decimals == 0 ) {
3013
- out.device (d) = x.round ();
3013
+ out.device (d) = x.unaryExpr ([](const T& val) {
3014
+ return (std::isnan (val) || std::isinf (val)) ? val : std::rint (val);
3015
+ });
3016
+ } else if (decimals > 0 ) {
3017
+ auto ten_pow_decimals = static_cast <T>(std::pow (10 , decimals));
3018
+ out.device (d) = x.unaryExpr ([ten_pow_decimals](const T& val) {
3019
+ return (std::isnan (val) || std::isinf (val))
3020
+ ? val
3021
+ : std::rint (val * ten_pow_decimals) / ten_pow_decimals;
3022
+ });
3023
+ } else {
3024
+ auto ten_pow_decimals = static_cast <T>(std::pow (10 , -decimals));
3025
+ out.device (d) = x.unaryExpr ([ten_pow_decimals](const T& val) {
3026
+ return (std::isnan (val) || std::isinf (val))
3027
+ ? val
3028
+ : std::rint (val / ten_pow_decimals) * ten_pow_decimals;
3029
+ });
3030
+ }
3031
+ }
3032
+ };
3033
+
3034
+ template <typename T>
3035
+ struct RoundFunctor <T, std::enable_if_t <std::is_integral_v<T>>>
3036
+ : public BaseActivationFunctor<T> {
3037
+ int decimals;
3038
+
3039
+ std::vector<std::pair<const char *, int *>> GetAttrs () {
3040
+ return {{" decimals" , &decimals}};
3041
+ }
3042
+
3043
+ template <typename Device, typename X, typename Out>
3044
+ void operator ()(Device d, X x, Out out) const {
3045
+ out.device (d) = x;
3046
+ }
3047
+ };
3048
+
3049
+ template <typename T>
3050
+ struct RoundFunctor <phi::dtype::complex<T>>
3051
+ : public BaseActivationFunctor<phi::dtype::complex<T>> {
3052
+ int decimals;
3053
+
3054
+ std::vector<std::pair<const char *, int *>> GetAttrs () {
3055
+ return {{" decimals" , &decimals}};
3056
+ }
3057
+
3058
+ template <typename Device, typename X, typename Out>
3059
+ void operator ()(Device d, X x, Out out) const {
3060
+ using ComplexT = phi::dtype::complex<T>;
3061
+
3062
+ if (decimals == 0 ) {
3063
+ out.device (d) = x.unaryExpr ([](const ComplexT& c) {
3064
+ T real = std::isnan (c.real ) || std::isinf (c.real ) ? c.real
3065
+ : std::rint (c.real );
3066
+ T imag = std::isnan (c.imag ) || std::isinf (c.imag ) ? c.imag
3067
+ : std::rint (c.imag );
3068
+ return ComplexT (real, imag);
3069
+ });
3014
3070
} else if (decimals > 0 ) {
3015
3071
auto ten_pow_decimals = static_cast <T>(std::pow (10 , decimals));
3016
- out.device (d) = (x * ten_pow_decimals).round () / ten_pow_decimals;
3072
+ out.device (d) = x.unaryExpr ([ten_pow_decimals](const ComplexT& c) {
3073
+ T real = std::isnan (c.real ) || std::isinf (c.real )
3074
+ ? c.real
3075
+ : std::rint (c.real * ten_pow_decimals) / ten_pow_decimals;
3076
+ T imag = std::isnan (c.imag ) || std::isinf (c.imag )
3077
+ ? c.imag
3078
+ : std::rint (c.imag * ten_pow_decimals) / ten_pow_decimals;
3079
+ return ComplexT (real, imag);
3080
+ });
3017
3081
} else {
3018
3082
auto ten_pow_decimals = static_cast <T>(std::pow (10 , -decimals));
3019
- out.device (d) = (x / ten_pow_decimals).round () * ten_pow_decimals;
3083
+ out.device (d) = x.unaryExpr ([ten_pow_decimals](const ComplexT& c) {
3084
+ T real = std::isnan (c.real ) || std::isinf (c.real )
3085
+ ? c.real
3086
+ : std::rint (c.real / ten_pow_decimals) * ten_pow_decimals;
3087
+ T imag = std::isnan (c.imag ) || std::isinf (c.imag )
3088
+ ? c.imag
3089
+ : std::rint (c.imag / ten_pow_decimals) * ten_pow_decimals;
3090
+ return ComplexT (real, imag);
3091
+ });
3020
3092
}
3021
3093
}
3022
3094
};
@@ -5318,7 +5390,7 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
5318
5390
}
5319
5391
};
5320
5392
5321
- template <typename T>
5393
+ template <typename T, typename Enable = void >
5322
5394
struct CudaRoundFunctor : public BaseActivationFunctor <T> {
5323
5395
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
5324
5396
int decimals;
@@ -5330,20 +5402,79 @@ struct CudaRoundFunctor : public BaseActivationFunctor<T> {
5330
5402
__device__ __forceinline__ T operator ()(const T arg_x) const {
5331
5403
MPType x = static_cast <MPType>(arg_x);
5332
5404
5405
+ if (isnan (x) || isinf (x)) return arg_x;
5333
5406
if (decimals == 0 ) {
5334
- return static_cast <T>(round (x));
5407
+ return static_cast <T>(std::rint (x));
5335
5408
} else if (decimals > 0 ) {
5336
- float ten_pow_decimals = powf (10 ., decimals);
5337
- return static_cast <T>(round (x * static_cast <MPType>(ten_pow_decimals)) /
5409
+ MPType ten_pow_decimals =
5410
+ pow (static_cast <MPType>(10 ), static_cast <MPType>(decimals));
5411
+ return static_cast <T>(rint (x * static_cast <MPType>(ten_pow_decimals)) /
5338
5412
ten_pow_decimals);
5339
5413
} else {
5340
- float ten_pow_decimals = powf (10 ., -decimals);
5341
- return static_cast <T>(round (x / static_cast <MPType>(ten_pow_decimals)) *
5414
+ MPType ten_pow_decimals =
5415
+ pow (static_cast <MPType>(10 ), static_cast <MPType>(-decimals));
5416
+ return static_cast <T>(rint (x / static_cast <MPType>(ten_pow_decimals)) *
5342
5417
ten_pow_decimals);
5343
5418
}
5344
5419
}
5345
5420
};
5346
5421
5422
+ template <typename T>
5423
+ struct CudaRoundFunctor <T, std::enable_if_t <std::is_integral_v<T>>>
5424
+ : public BaseActivationFunctor<T> {
5425
+ int decimals;
5426
+
5427
+ std::vector<std::pair<const char *, int *>> GetAttrs () {
5428
+ return {{" decimals" , &decimals}};
5429
+ }
5430
+ // round(x) = round(x)
5431
+ __device__ __forceinline__ T operator ()(const T arg_x) const { return arg_x; }
5432
+ };
5433
+
5434
+ template <typename T>
5435
+ struct CudaRoundFunctor <phi::dtype::complex<T>>
5436
+ : public BaseActivationFunctor<phi::dtype::complex<T>> {
5437
+ using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
5438
+ int decimals;
5439
+
5440
+ std::vector<std::pair<const char *, int *>> GetAttrs () {
5441
+ return {{" decimals" , &decimals}};
5442
+ }
5443
+
5444
+ __device__ __forceinline__ phi::dtype::complex<T> operator ()(
5445
+ const phi::dtype::complex<T> arg_x) const {
5446
+ MPType real_part = static_cast <MPType>(arg_x.real );
5447
+ MPType imag_part = static_cast <MPType>(arg_x.imag );
5448
+ bool real_special = isnan (real_part) || isinf (real_part);
5449
+ bool imag_special = isnan (imag_part) || isinf (imag_part);
5450
+ MPType real, imag;
5451
+
5452
+ if (decimals == 0 ) {
5453
+ real = real_special ? real_part : rint (real_part);
5454
+ imag = imag_special ? imag_part : rint (imag_part);
5455
+ } else if (decimals > 0 ) {
5456
+ MPType ten_pow_decimals =
5457
+ pow (static_cast <MPType>(10 ), static_cast <MPType>(decimals));
5458
+ real = real_special
5459
+ ? real_part
5460
+ : rint (real_part * ten_pow_decimals) / ten_pow_decimals;
5461
+ imag = imag_special
5462
+ ? imag_part
5463
+ : rint (imag_part * ten_pow_decimals) / ten_pow_decimals;
5464
+ } else {
5465
+ MPType ten_pow_decimals =
5466
+ pow (static_cast <MPType>(10 ), static_cast <MPType>(-decimals));
5467
+ real = real_special
5468
+ ? real_part
5469
+ : rint (real_part / ten_pow_decimals) * ten_pow_decimals;
5470
+ imag = imag_special
5471
+ ? imag_part
5472
+ : rint (imag_part / ten_pow_decimals) * ten_pow_decimals;
5473
+ }
5474
+ return phi::dtype::complex<T>(static_cast <T>(real), static_cast <T>(imag));
5475
+ }
5476
+ };
5477
+
5347
5478
// GradFunctor for ceil, floor and round
5348
5479
template <typename T>
5349
5480
struct CudaZeroGradFunctor : public BaseActivationFunctor <T> {
0 commit comments