@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include < cstdio>
15
16
#include " cub/cub.cuh"
16
17
#include " paddle/fluid/framework/op_registry.h"
17
18
#include " paddle/fluid/operators/top_k_op.h"
18
19
#include " paddle/fluid/platform/cuda_device_function.h"
19
20
#include " paddle/fluid/platform/float16.h"
20
-
21
21
// set cub base traits in order to handle float16
22
22
namespace cub {
23
23
template <>
@@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
300
300
}
301
301
}
302
302
303
+ template <typename T, int MaxLength, int BlockSize>
304
+ __global__ void AssignGrad (T* x_grad, const int64_t * indices, const T* out_grad,
305
+ size_t rows, size_t cols, size_t k) {
306
+ for (size_t i = 0 ; i < rows; ++i) {
307
+ for (size_t j = 0 ; j < cols; ++j) {
308
+ x_grad[i * cols + j] = 0 ;
309
+ }
310
+ for (size_t j = 0 ; j < k; ++j) {
311
+ size_t idx = indices[i * k + j];
312
+ x_grad[i * cols + idx] = out_grad[i * k + j];
313
+ }
314
+ }
315
+ }
316
+
303
317
inline static int GetDesiredBlockDim (int dim) {
304
318
if (dim > 128 ) {
305
319
return 256 ;
@@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
478
492
FIXED_BLOCK_DIM_BASE (64 , ##__VA_ARGS__); \
479
493
FIXED_BLOCK_DIM_BASE (32 , ##__VA_ARGS__)
480
494
481
- template <typename T>
495
+ template <typename DeviceContext, typename T>
482
496
class TopkOpCUDAKernel : public framework ::OpKernel<T> {
483
497
public:
484
498
void Compute (const framework::ExecutionContext& ctx) const override {
@@ -540,15 +554,70 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
540
554
}
541
555
};
542
556
557
+ template <typename DeviceContext, typename T>
558
+ class TopkOpGradCUDAKernel : public framework ::OpKernel<T> {
559
+ public:
560
+ void Compute (const framework::ExecutionContext& context) const override {
561
+ PADDLE_ENFORCE_EQ (
562
+ platform::is_gpu_place (context.GetPlace ()), true ,
563
+ platform::errors::InvalidArgument (" It must use CUDAPlace." ));
564
+ auto * x = context.Input <Tensor>(" X" );
565
+ auto * out_grad = context.Input <Tensor>(framework::GradVarName (" Out" ));
566
+ auto * indices = context.Input <Tensor>(" Indices" );
567
+ auto * x_grad = context.Output <Tensor>(framework::GradVarName (" X" ));
568
+
569
+ T* x_grad_data = x_grad->mutable_data <T>(context.GetPlace ());
570
+ const T* out_grad_data = out_grad->data <T>();
571
+ const int64_t * indices_data = indices->data <int64_t >();
572
+ size_t k = indices->dims ()[indices->dims ().size () - 1 ];
573
+
574
+ framework::DDim xdims = x->dims ();
575
+ const size_t row =
576
+ framework::product (framework::slice_ddim (xdims, 0 , xdims.size () - 1 ));
577
+ const size_t col = xdims[xdims.size () - 1 ];
578
+ const auto & dev_ctx = context.cuda_device_context ();
579
+
580
+ const int kMaxHeight = 2048 ;
581
+ int gridx = row < kMaxHeight ? row : kMaxHeight ;
582
+ switch (GetDesiredBlockDim (col)) {
583
+ FIXED_BLOCK_DIM (
584
+ AssignGrad<T, 5 ,
585
+ kBlockDim ><<<gridx, kBlockDim , 0 , dev_ctx.stream()>>> (
586
+ x_grad_data, indices_data, out_grad_data, row, col, k));
587
+ default :
588
+ PADDLE_THROW (
589
+ platform::errors::Unavailable (" Error occurs when Assign Grad." ));
590
+ }
591
+ }
592
+ };
543
593
#undef FIXED_BLOCK_DIM_BASE
544
594
#undef FIXED_BLOCK_DIM
545
595
546
596
} // namespace operators
547
597
} // namespace paddle
548
598
549
599
REGISTER_OP_CUDA_KERNEL (
550
- top_k, paddle::operators::TopkOpCUDAKernel<float >,
551
- paddle::operators::TopkOpCUDAKernel<double >,
552
- paddle::operators::TopkOpCUDAKernel<int >,
553
- paddle::operators::TopkOpCUDAKernel<int64_t >,
554
- paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
600
+ top_k,
601
+ paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
602
+ float >,
603
+ paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
604
+ double >,
605
+ paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
606
+ int >,
607
+ paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
608
+ int64_t >,
609
+ paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
610
+ paddle::platform::float16>);
611
+
612
+ REGISTER_OP_CUDA_KERNEL (
613
+ top_k_grad,
614
+ paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
615
+ float >,
616
+ paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
617
+ double >,
618
+ paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
619
+ int >,
620
+ paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
621
+ int64_t >,
622
+ paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
623
+ paddle::platform::float16>);
0 commit comments