Skip to content

Commit 9fcbe2c

Browse files
committed
Adding las_op_momentum infer_shape
1 parent a8010c9 commit 9fcbe2c

File tree

3 files changed

+132
-20
lines changed

3 files changed

+132
-20
lines changed

paddle/fluid/framework/operator.cc

+12
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,9 @@ void OperatorWithKernel::ParseInputDataType(
15621562
proto::VarType::Type default_data_type =
15631563
static_cast<proto::VarType::Type>(-1);
15641564
const std::vector<Variable*> vars = ctx.MultiInputVar(name);
1565+
if (vars.size() == 161) {
1566+
std::cout << "vars.size(): " << vars.size() << std::endl;
1567+
}
15651568
for (size_t i = 0; i < vars.size(); ++i) {
15661569
const Variable* var = vars[i];
15671570
if (var != nullptr) {
@@ -1588,6 +1591,15 @@ void OperatorWithKernel::ParseInputDataType(
15881591
"not initialized.",
15891592
Type(), name, ctx.InputNames(name).at(i)));
15901593
proto::VarType::Type tmp = t->type();
1594+
1595+
int a = static_cast<int>(default_data_type);
1596+
int b = static_cast<int>(*data_type);
1597+
int c = static_cast<int>(tmp);
1598+
std::cout << i << "th op." << std::endl;
1599+
std::cout << "default_data_type :" << a << std::endl;
1600+
std::cout << "data_type :" << b << std::endl;
1601+
std::cout << "tmp_type :" << c << std::endl;
1602+
15911603
PADDLE_ENFORCE(
15921604
tmp == *data_type || *data_type == default_data_type,
15931605
platform::errors::InvalidArgument(

paddle/fluid/operators/optimizers/lars_momentum_op.cc

+101-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,110 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
16-
#include "paddle/fluid/operators/optimizers/momentum_op.h"
1716

1817
namespace paddle {
1918
namespace operators {
2019

20+
class LarsMomentumOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
PADDLE_ENFORCE_EQ(ctx->HasInputs("Param"), true,
27+
platform::errors::NotFound(
28+
"Inputs(param) of LarsMomentum should not be null."));
29+
PADDLE_ENFORCE_EQ(ctx->HasInputs("Grad"), true,
30+
platform::errors::NotFound(
31+
"Input(grad) of LarsMomentum should not be null."));
32+
PADDLE_ENFORCE_EQ(
33+
ctx->HasInputs("Velocity"), true,
34+
platform::errors::NotFound(
35+
"Inputs(velocity) of LarsMomentum should not be null."));
36+
PADDLE_ENFORCE_EQ(
37+
ctx->HasInputs("LearningRate"), true,
38+
platform::errors::NotFound(
39+
"Input(LearningRate) of LarsMomentum should not be null."));
40+
PADDLE_ENFORCE_EQ(
41+
ctx->GetInputsVarType("Param").front(),
42+
framework::proto::VarType::LOD_TENSOR,
43+
platform::errors::InvalidArgument(
44+
"The input var's type should be LoDTensor, but the received is %s",
45+
ctx->GetInputsVarType("Param").front()));
46+
47+
PADDLE_ENFORCE_EQ(ctx->HasOutputs("ParamOut"), true,
48+
platform::errors::NotFound(
49+
"Output(ParamOut) of Momentum should not be null."));
50+
PADDLE_ENFORCE_EQ(
51+
ctx->HasOutputs("VelocityOut"), true,
52+
platform::errors::NotFound(
53+
"Output(VelocityOut) of Momentum should not be null."));
54+
55+
auto lr_dims = ctx->GetInputsDim("LearningRate");
56+
for (size_t i = 0; i < lr_dims.size(); ++i) {
57+
PADDLE_ENFORCE_NE(framework::product(lr_dims[i]), 0,
58+
platform::errors::InvalidArgument(
59+
"Maybe the Input variable LearningRate has not "
60+
"been initialized. You may need to confirm "
61+
"whether exe.run(startup_program) is put "
62+
"after optimizer.minimize function."));
63+
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
64+
platform::errors::InvalidArgument(
65+
"Learning_rate should be a scalar. But Received "
66+
"LearningRate's dim [%s]",
67+
framework::product(lr_dims[i])));
68+
}
69+
70+
auto param_dim = ctx->GetInputsDim("Param");
71+
auto grad_dim = ctx->GetInputsDim("Grad");
72+
auto velocity_dim = ctx->GetInputsDim("Velocity");
73+
PADDLE_ENFORCE_EQ(
74+
param_dim.size(), grad_dim.size(),
75+
platform::errors::InvalidArgument(
76+
"Param and Grad input of LarsMomentumOp should have the same "
77+
"quantity. But number of Param is [%d] and Grad is [%d].",
78+
param_dim.size(), grad_dim.size()));
79+
PADDLE_ENFORCE_EQ(
80+
param_dim.size(), velocity_dim.size(),
81+
platform::errors::InvalidArgument(
82+
"Param and Velocity input of LarsMomentumOp should have the same "
83+
"quantity. But number of Param is [%d] and Velocity is [%d].",
84+
param_dim.size(), velocity_dim.size()));
85+
86+
if (ctx->GetInputsVarType("Grad")[0] ==
87+
framework::proto::VarType::LOD_TENSOR) {
88+
for (size_t i = 0; i < param_dim.size(); ++i) {
89+
PADDLE_ENFORCE_EQ(
90+
param_dim[i], grad_dim[i],
91+
platform::errors::InvalidArgument(
92+
"Param and Grad input of MomentumOp should have the same "
93+
"dimension. But received Param's dim [%s] and Grad's dim [%s].",
94+
param_dim[i], grad_dim[i]));
95+
PADDLE_ENFORCE_EQ(
96+
param_dim[i], velocity_dim[i],
97+
platform::errors::InvalidArgument(
98+
"Param and Velocity of MomentumOp should have the same "
99+
"dimension. But received Param's dim [%s] and Velocity [%s].",
100+
param_dim[i], velocity_dim[i]));
101+
}
102+
}
103+
104+
ctx->SetOutputsDim("ParamOut", param_dim);
105+
ctx->SetOutputsDim("VelocityOut", param_dim);
106+
if (ctx->HasOutputs("MasterParamOut")) {
107+
ctx->SetOutputsDim("MasterParamOut", param_dim);
108+
}
109+
}
110+
111+
protected:
112+
framework::OpKernelType GetExpectedKernelType(
113+
const framework::ExecutionContext& ctx) const override {
114+
auto input_data_type =
115+
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
116+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
117+
}
118+
};
119+
21120
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
22121
public:
23122
void Make() override {
@@ -104,7 +203,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
104203

105204
namespace ops = paddle::operators;
106205
REGISTER_OPERATOR(
107-
lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
206+
lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker,
108207
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
109208
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
110209
ops::LarsMomentumOpVarTypeInference);

paddle/fluid/operators/optimizers/lars_momentum_op.cu

+19-18
Original file line numberDiff line numberDiff line change
@@ -413,28 +413,29 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
413413
reinterpret_cast<void*>(MomentumLarsKernel<T, MT>), grid_real,
414414
LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream());
415415
} else {
416-
auto param = ctx.Input<framework::LoDTensor>("Param");
417-
auto grad = ctx.Input<framework::LoDTensor>("Grad");
418-
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
419-
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
420-
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
421-
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
422-
423-
auto* p = param->data<T>();
424-
auto* g = grad->data<T>();
425-
auto* v = velocity->data<MT>();
426-
auto* lr = learning_rate->data<MT>();
427-
auto* p_out = param_out->mutable_data<T>(ctx.GetPlace());
428-
auto* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());
416+
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
417+
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
418+
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
419+
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
420+
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
421+
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
422+
423+
auto* p = param[0]->data<T>();
424+
auto* g = grad[0]->data<T>();
425+
auto* v = velocity[0]->data<MT>();
426+
auto* lr = learning_rate[0]->data<MT>();
427+
auto* p_out = param_out[0]->mutable_data<T>(ctx.GetPlace());
428+
auto* v_out = velocity_out[0]->mutable_data<MT>(ctx.GetPlace());
429429
const MT* master_p = nullptr;
430430
MT* master_p_out = nullptr;
431431
if (multi_precision) {
432-
auto master_param = ctx.Input<framework::Tensor>("MasterParam");
433-
auto master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
434-
master_p = master_param->data<MT>();
435-
master_p_out = master_param_out->mutable_data<MT>(ctx.GetPlace());
432+
auto master_param = ctx.MultiInput<framework::Tensor>("MasterParam");
433+
auto master_param_out =
434+
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
435+
master_p = master_param[0]->data<MT>();
436+
master_p_out = master_param_out[0]->mutable_data<MT>(ctx.GetPlace());
436437
}
437-
int64_t numel = param->numel();
438+
int64_t numel = param[0]->numel();
438439
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
439440
&num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
440441
sizeof(MT) << 1);

0 commit comments

Comments
 (0)