@@ -55,14 +55,6 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
55
55
};
56
56
57
57
class NormOp : public framework ::OperatorWithKernel {
58
- protected:
59
- framework::OpKernelType GetKernelType (
60
- const framework::ExecutionContext& ctx) const override {
61
- return framework::OpKernelType (
62
- framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
63
- ctx.device_context ());
64
- }
65
-
66
58
public:
67
59
using framework::OperatorWithKernel::OperatorWithKernel;
68
60
void InferShape (framework::InferShapeContext* ctx) const override {
@@ -80,14 +72,6 @@ class NormOp : public framework::OperatorWithKernel {
80
72
};
81
73
82
74
class NormOpGrad : public framework ::OperatorWithKernel {
83
- protected:
84
- framework::OpKernelType GetKernelType (
85
- const framework::ExecutionContext& ctx) const override {
86
- return framework::OpKernelType (
87
- framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
88
- ctx.device_context ());
89
- }
90
-
91
75
public:
92
76
using framework::OperatorWithKernel::OperatorWithKernel;
93
77
void InferShape (framework::InferShapeContext* ctx) const override {
@@ -105,7 +89,7 @@ REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker<float>, norm_grad,
105
89
ops::NormOpGrad);
106
90
REGISTER_OP_CPU_KERNEL (
107
91
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float >,
108
- ops::NormKernel<paddle::platform::CPUDeviceContext, double >);
92
+ ops::NormKernel<paddle::platform::CPUDeviceContext, double , float >);
109
93
REGISTER_OP_CPU_KERNEL (
110
94
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float >,
111
- ops::NormGradKernel<paddle::platform::CPUDeviceContext, double >);
95
+ ops::NormGradKernel<paddle::platform::CPUDeviceContext, double , float >);
0 commit comments