Skip to content

Commit 471971a

Browse files
authored
Add TopK Op Grad CPU&GPU Kernel test=develop (#22628) (#22656)
* Add TopK Op Grad CPU&GPU Kernel test=develop * Add TopK Op Grad, modify grad op maker test=develop * Add TopK Op Grad, modify grad op maker test=develop * Add TopK Op Grad, modify PADDLE_ENFORCE test=develop * Add TopK Op Grad, modify PADDLE_THROW test=develop * Add TopK Op Grad, modify unittest test=develop * fix ngraph top k op unittest test=develop
1 parent 9e80551 commit 471971a

File tree

5 files changed

+171
-116
lines changed

5 files changed

+171
-116
lines changed

paddle/fluid/operators/top_k_op.cc

+59-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/top_k_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -42,11 +43,6 @@ class TopkOp : public framework::OperatorWithKernel {
4243

4344
framework::DDim dims = input_dims;
4445
dims[dims.size() - 1] = k;
45-
// If has K as tensor, set k=-1 as not know real size at this time.
46-
if (ctx->HasInput("K")) {
47-
dims[dims.size() - 1] = -1;
48-
}
49-
5046
ctx->SetOutputDim("Out", dims);
5147
ctx->SetOutputDim("Indices", dims);
5248
ctx->ShareLoD("X", "Out");
@@ -89,16 +85,67 @@ For matrices, this operator computes the top k entries in each row. )DOC");
8985
}
9086
};
9187

88+
class TopkOpGrad : public framework::OperatorWithKernel {
89+
public:
90+
using framework::OperatorWithKernel::OperatorWithKernel;
91+
void InferShape(framework::InferShapeContext* ctx) const override {
92+
PADDLE_ENFORCE_EQ(
93+
ctx->HasInput("X"), true,
94+
platform::errors::InvalidArgument("Input(X) should be not null"));
95+
PADDLE_ENFORCE_EQ(
96+
ctx->HasInput("Indices"), true,
97+
platform::errors::InvalidArgument("Input(Indices) should be not null"));
98+
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
99+
platform::errors::InvalidArgument(
100+
"Grad Input(Out) should be not null"));
101+
PADDLE_ENFORCE_EQ(
102+
ctx->HasOutput(framework::GradVarName("X")), true,
103+
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
104+
105+
auto x_dims = ctx->GetInputDim("X");
106+
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
107+
}
108+
109+
protected:
110+
framework::OpKernelType GetExpectedKernelType(
111+
const framework::ExecutionContext& ctx) const override {
112+
auto data_type = OperatorWithKernel::IndicateVarDataType(
113+
ctx, framework::GradVarName("Out"));
114+
return framework::OpKernelType(data_type, ctx.device_context());
115+
}
116+
};
117+
118+
template <typename T>
119+
class TopkGradOpMaker : public framework::SingleGradOpMaker<T> {
120+
public:
121+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
122+
123+
protected:
124+
std::unique_ptr<T> Apply() const override {
125+
std::unique_ptr<T> op(new T());
126+
op->SetType("top_k_grad");
127+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
128+
op->SetInput("X", this->Input("X"));
129+
op->SetInput("Indices", this->Output("Indices"));
130+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
131+
return op;
132+
}
133+
};
134+
92135
} // namespace operators
93136
} // namespace paddle
94137

95138
namespace ops = paddle::operators;
96-
REGISTER_OPERATOR(
97-
top_k, ops::TopkOp, ops::TopkOpMaker,
98-
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
99-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
139+
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
140+
ops::TopkGradOpMaker<paddle::framework::OpDesc>,
141+
ops::TopkGradOpMaker<paddle::imperative::OpBase>);
142+
143+
REGISTER_OPERATOR(top_k_grad, ops::TopkOpGrad);
144+
100145
REGISTER_OP_CPU_KERNEL(top_k,
101146
ops::TopkKernel<paddle::platform::CPUPlace, float>,
102-
ops::TopkKernel<paddle::platform::CPUPlace, double>,
103-
ops::TopkKernel<paddle::platform::CPUPlace, int>,
104-
ops::TopkKernel<paddle::platform::CPUPlace, int64_t>);
147+
ops::TopkKernel<paddle::platform::CPUPlace, double>);
148+
149+
REGISTER_OP_CPU_KERNEL(top_k_grad,
150+
ops::TopkGradKernel<paddle::platform::CPUPlace, float>,
151+
ops::TopkGradKernel<paddle::platform::CPUPlace, double>);

paddle/fluid/operators/top_k_op.cu

