Skip to content

Commit b02a42c

Browse files
authored
Modify compare logical inplace (#56888)
* fix error * fix compare * fix * fix * remove fluid * fix inpalce test * fix and sep inpalce impl
1 parent 5422a44 commit b02a42c

File tree

6 files changed

+208
-113
lines changed

6 files changed

+208
-113
lines changed

paddle/phi/kernels/cpu/compare_kernel.cc

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,34 @@ inline void CompareKernelImpl(const Context& ctx,
3030
const DenseTensor& y,
3131
int axis,
3232
DenseTensor* out) {
33-
if (!out->IsSharedWith(x)) {
34-
ctx.template Alloc<bool>(out);
35-
if (x.dims().size() >= y.dims().size()) {
36-
funcs::ElementwiseCompute<Functor, T, bool>(
37-
ctx, x, y, Functor(), out, axis);
38-
} else {
39-
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
40-
ctx, x, y, InverseFunctor(), out, axis);
41-
}
33+
ctx.template Alloc<bool>(out);
34+
if (x.dims().size() >= y.dims().size()) {
35+
funcs::ElementwiseCompute<Functor, T, bool>(
36+
ctx, x, y, Functor(), out, axis);
4237
} else {
43-
if (x.dims().size() >= y.dims().size()) {
44-
funcs::ElementwiseCompute<Functor, T, T>(ctx, x, y, Functor(), out, axis);
45-
} else {
46-
funcs::ElementwiseCompute<InverseFunctor, T, T>(
47-
ctx, x, y, InverseFunctor(), out, axis);
48-
}
38+
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
39+
ctx, x, y, InverseFunctor(), out, axis);
40+
}
41+
}
42+
43+
template <typename T,
44+
typename Context,
45+
typename Functor,
46+
typename InverseFunctor>
47+
inline void InplaceCompareKernelImpl(const Context& ctx,
48+
const DenseTensor& x,
49+
const DenseTensor& y,
50+
int axis,
51+
DenseTensor* out) {
52+
auto x_origin = x;
53+
out->set_type(phi::DataType::BOOL);
54+
ctx.template Alloc<bool>(out);
55+
if (x_origin.dims().size() >= y.dims().size()) {
56+
funcs::ElementwiseCompute<Functor, T, bool>(
57+
ctx, x_origin, y, Functor(), out, axis);
58+
} else {
59+
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
60+
ctx, x_origin, y, InverseFunctor(), out, axis);
4961
}
5062
}
5163

@@ -92,19 +104,21 @@ PD_REGISTER_KERNEL(equal_all,
92104
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
93105
}
94106

95-
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
96-
PD_REGISTER_KERNEL(name, \
97-
CPU, \
98-
ALL_LAYOUT, \
99-
phi::func##Kernel, \
100-
bool, \
101-
int16_t, \
102-
int, \
103-
int64_t, \
104-
float, \
105-
double, \
106-
phi::dtype::float16, \
107-
phi::dtype::bfloat16) {}
107+
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
108+
PD_REGISTER_KERNEL(name, \
109+
CPU, \
110+
ALL_LAYOUT, \
111+
phi::func##Kernel, \
112+
bool, \
113+
int16_t, \
114+
int, \
115+
int64_t, \
116+
float, \
117+
double, \
118+
phi::dtype::float16, \
119+
phi::dtype::bfloat16) { \
120+
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
121+
}
108122
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
109123
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
110124
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)

paddle/phi/kernels/cpu/logical_kernel.cc

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,40 @@
2424

2525
namespace phi {
2626

27-
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
28-
template <typename T, typename Context> \
29-
void Logical##type##Kernel(const Context& dev_ctx, \
30-
const DenseTensor& x, \
31-
const DenseTensor& y, \
32-
DenseTensor* out) { \
33-
funcs::Logical##type##Functor<T> binary_func; \
34-
if (out->IsSharedWith(x)) { \
35-
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, T>( \
36-
dev_ctx, x, y, binary_func, out); \
37-
} else { \
38-
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
39-
dev_ctx, x, y, binary_func, out); \
40-
} \
27+
template <typename T, typename Context, typename Functor>
28+
void LogicalKernelImpl(const Context& dev_ctx,
29+
const DenseTensor& x,
30+
const DenseTensor& y,
31+
DenseTensor* out) {
32+
Functor binary_func;
33+
funcs::ElementwiseCompute<Functor, T, bool>(dev_ctx, x, y, binary_func, out);
34+
}
35+
36+
template <typename T, typename Context, typename Functor>
37+
void InplaceLogicalKernelImpl(const Context& dev_ctx,
38+
const DenseTensor& x,
39+
const DenseTensor& y,
40+
DenseTensor* out) {
41+
Functor binary_func;
42+
auto x_origin = x;
43+
out->set_type(phi::DataType::BOOL);
44+
funcs::ElementwiseCompute<Functor, T, bool>(
45+
dev_ctx, x_origin, y, binary_func, out);
46+
}
47+
48+
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
49+
template <typename T, typename Context> \
50+
void Logical##type##Kernel(const Context& dev_ctx, \
51+
const DenseTensor& x, \
52+
const DenseTensor& y, \
53+
DenseTensor* out) { \
54+
if (out->IsSharedWith(x)) { \
55+
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
56+
dev_ctx, x, y, out); \
57+
} else { \
58+
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
59+
dev_ctx, x, y, out); \
60+
} \
4161
}
4262

