16
16
17
17
namespace custom_kernel {
18
18
19
+ template <typename T, typename Context>
20
+ void CastKernel (const Context& dev_ctx,
21
+ const phi::DenseTensor& x,
22
+ phi::DenseTensorMeta::DataType dtype,
23
+ phi::DenseTensor* out);
24
+
19
25
template <typename T, typename Context>
20
26
void EqualRawKernel (const Context& dev_ctx,
21
- const phi::DenseTensor& x,
22
- const phi::DenseTensor& y,
23
- int axis,
24
- phi::DenseTensor* out) {
25
- dev_ctx.template Alloc <bool >(out);
26
- const auto & runner = NpuOpRunner (" Equal" , {x, y}, {*out}, {});
27
+ const phi::DenseTensor& x,
28
+ const phi::DenseTensor& y,
29
+ int axis,
30
+ phi::DenseTensor* out) {
27
31
auto stream = dev_ctx.stream ();
32
+ dev_ctx.template Alloc <bool >(out);
33
+
34
+ phi::DenseTensor transformed_x (x), transformed_y;
35
+ if (x.dtype () != y.dtype ()) {
36
+ phi::DenseTensorMeta meta = {x.dtype (), y.dims ()};
37
+ transformed_y.set_meta (meta);
38
+ custom_kernel::CastKernel<T, Context>(
39
+ dev_ctx, y, x.dtype (), &transformed_y);
40
+ } else {
41
+ transformed_y = y;
42
+ }
43
+
44
+ const auto & runner =
45
+ NpuOpRunner (" Equal" , {transformed_x, transformed_y}, {*out}, {});
28
46
runner.Run (stream);
29
47
}
30
48
@@ -38,10 +56,10 @@ void EqualKernel(const Context& dev_ctx,
38
56
39
57
template <typename T, typename Context>
40
58
void NotEqualRawKernel (const Context& dev_ctx,
41
- const phi::DenseTensor& x,
42
- const phi::DenseTensor& y,
43
- int axis,
44
- phi::DenseTensor* out) {
59
+ const phi::DenseTensor& x,
60
+ const phi::DenseTensor& y,
61
+ int axis,
62
+ phi::DenseTensor* out) {
45
63
dev_ctx.template Alloc <bool >(out);
46
64
const auto & runner = NpuOpRunner (" NotEqual" , {x, y}, {*out}, {});
47
65
auto stream = dev_ctx.stream ();
@@ -50,19 +68,18 @@ void NotEqualRawKernel(const Context& dev_ctx,
50
68
51
69
template <typename T, typename Context>
52
70
void NotEqualKernel (const Context& dev_ctx,
53
- const phi::DenseTensor& x,
54
- const phi::DenseTensor& y,
55
- phi::DenseTensor* out) {
71
+ const phi::DenseTensor& x,
72
+ const phi::DenseTensor& y,
73
+ phi::DenseTensor* out) {
56
74
custom_kernel::NotEqualRawKernel<T, Context>(dev_ctx, x, y, -1 , out);
57
75
}
58
76
59
-
60
77
template <typename T, typename Context>
61
78
void LessEqualRawKernel (const Context& dev_ctx,
62
- const phi::DenseTensor& x,
63
- const phi::DenseTensor& y,
64
- int axis,
65
- phi::DenseTensor* out) {
79
+ const phi::DenseTensor& x,
80
+ const phi::DenseTensor& y,
81
+ int axis,
82
+ phi::DenseTensor* out) {
66
83
dev_ctx.template Alloc <bool >(out);
67
84
auto stream = dev_ctx.stream ();
68
85
@@ -72,18 +89,18 @@ void LessEqualRawKernel(const Context& dev_ctx,
72
89
73
90
template <typename T, typename Context>
74
91
void LessEqualKernel (const Context& dev_ctx,
75
- const phi::DenseTensor& x,
76
- const phi::DenseTensor& y,
77
- phi::DenseTensor* out) {
92
+ const phi::DenseTensor& x,
93
+ const phi::DenseTensor& y,
94
+ phi::DenseTensor* out) {
78
95
custom_kernel::LessEqualRawKernel<T, Context>(dev_ctx, x, y, -1 , out);
79
96
}
80
97
81
98
template <typename T, typename Context>
82
99
void LessThanRawKernel (const Context& dev_ctx,
83
- const phi::DenseTensor& x,
84
- const phi::DenseTensor& y,
85
- int axis,
86
- phi::DenseTensor* out) {
100
+ const phi::DenseTensor& x,
101
+ const phi::DenseTensor& y,
102
+ int axis,
103
+ phi::DenseTensor* out) {
87
104
dev_ctx.template Alloc <bool >(out);
88
105
const auto & runner = NpuOpRunner (" Less" , {x, y}, {*out}, {});
89
106
auto stream = dev_ctx.stream ();
@@ -92,18 +109,18 @@ void LessThanRawKernel(const Context& dev_ctx,
92
109
93
110
template <typename T, typename Context>
94
111
void LessThanKernel (const Context& dev_ctx,
95
- const phi::DenseTensor& x,
96
- const phi::DenseTensor& y,
97
- phi::DenseTensor* out) {
112
+ const phi::DenseTensor& x,
113
+ const phi::DenseTensor& y,
114
+ phi::DenseTensor* out) {
98
115
custom_kernel::LessThanRawKernel<T, Context>(dev_ctx, x, y, -1 , out);
99
116
}
100
117
101
118
template <typename T, typename Context>
102
119
void GreaterEqualRawKernel (const Context& dev_ctx,
103
- const phi::DenseTensor& x,
104
- const phi::DenseTensor& y,
105
- int axis,
106
- phi::DenseTensor* out) {
120
+ const phi::DenseTensor& x,
121
+ const phi::DenseTensor& y,
122
+ int axis,
123
+ phi::DenseTensor* out) {
107
124
dev_ctx.template Alloc <bool >(out);
108
125
const auto & runner = NpuOpRunner (" GreaterEqual" , {x, y}, {*out}, {});
109
126
auto stream = dev_ctx.stream ();
@@ -112,18 +129,18 @@ void GreaterEqualRawKernel(const Context& dev_ctx,
112
129
113
130
template <typename T, typename Context>
114
131
void GreaterEqualKernel (const Context& dev_ctx,
115
- const phi::DenseTensor& x,
116
- const phi::DenseTensor& y,
117
- phi::DenseTensor* out) {
132
+ const phi::DenseTensor& x,
133
+ const phi::DenseTensor& y,
134
+ phi::DenseTensor* out) {
118
135
custom_kernel::GreaterEqualRawKernel<T, Context>(dev_ctx, x, y, -1 , out);
119
136
}
120
137
121
138
template <typename T, typename Context>
122
139
void GreaterThanRawKernel (const Context& dev_ctx,
123
- const phi::DenseTensor& x,
124
- const phi::DenseTensor& y,
125
- int axis,
126
- phi::DenseTensor* out) {
140
+ const phi::DenseTensor& x,
141
+ const phi::DenseTensor& y,
142
+ int axis,
143
+ phi::DenseTensor* out) {
127
144
dev_ctx.template Alloc <bool >(out);
128
145
const auto & runner = NpuOpRunner (" Greater" , {x, y}, {*out}, {});
129
146
auto stream = dev_ctx.stream ();
@@ -132,9 +149,9 @@ void GreaterThanRawKernel(const Context& dev_ctx,
132
149
133
150
template <typename T, typename Context>
134
151
void GreaterThanKernel (const Context& dev_ctx,
135
- const phi::DenseTensor& x,
136
- const phi::DenseTensor& y,
137
- phi::DenseTensor* out) {
152
+ const phi::DenseTensor& x,
153
+ const phi::DenseTensor& y,
154
+ phi::DenseTensor* out) {
138
155
custom_kernel::GreaterThanRawKernel<T, Context>(dev_ctx, x, y, -1 , out);
139
156
}
140
157
@@ -176,7 +193,6 @@ PD_REGISTER_PLUGIN_KERNEL(not_equal,
176
193
phi::dtype::float16,
177
194
double ) {}
178
195
179
-
180
196
PD_REGISTER_PLUGIN_KERNEL (not_equal_raw,
181
197
npu,
182
198
ALL_LAYOUT,
@@ -282,4 +298,4 @@ PD_REGISTER_PLUGIN_KERNEL(greater_than_raw,
282
298
int64_t ,
283
299
float ,
284
300
phi::dtype::float16,
285
- double ) {}
301
+ double ) {}
0 commit comments