Skip to content

Commit 1c5dc6e

Browse files
authored
【Paddle Tensor 规范化第二期】paddle.round supports int and complex (#72239)
* add int and complex * fix * fix ci * fix ci * fix ci * fix * fix ci * fix ci
1 parent 381e176 commit 1c5dc6e

File tree

7 files changed

+559
-17
lines changed

7 files changed

+559
-17
lines changed

paddle/phi/kernels/cpu/activation_grad_kernel.cc

+11-1
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,23 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad,
478478
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
479479
HardSwishGradKernel)
480480
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
481-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
482481
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
483482
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
484483
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
485484
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
486485
CeluDoubleGradKernel)
487486

487+
PD_REGISTER_KERNEL(round_grad,
488+
CPU,
489+
ALL_LAYOUT,
490+
phi::RoundGradKernel,
491+
float,
492+
double,
493+
int,
494+
int64_t,
495+
phi::dtype::complex<float>,
496+
phi::dtype::complex<double>) {}
497+
488498
PD_REGISTER_KERNEL(pow_grad,
489499
CPU,
490500
ALL_LAYOUT,

paddle/phi/kernels/cpu/activation_kernel.cc

+11-1
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,21 @@ PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
253253
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
254254
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
255255
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
256-
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
257256
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
258257
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
259258
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
260259

260+
PD_REGISTER_KERNEL(round,
261+
CPU,
262+
ALL_LAYOUT,
263+
phi::RoundKernel,
264+
int,
265+
int64_t,
266+
float,
267+
double,
268+
phi::dtype::complex<float>,
269+
phi::dtype::complex<double>) {}
270+
261271
PD_REGISTER_KERNEL(exp,
262272
CPU,
263273
ALL_LAYOUT,

paddle/phi/kernels/funcs/activation_functor.h

+141-10
Original file line numberDiff line numberDiff line change
@@ -2999,7 +2999,7 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
29992999
};
30003000

30013001
// round(x) = [x]
3002-
template <typename T>
3002+
template <typename T, typename Enable = void>
30033003
struct RoundFunctor : public BaseActivationFunctor<T> {
30043004
int decimals;
30053005

@@ -3010,13 +3010,85 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
30103010
template <typename Device, typename X, typename Out>
30113011
void operator()(Device d, X x, Out out) const {
30123012
if (decimals == 0) {
3013-
out.device(d) = x.round();
3013+
out.device(d) = x.unaryExpr([](const T& val) {
3014+
return (std::isnan(val) || std::isinf(val)) ? val : std::rint(val);
3015+
});
3016+
} else if (decimals > 0) {
3017+
auto ten_pow_decimals = static_cast<T>(std::pow(10, decimals));
3018+
out.device(d) = x.unaryExpr([ten_pow_decimals](const T& val) {
3019+
return (std::isnan(val) || std::isinf(val))
3020+
? val
3021+
: std::rint(val * ten_pow_decimals) / ten_pow_decimals;
3022+
});
3023+
} else {
3024+
auto ten_pow_decimals = static_cast<T>(std::pow(10, -decimals));
3025+
out.device(d) = x.unaryExpr([ten_pow_decimals](const T& val) {
3026+
return (std::isnan(val) || std::isinf(val))
3027+
? val
3028+
: std::rint(val / ten_pow_decimals) * ten_pow_decimals;
3029+
});
3030+
}
3031+
}
3032+
};
3033+
3034+
template <typename T>
3035+
struct RoundFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
3036+
: public BaseActivationFunctor<T> {
3037+
int decimals;
3038+
3039+
std::vector<std::pair<const char*, int*>> GetAttrs() {
3040+
return {{"decimals", &decimals}};
3041+
}
3042+
3043+
template <typename Device, typename X, typename Out>
3044+
void operator()(Device d, X x, Out out) const {
3045+
out.device(d) = x;
3046+
}
3047+
};
3048+
3049+
template <typename T>
3050+
struct RoundFunctor<phi::dtype::complex<T>>
3051+
: public BaseActivationFunctor<phi::dtype::complex<T>> {
3052+
int decimals;
3053+
3054+
std::vector<std::pair<const char*, int*>> GetAttrs() {
3055+
return {{"decimals", &decimals}};
3056+
}
3057+
3058+
template <typename Device, typename X, typename Out>
3059+
void operator()(Device d, X x, Out out) const {
3060+
using ComplexT = phi::dtype::complex<T>;
3061+
3062+
if (decimals == 0) {
3063+
out.device(d) = x.unaryExpr([](const ComplexT& c) {
3064+
T real = std::isnan(c.real) || std::isinf(c.real) ? c.real
3065+
: std::rint(c.real);
3066+
T imag = std::isnan(c.imag) || std::isinf(c.imag) ? c.imag
3067+
: std::rint(c.imag);
3068+
return ComplexT(real, imag);
3069+
});
30143070
} else if (decimals > 0) {
30153071
auto ten_pow_decimals = static_cast<T>(std::pow(10, decimals));
3016-
out.device(d) = (x * ten_pow_decimals).round() / ten_pow_decimals;
3072+
out.device(d) = x.unaryExpr([ten_pow_decimals](const ComplexT& c) {
3073+
T real = std::isnan(c.real) || std::isinf(c.real)
3074+
? c.real
3075+
: std::rint(c.real * ten_pow_decimals) / ten_pow_decimals;
3076+
T imag = std::isnan(c.imag) || std::isinf(c.imag)
3077+
? c.imag
3078+
: std::rint(c.imag * ten_pow_decimals) / ten_pow_decimals;
3079+
return ComplexT(real, imag);
3080+
});
30173081
} else {
30183082
auto ten_pow_decimals = static_cast<T>(std::pow(10, -decimals));
3019-
out.device(d) = (x / ten_pow_decimals).round() * ten_pow_decimals;
3083+
out.device(d) = x.unaryExpr([ten_pow_decimals](const ComplexT& c) {
3084+
T real = std::isnan(c.real) || std::isinf(c.real)
3085+
? c.real
3086+
: std::rint(c.real / ten_pow_decimals) * ten_pow_decimals;
3087+
T imag = std::isnan(c.imag) || std::isinf(c.imag)
3088+
? c.imag
3089+
: std::rint(c.imag / ten_pow_decimals) * ten_pow_decimals;
3090+
return ComplexT(real, imag);
3091+
});
30203092
}
30213093
}
30223094
};
@@ -5318,7 +5390,7 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
53185390
}
53195391
};
53205392

