Skip to content

Commit e12d1a1

Browse files
author
sweetsky0901
committed
for esp data type
1 parent 77b0bf4 commit e12d1a1

File tree

2 files changed

+4
-20
lines changed

2 files changed

+4
-20
lines changed

paddle/operators/norm_op.cc

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
5555
};
5656

5757
class NormOp : public framework::OperatorWithKernel {
58-
protected:
59-
framework::OpKernelType GetKernelType(
60-
const framework::ExecutionContext& ctx) const override {
61-
return framework::OpKernelType(
62-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
63-
ctx.device_context());
64-
}
65-
6658
public:
6759
using framework::OperatorWithKernel::OperatorWithKernel;
6860
void InferShape(framework::InferShapeContext* ctx) const override {
@@ -80,14 +72,6 @@ class NormOp : public framework::OperatorWithKernel {
8072
};
8173

8274
class NormOpGrad : public framework::OperatorWithKernel {
83-
protected:
84-
framework::OpKernelType GetKernelType(
85-
const framework::ExecutionContext& ctx) const override {
86-
return framework::OpKernelType(
87-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
88-
ctx.device_context());
89-
}
90-
9175
public:
9276
using framework::OperatorWithKernel::OperatorWithKernel;
9377
void InferShape(framework::InferShapeContext* ctx) const override {
@@ -105,7 +89,7 @@ REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker<float>, norm_grad,
10589
ops::NormOpGrad);
10690
REGISTER_OP_CPU_KERNEL(
10791
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
108-
ops::NormKernel<paddle::platform::CPUDeviceContext, double>);
92+
ops::NormKernel<paddle::platform::CPUDeviceContext, double, float>);
10993
REGISTER_OP_CPU_KERNEL(
11094
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
111-
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double>);
95+
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double, float>);

paddle/operators/norm_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818
namespace ops = paddle::operators;
1919
REGISTER_OP_CUDA_KERNEL(
2020
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
21-
ops::NormKernel<paddle::platform::CUDADeviceContext, double>);
21+
ops::NormKernel<paddle::platform::CUDADeviceContext, double, float>);
2222
REGISTER_OP_CUDA_KERNEL(
2323
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>,
24-
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double>);
24+
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double, float>);

0 commit comments

Comments
 (0)