@@ -13,11 +13,110 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/optimizers/lars_momentum_op.h"
16
- #include " paddle/fluid/operators/optimizers/momentum_op.h"
17
16
18
17
namespace paddle {
19
18
namespace operators {
20
19
20
+ class LarsMomentumOp : public framework ::OperatorWithKernel {
21
+ public:
22
+ using framework::OperatorWithKernel::OperatorWithKernel;
23
+
24
+ protected:
25
+ void InferShape (framework::InferShapeContext* ctx) const override {
26
+ PADDLE_ENFORCE_EQ (ctx->HasInputs (" Param" ), true ,
27
+ platform::errors::NotFound (
28
+ " Inputs(param) of LarsMomentum should not be null." ));
29
+ PADDLE_ENFORCE_EQ (ctx->HasInputs (" Grad" ), true ,
30
+ platform::errors::NotFound (
31
+ " Input(grad) of LarsMomentum should not be null." ));
32
+ PADDLE_ENFORCE_EQ (
33
+ ctx->HasInputs (" Velocity" ), true ,
34
+ platform::errors::NotFound (
35
+ " Inputs(velocity) of LarsMomentum should not be null." ));
36
+ PADDLE_ENFORCE_EQ (
37
+ ctx->HasInputs (" LearningRate" ), true ,
38
+ platform::errors::NotFound (
39
+ " Input(LearningRate) of LarsMomentum should not be null." ));
40
+ PADDLE_ENFORCE_EQ (
41
+ ctx->GetInputsVarType (" Param" ).front (),
42
+ framework::proto::VarType::LOD_TENSOR,
43
+ platform::errors::InvalidArgument (
44
+ " The input var's type should be LoDTensor, but the received is %s" ,
45
+ ctx->GetInputsVarType (" Param" ).front ()));
46
+
47
+ PADDLE_ENFORCE_EQ (ctx->HasOutputs (" ParamOut" ), true ,
48
+ platform::errors::NotFound (
49
+ " Output(ParamOut) of Momentum should not be null." ));
50
+ PADDLE_ENFORCE_EQ (
51
+ ctx->HasOutputs (" VelocityOut" ), true ,
52
+ platform::errors::NotFound (
53
+ " Output(VelocityOut) of Momentum should not be null." ));
54
+
55
+ auto lr_dims = ctx->GetInputsDim (" LearningRate" );
56
+ for (size_t i = 0 ; i < lr_dims.size (); ++i) {
57
+ PADDLE_ENFORCE_NE (framework::product (lr_dims[i]), 0 ,
58
+ platform::errors::InvalidArgument (
59
+ " Maybe the Input variable LearningRate has not "
60
+ " been initialized. You may need to confirm "
61
+ " whether exe.run(startup_program) is put "
62
+ " after optimizer.minimize function." ));
63
+ PADDLE_ENFORCE_EQ (framework::product (lr_dims[i]), 1 ,
64
+ platform::errors::InvalidArgument (
65
+ " Learning_rate should be a scalar. But Received "
66
+ " LearningRate's dim [%s]" ,
67
+ framework::product (lr_dims[i])));
68
+ }
69
+
70
+ auto param_dim = ctx->GetInputsDim (" Param" );
71
+ auto grad_dim = ctx->GetInputsDim (" Grad" );
72
+ auto velocity_dim = ctx->GetInputsDim (" Velocity" );
73
+ PADDLE_ENFORCE_EQ (
74
+ param_dim.size (), grad_dim.size (),
75
+ platform::errors::InvalidArgument (
76
+ " Param and Grad input of LarsMomentumOp should have the same "
77
+ " quantity. But number of Param is [%d] and Grad is [%d]." ,
78
+ param_dim.size (), grad_dim.size ()));
79
+ PADDLE_ENFORCE_EQ (
80
+ param_dim.size (), velocity_dim.size (),
81
+ platform::errors::InvalidArgument (
82
+ " Param and Velocity input of LarsMomentumOp should have the same "
83
+ " quantity. But number of Param is [%d] and Velocity is [%d]." ,
84
+ param_dim.size (), velocity_dim.size ()));
85
+
86
+ if (ctx->GetInputsVarType (" Grad" )[0 ] ==
87
+ framework::proto::VarType::LOD_TENSOR) {
88
+ for (size_t i = 0 ; i < param_dim.size (); ++i) {
89
+ PADDLE_ENFORCE_EQ (
90
+ param_dim[i], grad_dim[i],
91
+ platform::errors::InvalidArgument (
92
+ " Param and Grad input of MomentumOp should have the same "
93
+ " dimension. But received Param's dim [%s] and Grad's dim [%s]." ,
94
+ param_dim[i], grad_dim[i]));
95
+ PADDLE_ENFORCE_EQ (
96
+ param_dim[i], velocity_dim[i],
97
+ platform::errors::InvalidArgument (
98
+ " Param and Velocity of MomentumOp should have the same "
99
+ " dimension. But received Param's dim [%s] and Velocity [%s]." ,
100
+ param_dim[i], velocity_dim[i]));
101
+ }
102
+ }
103
+
104
+ ctx->SetOutputsDim (" ParamOut" , param_dim);
105
+ ctx->SetOutputsDim (" VelocityOut" , param_dim);
106
+ if (ctx->HasOutputs (" MasterParamOut" )) {
107
+ ctx->SetOutputsDim (" MasterParamOut" , param_dim);
108
+ }
109
+ }
110
+
111
+ protected:
112
+ framework::OpKernelType GetExpectedKernelType (
113
+ const framework::ExecutionContext& ctx) const override {
114
+ auto input_data_type =
115
+ OperatorWithKernel::IndicateVarDataType (ctx, " Param" );
116
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
117
+ }
118
+ };
119
+
21
120
class LarsMomentumOpMaker : public framework ::OpProtoAndCheckerMaker {
22
121
public:
23
122
void Make () override {
@@ -104,7 +203,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
104
203
105
204
namespace ops = paddle::operators;
106
205
REGISTER_OPERATOR (
107
- lars_momentum, ops::MomentumOp , ops::LarsMomentumOpMaker,
206
+ lars_momentum, ops::LarsMomentumOp , ops::LarsMomentumOpMaker,
108
207
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
109
208
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
110
209
ops::LarsMomentumOpVarTypeInference);
0 commit comments