Skip to content

Commit c08dbdc

Browse files
[NPU] add nll_loss kernel (#1319)
1 parent a7f19ce commit c08dbdc

File tree

2 files changed

+538
-0
lines changed

2 files changed

+538
-0
lines changed
+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "kernels/funcs/npu_funcs.h"
16+
#include "kernels/funcs/npu_op_runner.h"
17+
#include "kernels/funcs/slice_utils.h"
18+
19+
namespace custom_kernel {
20+
template <typename T, typename Context>
21+
void CastKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
phi::DataType dtype,
24+
phi::DenseTensor* out);
25+
26+
template <typename T, typename Context>
27+
void NllLossRawKernel(const Context& dev_ctx,
28+
const phi::DenseTensor& x,
29+
const phi::DenseTensor& labels,
30+
const paddle::optional<phi::DenseTensor>& weight,
31+
int64_t ignore_index,
32+
const std::string& reduction,
33+
phi::DenseTensor* out,
34+
phi::DenseTensor* total_weight) {
35+
auto x_dims = x.dims();
36+
phi::Scalar weight_default = 1.0;
37+
int64_t reduction_int = 1;
38+
if (reduction == "none") {
39+
reduction_int = 0;
40+
} else if (reduction == "sum") {
41+
reduction_int = 2;
42+
}
43+
44+
phi::DenseTensor weight_tensor;
45+
auto weight_size = phi::make_ddim({x.dims()[1]});
46+
if (weight.get_ptr() == nullptr) {
47+
weight_tensor.ResizeAndAllocate(weight_size);
48+
dev_ctx.template Alloc<float>(&weight_tensor);
49+
EXEC_NPU_CMD(
50+
aclnnInplaceFillScalar, dev_ctx, weight_tensor, weight_default);
51+
} else {
52+
weight_tensor = *weight.get_ptr();
53+
}
54+
55+
bool need_resize = false;
56+
if (x_dims.size() == 4 && total_weight->dims().size() == 0) {
57+
total_weight->Resize(phi::make_ddim({1}));
58+
need_resize = true;
59+
}
60+
dev_ctx.template Alloc<T>(out);
61+
dev_ctx.template Alloc<T>(total_weight);
62+
63+
if (x.dtype() == phi::DataType::FLOAT32) {
64+
if (x_dims.size() == 2) {
65+
EXEC_NPU_CMD(aclnnNLLLoss,
66+
dev_ctx,
67+
x,
68+
labels,
69+
weight_tensor,
70+
reduction_int,
71+
ignore_index,
72+
*out,
73+
*total_weight);
74+
} else if (x_dims.size() == 4) {
75+
EXEC_NPU_CMD(aclnnNLLLoss2d,
76+
dev_ctx,
77+
x,
78+
labels,
79+
weight_tensor,
80+
reduction_int,
81+
ignore_index,
82+
*out,
83+
*total_weight);
84+
}
85+
86+
if (need_resize) {
87+
total_weight->Resize(phi::make_ddim({}));
88+
}
89+
} else {
90+
// data trans: double to float32
91+
phi::DenseTensor x_cast, weight_tensor_cast, out_cast, total_weight_cast;
92+
phi::DenseTensorMeta x_cast_meta;
93+
phi::DenseTensorMeta weight_tensor_cast_meta;
94+
phi::DenseTensorMeta out_cast_meta;
95+
phi::DenseTensorMeta total_weight_cast_meta;
96+
97+
x_cast_meta = {phi::DataType::FLOAT32, x.dims()};
98+
weight_tensor_cast_meta = {phi::DataType::FLOAT32, weight_tensor.dims()};
99+
out_cast_meta = {phi::DataType::FLOAT32, out->dims()};
100+
total_weight_cast_meta = {phi::DataType::FLOAT32, total_weight->dims()};
101+
102+
x_cast.set_meta(x_cast_meta);
103+
weight_tensor_cast.set_meta(weight_tensor_cast_meta);
104+
out_cast.set_meta(out_cast_meta);
105+
total_weight_cast.set_meta(total_weight_cast_meta);
106+
107+
dev_ctx.template Alloc<float>(&out_cast);
108+
dev_ctx.template Alloc<float>(&total_weight_cast);
109+
custom_kernel::CastKernel<T, Context>(
110+
dev_ctx, x, phi::DataType::FLOAT32, &x_cast);
111+
custom_kernel::CastKernel<T, Context>(
112+
dev_ctx, weight_tensor, phi::DataType::FLOAT32, &weight_tensor_cast);
113+
114+
if (x_dims.size() == 2) {
115+
EXEC_NPU_CMD(aclnnNLLLoss,
116+
dev_ctx,
117+
x_cast,
118+
labels,
119+
weight_tensor_cast,
120+
reduction_int,
121+
ignore_index,
122+
out_cast,
123+
total_weight_cast);
124+
} else if (x_dims.size() == 4) {
125+
EXEC_NPU_CMD(aclnnNLLLoss2d,
126+
dev_ctx,
127+
x_cast,
128+
labels,
129+
weight_tensor_cast,
130+
reduction_int,
131+
ignore_index,
132+
out_cast,
133+
total_weight_cast);
134+
}
135+
136+
custom_kernel::CastKernel<T, Context>(dev_ctx, out_cast, out->dtype(), out);
137+
custom_kernel::CastKernel<T, Context>(
138+
dev_ctx, total_weight_cast, total_weight->dtype(), total_weight);
139+
140+
if (need_resize) {
141+
total_weight->Resize(phi::make_ddim({}));
142+
}
143+
}
144+
}
145+
146+
template <typename T, typename Context>
147+
void NllLossGradKernel(const Context& dev_ctx,
148+
const phi::DenseTensor& x,
149+
const phi::DenseTensor& labels,
150+
const paddle::optional<phi::DenseTensor>& weight,
151+
const phi::DenseTensor& total_weight,
152+
const phi::DenseTensor& d_out,
153+
int64_t ignore_index,
154+
const std::string& reduction,
155+
phi::DenseTensor* dx) {
156+
auto x_dims = x.dims();
157+
phi::Scalar weight_default = 1.0;
158+
int64_t reduction_int = 1;
159+
if (reduction == "none") {
160+
reduction_int = 0;
161+
} else if (reduction == "sum") {
162+
reduction_int = 2;
163+
}
164+
165+
phi::DenseTensor weight_tensor;
166+
auto weight_size = phi::make_ddim({x.dims()[1]});
167+
if (weight.get_ptr() == nullptr) {
168+
weight_tensor.ResizeAndAllocate(weight_size);
169+
dev_ctx.template Alloc<float>(&weight_tensor);
170+
EXEC_NPU_CMD(
171+
aclnnInplaceFillScalar, dev_ctx, weight_tensor, weight_default);
172+
} else {
173+
weight_tensor = *weight.get_ptr();
174+
}
175+
dev_ctx.template Alloc<T>(dx);
176+
177+
phi::DenseTensor total_weight_new;
178+
if (x_dims.size() == 4) {
179+
phi::DenseTensorMeta total_weight_new_meta = {phi::DataType::FLOAT32,
180+
phi::make_ddim({1})};
181+
total_weight_new.set_meta(total_weight_new_meta);
182+
TensorCopy(dev_ctx, total_weight, true, &total_weight_new);
183+
total_weight_new.Resize(phi::make_ddim({1}));
184+
}
185+
186+
if (x.dtype() == phi::DataType::FLOAT32) {
187+
if (x_dims.size() == 2) {
188+
EXEC_NPU_CMD(aclnnNLLLossBackward,
189+
dev_ctx,
190+
d_out,
191+
x,
192+
labels,
193+
weight_tensor,
194+
reduction_int,
195+
ignore_index,
196+
total_weight,
197+
*dx);
198+
} else if (x_dims.size() == 4) {
199+
if (d_out.dims().size() == 0) {
200+
phi::DenseTensor d_out_new;
201+
phi::DenseTensorMeta d_out_new_meta = {phi::DataType::FLOAT32,
202+
phi::make_ddim({1})};
203+
d_out_new.set_meta(d_out_new_meta);
204+
TensorCopy(dev_ctx, d_out, true, &d_out_new);
205+
d_out_new.Resize(phi::make_ddim({1}));
206+
207+
EXEC_NPU_CMD(aclnnNLLLoss2dBackward,
208+
dev_ctx,
209+
d_out_new,
210+
x,
211+
labels,
212+
weight_tensor,
213+
reduction_int,
214+
ignore_index,
215+
total_weight_new,
216+
*dx);
217+
} else {
218+
EXEC_NPU_CMD(aclnnNLLLoss2dBackward,
219+
dev_ctx,
220+
d_out,
221+
x,
222+
labels,
223+
weight_tensor,
224+
reduction_int,
225+
ignore_index,
226+
total_weight_new,
227+
*dx);
228+
}
229+
}
230+
} else {
231+
// data trans: double to float32
232+
phi::DenseTensor d_out_cast, x_cast, weight_tensor_cast, total_weight_cast,
233+
dx_cast;
234+
phi::DenseTensorMeta d_out_cast_meta;
235+
phi::DenseTensorMeta x_cast_meta;
236+
phi::DenseTensorMeta weight_tensor_cast_meta;
237+
phi::DenseTensorMeta total_weight_cast_meta;
238+
phi::DenseTensorMeta dx_cast_meta;
239+
240+
d_out_cast_meta = {phi::DataType::FLOAT32, d_out.dims()};
241+
x_cast_meta = {phi::DataType::FLOAT32, x.dims()};
242+
weight_tensor_cast_meta = {phi::DataType::FLOAT32, weight_tensor.dims()};
243+
total_weight_cast_meta = {phi::DataType::FLOAT32, total_weight.dims()};
244+
dx_cast_meta = {phi::DataType::FLOAT32, dx->dims()};
245+
246+
d_out_cast.set_meta(d_out_cast_meta);
247+
x_cast.set_meta(x_cast_meta);
248+
weight_tensor_cast.set_meta(weight_tensor_cast_meta);
249+
total_weight_cast.set_meta(total_weight_cast_meta);
250+
dx_cast.set_meta(dx_cast_meta);
251+
252+
dev_ctx.template Alloc<float>(&dx_cast);
253+
custom_kernel::CastKernel<T, Context>(
254+
dev_ctx, d_out, phi::DataType::FLOAT32, &d_out_cast);
255+
custom_kernel::CastKernel<T, Context>(
256+
dev_ctx, x, phi::DataType::FLOAT32, &x_cast);
257+
custom_kernel::CastKernel<T, Context>(
258+
dev_ctx, weight_tensor, phi::DataType::FLOAT32, &weight_tensor_cast);
259+
custom_kernel::CastKernel<T, Context>(
260+
dev_ctx, total_weight, phi::DataType::FLOAT32, &total_weight_cast);
261+
262+
if (x_dims.size() == 4 && total_weight_cast.dims().size() == 0) {
263+
total_weight_cast.Resize(phi::make_ddim({1}));
264+
}
265+
266+
if (x_dims.size() == 4 && d_out_cast.dims().size() == 0) {
267+
d_out_cast.Resize(phi::make_ddim({1}));
268+
}
269+
270+
if (x_dims.size() == 2) {
271+
EXEC_NPU_CMD(aclnnNLLLossBackward,
272+
dev_ctx,
273+
d_out_cast,
274+
x_cast,
275+
labels,
276+
weight_tensor_cast,
277+
reduction_int,
278+
ignore_index,
279+
total_weight_cast,
280+
dx_cast);
281+
} else if (x_dims.size() == 4) {
282+
EXEC_NPU_CMD(aclnnNLLLoss2dBackward,
283+
dev_ctx,
284+
d_out_cast,
285+
x_cast,
286+
labels,
287+
weight_tensor_cast,
288+
reduction_int,
289+
ignore_index,
290+
total_weight_cast,
291+
dx_cast);
292+
}
293+
294+
custom_kernel::CastKernel<T, Context>(dev_ctx, dx_cast, dx->dtype(), dx);
295+
}
296+
}
297+
} // namespace custom_kernel
298+
299+
PD_REGISTER_PLUGIN_KERNEL(
300+
nll_loss, npu, ALL_LAYOUT, custom_kernel::NllLossRawKernel, float, double) {
301+
}
302+
303+
PD_REGISTER_PLUGIN_KERNEL(nll_loss_grad,
304+
npu,
305+
ALL_LAYOUT,
306+
custom_kernel::NllLossGradKernel,
307+
float,
308+
double) {}

0 commit comments

Comments
 (0)