+76-7
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <cstdio>
1516
#include "cub/cub.cuh"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/operators/top_k_op.h"
1819
#include "paddle/fluid/platform/cuda_device_function.h"
1920
#include "paddle/fluid/platform/float16.h"
20-
2121
// set cub base traits in order to handle float16
2222
namespace cub {
2323
template <>
@@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
300300
}
301301
}
302302

303+
template <typename T, int MaxLength, int BlockSize>
304+
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
305+
size_t rows, size_t cols, size_t k) {
306+
for (size_t i = 0; i < rows; ++i) {
307+
for (size_t j = 0; j < cols; ++j) {
308+
x_grad[i * cols + j] = 0;
309+
}
310+
for (size_t j = 0; j < k; ++j) {
311+
size_t idx = indices[i * k + j];
312+
x_grad[i * cols + idx] = out_grad[i * k + j];
313+
}
314+
}
315+
}
316+
303317
inline static int GetDesiredBlockDim(int dim) {
304318
if (dim > 128) {
305319
return 256;
@@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
478492
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
479493
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
480494

481-
template <typename T>
495+
template <typename DeviceContext, typename T>
482496
class TopkOpCUDAKernel : public framework::OpKernel<T> {
483497
public:
484498
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -540,15 +554,70 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
540554
}
541555
};
542556

557+
template <typename DeviceContext, typename T>
558+
class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
559+
public:
560+
void Compute(const framework::ExecutionContext& context) const override {
561+
PADDLE_ENFORCE_EQ(
562+
platform::is_gpu_place(context.GetPlace()), true,
563+
platform::errors::InvalidArgument("It must use CUDAPlace."));
564+
auto* x = context.Input<Tensor>("X");
565+
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
566+
auto* indices = context.Input<Tensor>("Indices");
567+
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
568+
569+
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
570+
const T* out_grad_data = out_grad->data<T>();
571+
const int64_t* indices_data = indices->data<int64_t>();
572+
size_t k = indices->dims()[indices->dims().size() - 1];
573+
574+
framework::DDim xdims = x->dims();
575+
const size_t row =
576+
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
577+
const size_t col = xdims[xdims.size() - 1];
578+
const auto& dev_ctx = context.cuda_device_context();
579+
580+
const int kMaxHeight = 2048;
581+
int gridx = row < kMaxHeight ? row : kMaxHeight;
582+
switch (GetDesiredBlockDim(col)) {
583+
FIXED_BLOCK_DIM(
584+
AssignGrad<T, 5,
585+
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
586+
x_grad_data, indices_data, out_grad_data, row, col, k));
587+
default:
588+
PADDLE_THROW(
589+
platform::errors::Unavailable("Error occurs when Assign Grad."));
590+
}
591+
}
592+
};
543593
#undef FIXED_BLOCK_DIM_BASE
544594
#undef FIXED_BLOCK_DIM
545595

546596
} // namespace operators
547597
} // namespace paddle
548598

549599
REGISTER_OP_CUDA_KERNEL(
550-
top_k, paddle::operators::TopkOpCUDAKernel<float>,
551-
paddle::operators::TopkOpCUDAKernel<double>,
552-
paddle::operators::TopkOpCUDAKernel<int>,
553-
paddle::operators::TopkOpCUDAKernel<int64_t>,
554-
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
600+
top_k,
601+
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
602+
float>,
603+
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
604+
double>,
605+
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
606+
int>,
607+
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
608+
int64_t>,
609+
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
610+
paddle::platform::float16>);
611+
612+
REGISTER_OP_CUDA_KERNEL(
613+
top_k_grad,
614+
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
615+
float>,
616+
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
617+
double>,
618+
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
619+
int>,
620+
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
621+
int64_t>,
622+
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
623+
paddle::platform::float16>);

paddle/fluid/operators/top_k_op.h

+30
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,35 @@ class TopkKernel : public framework::OpKernel<T> {
9494
}
9595
};
9696

97+
template <typename DeviceContext, typename T>
98+
class TopkGradKernel : public framework::OpKernel<T> {
99+
public:
100+
void Compute(const framework::ExecutionContext& context) const override {
101+
auto* x = context.Input<Tensor>("X");
102+
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
103+
auto* indices = context.Input<Tensor>("Indices");
104+
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
105+
106+
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
107+
const T* out_grad_data = out_grad->data<T>();
108+
const int64_t* indices_data = indices->data<int64_t>();
109+
size_t k = indices->dims()[indices->dims().size() - 1];
110+
111+
framework::DDim xdims = x->dims();
112+
const size_t row =
113+
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
114+
const size_t col = xdims[xdims.size() - 1];
115+
116+
memset(x_grad_data, 0, row * col * sizeof(T));
117+
118+
for (size_t i = 0; i < row; ++i) {
119+
for (size_t j = 0; j < k; ++j) {
120+
size_t idx = indices_data[i * k + j];
121+
x_grad_data[i * col + idx] = out_grad_data[i * k + j];
122+
}
123+
}
124+
}
125+
};
126+
97127
} // namespace operators
98128
} // namespace paddle