5321-
template <typename T>
5393+
template <typename T, typename Enable = void>
53225394
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
53235395
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
53245396
int decimals;
@@ -5330,20 +5402,79 @@ struct CudaRoundFunctor : public BaseActivationFunctor<T> {
53305402
__device__ __forceinline__ T operator()(const T arg_x) const {
53315403
MPType x = static_cast<MPType>(arg_x);
53325404

5405+
if (isnan(x) || isinf(x)) return arg_x;
53335406
if (decimals == 0) {
5334-
return static_cast<T>(round(x));
5407+
return static_cast<T>(std::rint(x));
53355408
} else if (decimals > 0) {
5336-
float ten_pow_decimals = powf(10., decimals);
5337-
return static_cast<T>(round(x * static_cast<MPType>(ten_pow_decimals)) /
5409+
MPType ten_pow_decimals =
5410+
pow(static_cast<MPType>(10), static_cast<MPType>(decimals));
5411+
return static_cast<T>(rint(x * static_cast<MPType>(ten_pow_decimals)) /
53385412
ten_pow_decimals);
53395413
} else {
5340-
float ten_pow_decimals = powf(10., -decimals);
5341-
return static_cast<T>(round(x / static_cast<MPType>(ten_pow_decimals)) *
5414+
MPType ten_pow_decimals =
5415+
pow(static_cast<MPType>(10), static_cast<MPType>(-decimals));
5416+
return static_cast<T>(rint(x / static_cast<MPType>(ten_pow_decimals)) *
53425417
ten_pow_decimals);
53435418
}
53445419
}
53455420
};
53465421

5422+
template <typename T>
5423+
struct CudaRoundFunctor<T, std::enable_if_t<std::is_integral_v<T>>>
5424+
: public BaseActivationFunctor<T> {
5425+
int decimals;
5426+
5427+
std::vector<std::pair<const char*, int*>> GetAttrs() {
5428+
return {{"decimals", &decimals}};
5429+
}
5430+
// round(x) = round(x)
5431+
__device__ __forceinline__ T operator()(const T arg_x) const { return arg_x; }
5432+
};
5433+
5434+
template <typename T>
5435+
struct CudaRoundFunctor<phi::dtype::complex<T>>
5436+
: public BaseActivationFunctor<phi::dtype::complex<T>> {
5437+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
5438+
int decimals;
5439+
5440+
std::vector<std::pair<const char*, int*>> GetAttrs() {
5441+
return {{"decimals", &decimals}};
5442+
}
5443+
5444+
__device__ __forceinline__ phi::dtype::complex<T> operator()(
5445+
const phi::dtype::complex<T> arg_x) const {
5446+
MPType real_part = static_cast<MPType>(arg_x.real);
5447+
MPType imag_part = static_cast<MPType>(arg_x.imag);
5448+
bool real_special = isnan(real_part) || isinf(real_part);
5449+
bool imag_special = isnan(imag_part) || isinf(imag_part);
5450+
MPType real, imag;
5451+
5452+
if (decimals == 0) {
5453+
real = real_special ? real_part : rint(real_part);
5454+
imag = imag_special ? imag_part : rint(imag_part);
5455+
} else if (decimals > 0) {
5456+
MPType ten_pow_decimals =
5457+
pow(static_cast<MPType>(10), static_cast<MPType>(decimals));
5458+
real = real_special
5459+
? real_part
5460+
: rint(real_part * ten_pow_decimals) / ten_pow_decimals;
5461+
imag = imag_special
5462+
? imag_part
5463+
: rint(imag_part * ten_pow_decimals) / ten_pow_decimals;
5464+
} else {
5465+
MPType ten_pow_decimals =
5466+
pow(static_cast<MPType>(10), static_cast<MPType>(-decimals));
5467+
real = real_special
5468+
? real_part
5469+
: rint(real_part / ten_pow_decimals) * ten_pow_decimals;
5470+
imag = imag_special
5471+
? imag_part
5472+
: rint(imag_part / ten_pow_decimals) * ten_pow_decimals;
5473+
}
5474+
return phi::dtype::complex<T>(static_cast<T>(real), static_cast<T>(imag));
5475+
}
5476+
};
5477+
53475478
// GradFunctor for ceil, floor and round
53485479
template <typename T>
53495480
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {

paddle/phi/kernels/gpu/activation_grad_kernel.cu

+12-1
Original file line numberDiff line numberDiff line change
@@ -547,12 +547,23 @@ PD_REGISTER_KERNEL(log_double_grad,
547547
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
548548
HardSwishGradKernel)
549549
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
550-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
551550
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
552551
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
553552
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
554553
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)
555554

555+
PD_REGISTER_KERNEL(round_grad,
556+
GPU,
557+
ALL_LAYOUT,
558+
phi::RoundGradKernel,
559+
int,
560+
int64_t,
561+
float,
562+
double,
563+
phi::dtype::float16,
564+
phi::dtype::bfloat16,
565+
phi::dtype::complex<float>,
566+
phi::dtype::complex<double>) {}
556567
PD_REGISTER_KERNEL(pow_grad,
557568
GPU,
558569
ALL_LAYOUT,

paddle/phi/kernels/gpu/activation_kernel.cu

+12-1
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,24 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
336336
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
337337
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
338338
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
339-
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
340339
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
341340
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
342341
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
343342
PD_REGISTER_ACTIVATION_KERNEL(selu, SeluKernel)
344343
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
345344

345+
PD_REGISTER_KERNEL(round,
346+
GPU,
347+
ALL_LAYOUT,
348+
phi::RoundKernel,
349+
int,
350+
int64_t,
351+
float,
352+
double,
353+
phi::dtype::float16,
354+
phi::dtype::bfloat16,
355+
phi::dtype::complex<float>,
356+
phi::dtype::complex<double>) {}
346357
PD_REGISTER_KERNEL(log,
347358
GPU,
348359
ALL_LAYOUT,

python/paddle/tensor/ops.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def round(x: Tensor, decimals: int = 0, name: str | None = None) -> Tensor:
810810
out.data = [1., -1., 3., 1.]
811811
812812
Args:
813-
x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, float32, float64 or float16.
813+
x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128.
814814
decimals(int): Rounded decimal place (default: 0).
815815
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
816816
@@ -826,13 +826,25 @@ def round(x: Tensor, decimals: int = 0, name: str | None = None) -> Tensor:
826826
>>> out = paddle.round(x)
827827
>>> print(out)
828828
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
829-
[-1., -0., 1., 2.])
829+
[-0., -0., 1., 2.])
830830
"""
831831
if in_dynamic_or_pir_mode():
832832
return _C_ops.round(x, decimals)
833833
else:
834834
check_variable_and_dtype(
835-
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'round'
835+
x,
836+
'x',
837+
[
838+
'float16',
839+
'uint16',
840+
'int32',
841+
'int64',
842+
'float32',
843+
'float64',
844+
'complex64',
845+
'complex128',
846+
],
847+
'round',
836848
)
837849
helper = LayerHelper('round', **locals())
838850
attrs = {

0 commit comments

Comments
 (0)