4363
DEFINE_LOGICAL_BINARY_KERNEL(And)
@@ -52,15 +72,18 @@ void LogicalNotKernel(const Context& dev_ctx,
5272
funcs::LogicalNotFunctor<T> unary_func;
5373

5474
phi::Transform<Context> trans;
55-
if (!out->IsSharedWith(x)) {
75+
if (out->IsSharedWith(x)) {
76+
auto x_origin = x;
77+
out->set_type(phi::DataType::BOOL);
5678
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
57-
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
58-
} else {
5979
trans(dev_ctx,
60-
x.data<T>(),
61-
x.data<T>() + x.numel(),
62-
reinterpret_cast<T*>(out->data()),
80+
x_origin.data<T>(),
81+
x_origin.data<T>() + x_origin.numel(),
82+
out_ptr,
6383
unary_func);
84+
} else {
85+
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
86+
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
6487
}
6588
}
6689

@@ -79,7 +102,9 @@ void LogicalNotKernel(const Context& dev_ctx,
79102
int8_t, \
80103
phi::dtype::complex<float>, \
81104
phi::dtype::complex<double>, \
82-
int16_t) {}
105+
int16_t) { \
106+
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
107+
}
83108

84109
REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
85110
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)

paddle/phi/kernels/impl/compare_kernel_impl.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,35 @@ inline void CompareKernelImpl(const Context& ctx,
3030
int axis,
3131
DenseTensor* out);
3232

33+
template <typename T,
34+
typename Context,
35+
typename Functor,
36+
typename InverseFunctor>
37+
inline void InplaceCompareKernelImpl(const Context& ctx,
38+
const DenseTensor& x,
39+
const DenseTensor& y,
40+
int axis,
41+
DenseTensor* out);
42+
3343
template <typename T, typename Context, typename Functor>
3444
inline void CompareAllKernelImpl(const Context& ctx,
3545
const DenseTensor& x,
3646
const DenseTensor& y,
3747
DenseTensor* out);
3848

39-
#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
40-
template <typename T, typename Context> \
41-
void name##Kernel(const Context& ctx, \
42-
const DenseTensor& x, \
43-
const DenseTensor& y, \
44-
DenseTensor* out) { \
45-
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
46-
ctx, x, y, -1, out); \
49+
#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
50+
template <typename T, typename Context> \
51+
void name##Kernel(const Context& ctx, \
52+
const DenseTensor& x, \
53+
const DenseTensor& y, \
54+
DenseTensor* out) { \
55+
if (out->IsSharedWith(x)) { \
56+
InplaceCompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
57+
ctx, x, y, -1, out); \
58+
} else { \
59+
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
60+
ctx, x, y, -1, out); \
61+
} \
4762
}
4863

4964
DEFINE_COMPARE_KERNEL(LessThan,

paddle/phi/kernels/kps/compare_kernel.cu

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,27 @@ inline void CompareKernelImpl(const Context& ctx,
5252
const DenseTensor& y,
5353
int axis,
5454
DenseTensor* out) {
55-
if (!out->IsSharedWith(x)) {
56-
ctx.template Alloc<bool>(out);
57-
}
55+
ctx.template Alloc<bool>(out);
5856
std::vector<const DenseTensor*> ins{&x, &y};
5957
std::vector<DenseTensor*> outs{out};
60-
if (!out->IsSharedWith(x)) {
61-
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
62-
} else {
63-
funcs::BroadcastKernel<T>(ctx, ins, &outs, Functor(), axis);
64-
}
58+
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
59+
}
60+
61+
template <typename T,
62+
typename Context,
63+
typename Functor,
64+
typename InverseFunctor>
65+
inline void InplaceCompareKernelImpl(const Context& ctx,
66+
const DenseTensor& x,
67+
const DenseTensor& y,
68+
int axis,
69+
DenseTensor* out) {
70+
auto x_origin = x;
71+
ctx.template Alloc<bool>(out);
72+
out->set_type(phi::DataType::BOOL);
73+
std::vector<const DenseTensor*> ins{&x_origin, &y};
74+
std::vector<DenseTensor*> outs{out};
75+
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
6576
}
6677

6778
#ifndef PADDLE_WITH_XPU_KP
@@ -134,18 +145,21 @@ PD_REGISTER_KERNEL(equal_all,
134145
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
135146
}
136147

137-
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
138-
PD_REGISTER_KERNEL(name, \
139-
KPS, \
140-
ALL_LAYOUT, \
141-
phi::func##Kernel, \
142-
bool, \
143-
int16_t, \
144-
int, \
145-
int64_t, \
146-
float, \
147-
double, \
148-
phi::dtype::float16) {}
148+
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
149+
PD_REGISTER_KERNEL(name, \
150+
KPS, \
151+
ALL_LAYOUT, \
152+
phi::func##Kernel, \
153+
bool, \
154+
int16_t, \
155+
int, \
156+
int64_t, \
157+
float, \
158+
double, \
159+
phi::dtype::float16, \
160+
phi::dtype::bfloat16) { \
161+
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
162+
}
149163

150164
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
151165
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)