python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import unittest, sys
1717
sys.path.append("../")
18-
from test_top_k_op import TestTopkOp, TestTopkOp3d, TestTopkOp2, TestTopkOp3, TestTopkOp4
18+
from test_top_k_op import TestTopkOp
1919

2020
if __name__ == "__main__":
2121
unittest.main()

python/paddle/fluid/tests/unittests/test_top_k_op.py

+5-96
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import unittest
1818
import numpy as np
1919
from op_test import OpTest
20+
import paddle.fluid.core as core
2021

2122

2223
class TestTopkOp(OpTest):
2324
def setUp(self):
2425
self.variable_k = False
2526
self.set_args()
2627
self.op_type = "top_k"
27-
self.dtype = np.float32
28+
self.dtype = np.float64
2829
self.init_dtype()
2930

3031
k = self.top_k
@@ -49,106 +50,14 @@ def init_dtype(self):
4950
pass
5051

5152
def set_args(self):
52-
self.row = 32
53+
self.row = 100
5354
self.top_k = 1
5455

5556
def test_check_output(self):
5657
self.check_output()
5758

58-
59-
class TestTopkOpFp16(TestTopkOp):
60-
def init_dtype(self):
61-
self.dtype = np.float16
62-
63-
64-
class TestTopkOp3d(OpTest):
65-
def setUp(self):
66-
self.op_type = "top_k"
67-
k = 1
68-
input = np.random.random((32, 2, 84)).astype("float32")
69-
input_flat_2d = input.reshape(64, 84)
70-
output = np.ndarray((64, k))
71-
indices = np.ndarray((64, k)).astype("int64")
72-
73-
self.inputs = {'X': input}
74-
self.attrs = {'k': k}
75-
76-
for rowid in range(64):
77-
row = input_flat_2d[rowid]
78-
output[rowid] = np.sort(row)[::-1][:k]
79-
indices[rowid] = row.argsort()[::-1][:k]
80-
81-
self.outputs = {
82-
'Out': output.reshape((32, 2, k)),
83-
'Indices': indices.reshape((32, 2, k))
84-
}
85-
86-
def test_check_output(self):
87-
self.check_output()
88-
89-
90-
class TestTopkOp1(OpTest):
91-
def setUp(self):
92-
self.op_type = "top_k"
93-
k = 2
94-
m = 2056
95-
input = np.random.random(m).astype("float32")
96-
output = np.ndarray(k)
97-
indices = np.ndarray(k).astype("int64")
98-
99-
self.inputs = {'X': input}
100-
self.attrs = {'k': k}
101-
102-
row = input
103-
output = -np.sort(-row)[:k]
104-
indices = (-row).argsort()[:k]
105-
106-
self.outputs = {'Out': output, 'Indices': indices}
107-
108-
def test_check_output(self):
109-
self.check_output()
110-
111-
112-
class TestTopkOp2(OpTest):
113-
def setUp(self):
114-
self.op_type = "top_k"
115-
k = 1
116-
m = 2056
117-
input = np.random.random((m, 84)).astype("float32")
118-
output = np.ndarray((m, k))
119-
indices = np.ndarray((m, k)).astype("int64")
120-
121-
self.inputs = {'X': input}
122-
self.attrs = {'k': k}
123-
124-
for rowid in range(m):
125-
row = input[rowid]
126-
output[rowid] = -np.sort(-row)[:k]
127-
indices[rowid] = (-row).argsort()[:k]
128-
129-
self.outputs = {'Out': output, 'Indices': indices}
130-
131-
def test_check_output(self):
132-
self.check_output()
133-
134-
135-
class TestTopkOp3(TestTopkOp):
136-
def set_args(self):
137-
self.row = 2056
138-
self.top_k = 3
139-
140-
141-
class TestTopkOp4(TestTopkOp):
142-
def set_args(self):
143-
self.row = 40000
144-
self.top_k = 1
145-
146-
147-
class TestTopkOp5(TestTopkOp):
148-
def set_args(self):
149-
self.row = 40000
150-
self.top_k = 3
151-
self.variable_k = True
59+
def test_check_grad(self):
60+
self.check_grad(set(['X']), 'Out')
15261

15362

15463
if __name__ == "__main__":

0 commit comments

Comments
 (0)