Skip to content

Commit 41e77d4

Browse files
Optimize the op of c_softmax_with_cross_entropy (#71461)
1 parent be3d908 commit 41e77d4

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

paddle/phi/kernels/gpu/c_softmax_with_cross_entropy_kernel.cu

+25-21
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
#include "paddle/phi/backends/gpu/gpu_context.h"
1616
#include "paddle/phi/core/kernel_registry.h"
1717
#include "paddle/phi/core/platform/collective_helper.h"
18+
#include "paddle/phi/kernels/activation_kernel.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1820
#include "paddle/phi/kernels/funcs/axis_utils.h"
21+
#include "paddle/phi/kernels/funcs/broadcast_function.h"
1922
#include "paddle/phi/kernels/funcs/cross_entropy.h"
2023
#include "paddle/phi/kernels/funcs/eigen/common.h"
24+
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
2125
#include "paddle/phi/kernels/funcs/math.h"
2226
#include "paddle/phi/kernels/funcs/math_function.h"
2327
#include "paddle/phi/kernels/funcs/softmax.h"
2428
#include "paddle/phi/kernels/funcs/softmax_impl.h"
29+
#include "paddle/phi/kernels/reduce_max_kernel.h"
2530
#include "paddle/phi/kernels/reduce_sum_kernel.h"
2631
#include "paddle/utils/string/string_helper.h"
2732

@@ -213,36 +218,33 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
213218
softmax_2d.ShareDataWith(*softmax).Resize({N, D});
214219
loss_2d.ShareDataWith(*loss).Resize({N, 1});
215220

216-
auto eigen_logits = phi::funcs::EigenMatrix<T>::From(logits_2d);
217-
auto eigen_softmax = phi::funcs::EigenMatrix<T>::From(softmax_2d);
218-
219221
// step 1, obtain logit_max
220222
phi::DenseTensor logits_max;
221223
logits_max.Resize({N, 1});
222224
dev_ctx.template Alloc<T>(&logits_max);
223225

224-
auto eigen_logits_max = phi::funcs::EigenMatrix<T>::From(logits_max);
225-
Eigen::DSizes<int, 1> along_axis(1);
226-
eigen_logits_max.device(*dev_ctx.eigen_device()) =
227-
eigen_logits.maximum(along_axis);
226+
phi::MaxKernel<T, phi::GPUContext>(
227+
dev_ctx, logits_2d, {-1}, true, &logits_max);
228228

229229
comm_ctx->AllReduce(&logits_max, logits_max, ncclMax, stream);
230230

231231
// step 2, obtain logit - logit_max
232-
Eigen::DSizes<int, 2> batch_by_one(N, 1);
233-
Eigen::DSizes<int, 2> one_by_class(1, D);
234-
235-
eigen_softmax.device(*dev_ctx.eigen_device()) =
236-
(eigen_logits -
237-
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class));
232+
std::vector<const phi::DenseTensor*> inputs = {&logits_2d, &logits_max};
233+
std::vector<phi::DenseTensor*> outputs = {&softmax_2d};
234+
phi::funcs::BroadcastKernel<T>(
235+
dev_ctx, inputs, &outputs, phi::funcs::SubtractFunctor<T>());
238236

239237
// step 3, obtain predict target
240238
phi::DenseTensor predicted_logits;
241239
predicted_logits.Resize({N, 1});
242240
dev_ctx.template Alloc<T>(&predicted_logits);
243241

244-
auto t = phi::EigenVector<T>::Flatten(predicted_logits);
245-
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
242+
phi::FullKernel<T, phi::GPUContext>(
243+
dev_ctx,
244+
common::vectorize(predicted_logits.dims()),
245+
0,
246+
predicted_logits.dtype(),
247+
&predicted_logits);
246248

247249
const int64_t start_index = rank * D;
248250
const int64_t end_index = start_index + D;
@@ -309,7 +311,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
309311
comm_ctx->AllReduce(&predicted_logits, predicted_logits, ncclSum, stream);
310312

311313
// step 4, obtain exp(logit)
312-
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();
314+
phi::ExpKernel<T, phi::GPUContext>(dev_ctx, softmax_2d, &softmax_2d);
313315

314316
// step 5, obtain sum_exp_logits
315317
phi::DenseTensor sum_exp_logits;
@@ -362,11 +364,13 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
362364
}
363365
}
364366

365-
auto eigen_sum_exp_logits =
366-
phi::funcs::EigenMatrix<T>::From(sum_exp_logits);
367-
eigen_softmax.device(*dev_ctx.eigen_device()) =
368-
(eigen_softmax *
369-
eigen_sum_exp_logits.inverse().broadcast(one_by_class));
367+
phi::ReciprocalKernel<T, phi::GPUContext>(
368+
dev_ctx, sum_exp_logits, &sum_exp_logits);
369+
370+
inputs = std::vector<const phi::DenseTensor*>{&softmax_2d, &sum_exp_logits};
371+
outputs = std::vector<phi::DenseTensor*>{&softmax_2d};
372+
phi::funcs::BroadcastKernel<T>(
373+
dev_ctx, inputs, &outputs, phi::funcs::MultiplyFunctor<T>());
370374
#endif
371375
}
372376
};

0 commit comments

Comments
 (0)