Skip to content

Commit 81d5ba2

Browse files
authored
[NPU]fix equal kernel (PaddlePaddle#228)
* [NPU]fix equal kernel * remove extra code * call cask kernel in equal kernel
1 parent 93bfe1d commit 81d5ba2

File tree

2 files changed

+74
-44
lines changed

2 files changed

+74
-44
lines changed

backends/npu/kernels/compare_kernel.cc

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,33 @@
1616

1717
namespace custom_kernel {
1818

19+
template <typename T, typename Context>
20+
void CastKernel(const Context& dev_ctx,
21+
const phi::DenseTensor& x,
22+
phi::DenseTensorMeta::DataType dtype,
23+
phi::DenseTensor* out);
24+
1925
template <typename T, typename Context>
2026
void EqualRawKernel(const Context& dev_ctx,
21-
const phi::DenseTensor& x,
22-
const phi::DenseTensor& y,
23-
int axis,
24-
phi::DenseTensor* out) {
25-
dev_ctx.template Alloc<bool>(out);
26-
const auto& runner = NpuOpRunner("Equal", {x, y}, {*out}, {});
27+
const phi::DenseTensor& x,
28+
const phi::DenseTensor& y,
29+
int axis,
30+
phi::DenseTensor* out) {
2731
auto stream = dev_ctx.stream();
32+
dev_ctx.template Alloc<bool>(out);
33+
34+
phi::DenseTensor transformed_x(x), transformed_y;
35+
if (x.dtype() != y.dtype()) {
36+
phi::DenseTensorMeta meta = {x.dtype(), y.dims()};
37+
transformed_y.set_meta(meta);
38+
custom_kernel::CastKernel<T, Context>(
39+
dev_ctx, y, x.dtype(), &transformed_y);
40+
} else {
41+
transformed_y = y;
42+
}
43+
44+
const auto& runner =
45+
NpuOpRunner("Equal", {transformed_x, transformed_y}, {*out}, {});
2846
runner.Run(stream);
2947
}
3048

@@ -38,10 +56,10 @@ void EqualKernel(const Context& dev_ctx,
3856

3957
template <typename T, typename Context>
4058
void NotEqualRawKernel(const Context& dev_ctx,
41-
const phi::DenseTensor& x,
42-
const phi::DenseTensor& y,
43-
int axis,
44-
phi::DenseTensor* out) {
59+
const phi::DenseTensor& x,
60+
const phi::DenseTensor& y,
61+
int axis,
62+
phi::DenseTensor* out) {
4563
dev_ctx.template Alloc<bool>(out);
4664
const auto& runner = NpuOpRunner("NotEqual", {x, y}, {*out}, {});
4765
auto stream = dev_ctx.stream();
@@ -50,19 +68,18 @@ void NotEqualRawKernel(const Context& dev_ctx,
5068

5169
template <typename T, typename Context>
5270
void NotEqualKernel(const Context& dev_ctx,
53-
const phi::DenseTensor& x,
54-
const phi::DenseTensor& y,
55-
phi::DenseTensor* out) {
71+
const phi::DenseTensor& x,
72+
const phi::DenseTensor& y,
73+
phi::DenseTensor* out) {
5674
custom_kernel::NotEqualRawKernel<T, Context>(dev_ctx, x, y, -1, out);
5775
}
5876

59-
6077
template <typename T, typename Context>
6178
void LessEqualRawKernel(const Context& dev_ctx,
62-
const phi::DenseTensor& x,
63-
const phi::DenseTensor& y,
64-
int axis,
65-
phi::DenseTensor* out) {
79+
const phi::DenseTensor& x,
80+
const phi::DenseTensor& y,
81+
int axis,
82+
phi::DenseTensor* out) {
6683
dev_ctx.template Alloc<bool>(out);
6784
auto stream = dev_ctx.stream();
6885

@@ -72,18 +89,18 @@ void LessEqualRawKernel(const Context& dev_ctx,
7289

7390
template <typename T, typename Context>
7491
void LessEqualKernel(const Context& dev_ctx,
75-
const phi::DenseTensor& x,
76-
const phi::DenseTensor& y,
77-
phi::DenseTensor* out) {
92+
const phi::DenseTensor& x,
93+
const phi::DenseTensor& y,
94+
phi::DenseTensor* out) {
7895
custom_kernel::LessEqualRawKernel<T, Context>(dev_ctx, x, y, -1, out);
7996
}
8097

8198
template <typename T, typename Context>
8299
void LessThanRawKernel(const Context& dev_ctx,
83-
const phi::DenseTensor& x,
84-
const phi::DenseTensor& y,
85-
int axis,
86-
phi::DenseTensor* out) {
100+
const phi::DenseTensor& x,
101+
const phi::DenseTensor& y,
102+
int axis,
103+
phi::DenseTensor* out) {
87104
dev_ctx.template Alloc<bool>(out);
88105
const auto& runner = NpuOpRunner("Less", {x, y}, {*out}, {});
89106
auto stream = dev_ctx.stream();
@@ -92,18 +109,18 @@ void LessThanRawKernel(const Context& dev_ctx,
92109

93110
template <typename T, typename Context>
94111
void LessThanKernel(const Context& dev_ctx,
95-
const phi::DenseTensor& x,
96-
const phi::DenseTensor& y,
97-
phi::DenseTensor* out) {
112+
const phi::DenseTensor& x,
113+
const phi::DenseTensor& y,
114+
phi::DenseTensor* out) {
98115
custom_kernel::LessThanRawKernel<T, Context>(dev_ctx, x, y, -1, out);
99116
}
100117

101118
template <typename T, typename Context>
102119
void GreaterEqualRawKernel(const Context& dev_ctx,
103-
const phi::DenseTensor& x,
104-
const phi::DenseTensor& y,
105-
int axis,
106-
phi::DenseTensor* out) {
120+
const phi::DenseTensor& x,
121+
const phi::DenseTensor& y,
122+
int axis,
123+
phi::DenseTensor* out) {
107124
dev_ctx.template Alloc<bool>(out);
108125
const auto& runner = NpuOpRunner("GreaterEqual", {x, y}, {*out}, {});
109126
auto stream = dev_ctx.stream();
@@ -112,18 +129,18 @@ void GreaterEqualRawKernel(const Context& dev_ctx,
112129

113130
template <typename T, typename Context>
114131
void GreaterEqualKernel(const Context& dev_ctx,
115-
const phi::DenseTensor& x,
116-
const phi::DenseTensor& y,
117-
phi::DenseTensor* out) {
132+
const phi::DenseTensor& x,
133+
const phi::DenseTensor& y,
134+
phi::DenseTensor* out) {
118135
custom_kernel::GreaterEqualRawKernel<T, Context>(dev_ctx, x, y, -1, out);
119136
}
120137

121138
template <typename T, typename Context>
122139
void GreaterThanRawKernel(const Context& dev_ctx,
123-
const phi::DenseTensor& x,
124-
const phi::DenseTensor& y,
125-
int axis,
126-
phi::DenseTensor* out) {
140+
const phi::DenseTensor& x,
141+
const phi::DenseTensor& y,
142+
int axis,
143+
phi::DenseTensor* out) {
127144
dev_ctx.template Alloc<bool>(out);
128145
const auto& runner = NpuOpRunner("Greater", {x, y}, {*out}, {});
129146
auto stream = dev_ctx.stream();
@@ -132,9 +149,9 @@ void GreaterThanRawKernel(const Context& dev_ctx,
132149

133150
template <typename T, typename Context>
134151
void GreaterThanKernel(const Context& dev_ctx,
135-
const phi::DenseTensor& x,
136-
const phi::DenseTensor& y,
137-
phi::DenseTensor* out) {
152+
const phi::DenseTensor& x,
153+
const phi::DenseTensor& y,
154+
phi::DenseTensor* out) {
138155
custom_kernel::GreaterThanRawKernel<T, Context>(dev_ctx, x, y, -1, out);
139156
}
140157

@@ -176,7 +193,6 @@ PD_REGISTER_PLUGIN_KERNEL(not_equal,
176193
phi::dtype::float16,
177194
double) {}
178195

179-
180196
PD_REGISTER_PLUGIN_KERNEL(not_equal_raw,
181197
npu,
182198
ALL_LAYOUT,
@@ -282,4 +298,4 @@ PD_REGISTER_PLUGIN_KERNEL(greater_than_raw,
282298
int64_t,
283299
float,
284300
phi::dtype::float16,
285-
double) {}
301+
double) {}

backends/npu/tests/unittests/test_compare_op_npu.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def test_dynamic_api(self):
7171
out = op(x, y)
7272
self.assertEqual((out.numpy() == real_result).all(), True)
7373

74+
def test_dynamic_api_different_type(self):
75+
if op_type != 'equal':
76+
return
77+
paddle.disable_static()
78+
paddle.set_device('npu:0')
79+
x = np.random.random(size=(10, 7)).astype(typename)
80+
y = np.random.random(size=(10, 7)).astype('int32')
81+
real_result = callback(x, y)
82+
x = paddle.to_tensor(x, dtype=typename)
83+
y = paddle.to_tensor(y, dtype='float32')
84+
op = eval("paddle.%s" % (self.op_type))
85+
out = op(x, y)
86+
self.assertEqual((out.numpy() == real_result).all(), True)
87+
7488
@unittest.skipIf(typename == 'float16', "float16 is not supported now")
7589
def test_broadcast_api_1(self):
7690
paddle.enable_static()

0 commit comments

Comments
 (0)