paddle/phi/kernels/kps/logical_kernel.cu

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,45 @@
2525

2626
namespace phi {
2727

28-
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
29-
template <typename T, typename Context> \
30-
void Logical##type##Kernel(const Context& dev_ctx, \
31-
const DenseTensor& x, \
32-
const DenseTensor& y, \
33-
DenseTensor* out) { \
34-
if (!out->IsSharedWith(x)) { \
35-
dev_ctx.template Alloc<bool>(out); \
36-
} \
37-
\
38-
funcs::Logical##type##Functor<T> binary_func; \
39-
std::vector<const DenseTensor*> ins = {&x, &y}; \
40-
std::vector<DenseTensor*> outs = {out}; \
41-
if (!out->IsSharedWith(x)) { \
42-
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
43-
} else { \
44-
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, binary_func); \
45-
} \
28+
template <typename T, typename Context, typename Functor>
29+
void LogicalKernelImpl(const Context& dev_ctx,
30+
const DenseTensor& x,
31+
const DenseTensor& y,
32+
DenseTensor* out) {
33+
dev_ctx.template Alloc<bool>(out);
34+
Functor binary_func;
35+
std::vector<const DenseTensor*> ins = {&x, &y};
36+
std::vector<DenseTensor*> outs = {out};
37+
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
38+
}
39+
40+
template <typename T, typename Context, typename Functor>
41+
void InplaceLogicalKernelImpl(const Context& dev_ctx,
42+
const DenseTensor& x,
43+
const DenseTensor& y,
44+
DenseTensor* out) {
45+
auto x_origin = x;
46+
dev_ctx.template Alloc<bool>(out);
47+
out->set_type(phi::DataType::BOOL);
48+
Functor binary_func;
49+
std::vector<const DenseTensor*> ins = {&x_origin, &y};
50+
std::vector<DenseTensor*> outs = {out};
51+
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func);
52+
}
53+
54+
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
55+
template <typename T, typename Context> \
56+
void Logical##type##Kernel(const Context& dev_ctx, \
57+
const DenseTensor& x, \
58+
const DenseTensor& y, \
59+
DenseTensor* out) { \
60+
if (out->IsSharedWith(x)) { \
61+
InplaceLogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
62+
dev_ctx, x, y, out); \
63+
} else { \
64+
LogicalKernelImpl<T, Context, funcs::Logical##type##Functor<T>>( \
65+
dev_ctx, x, y, out); \
66+
} \
4667
}
4768

4869
DEFINE_LOGICAL_BINARY_KERNEL(And)
@@ -56,14 +77,18 @@ void LogicalNotKernel(const Context& dev_ctx,
5677
DenseTensor* out) {
5778
if (!out->IsSharedWith(x)) {
5879
dev_ctx.template Alloc<bool>(out);
59-
}
60-
funcs::LogicalNotFunctor<T> unary_func;
61-
std::vector<const DenseTensor*> ins = {&x};
62-
std::vector<DenseTensor*> outs = {out};
63-
if (!out->IsSharedWith(x)) {
80+
funcs::LogicalNotFunctor<T> unary_func;
81+
std::vector<const DenseTensor*> ins = {&x};
82+
std::vector<DenseTensor*> outs = {out};
6483
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
6584
} else {
66-
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, unary_func);
85+
auto x_origin = x;
86+
out->set_type(phi::DataType::BOOL);
87+
dev_ctx.template Alloc<bool>(out);
88+
funcs::LogicalNotFunctor<T> unary_func;
89+
std::vector<const DenseTensor*> ins = {&x_origin};
90+
std::vector<DenseTensor*> outs = {out};
91+
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
6792
}
6893
}
6994

@@ -99,7 +124,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
99124
int8_t, \
100125
phi::dtype::complex<float>, \
101126
phi::dtype::complex<double>, \
102-
int16_t) {}
127+
int16_t) { \
128+
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
129+
}
103130

104131
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
105132
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)

0 commit comments

Comments
 (0)