Skip to content

Commit b817edd

Browse files
authored
Merge branch 'PaddlePaddle:develop' into dropout_opt_clean_BcInT
2 parents bc156bd + 05499c7 commit b817edd

File tree

16 files changed

+640
-58
lines changed

16 files changed

+640
-58
lines changed

cmake/external/cinn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if(NOT CINN_GIT_TAG)
2020
set(CINN_GIT_TAG develop)
2121
endif()
2222

23-
message(STATUS "CINN version: " ${CINN_GIT_TAG})
23+
message(STATUS "CINN version: " ${CINN_GIT_TAG})
2424

2525
# TODO(zhhsplendid): CINN has lots of warnings during early development.
2626
# They will be treated as errors under paddle. We set no-error now and we will

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
prim_white_list = [
6868
"matmul_double_grad",
6969
"tanh_double_grad",
70+
"add_double_grad",
71+
"multiply_double_grad",
72+
"subtract_double_grad",
7073
]
7174

7275
# dict of special api that forward api's output will affect bacward api's output

paddle/fluid/operators/elementwise/elementwise_add_op.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,42 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
9999
}
100100
};
101101

102+
class ElementwiseAddCompositeDoubleGradOpMaker
103+
: public prim::CompositeGradOpMakerBase {
104+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
105+
106+
public:
107+
void Apply() override {
108+
// get input
109+
paddle::Tensor y = this->GetSingleForwardInput("Y");
110+
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
111+
paddle::optional<paddle::Tensor> ddx =
112+
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
113+
paddle::optional<paddle::Tensor> ddy =
114+
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
115+
// get output
116+
paddle::Tensor grad_out_grad_t =
117+
this->GetSingleInputGrad(framework::GradVarName("Out"));
118+
119+
// get attr
120+
int axis = static_cast<int>(this->Attr<int>("axis"));
121+
PADDLE_ENFORCE_EQ(
122+
axis,
123+
-1,
124+
phi::errors::InvalidArgument("We only support axis = -1 in composite "
125+
"add_doubel_grad but we got: ",
126+
axis));
127+
128+
paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
129+
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
130+
131+
VLOG(6) << "Runing add_double_grad composite func";
132+
prim::add_double_grad<prim::DescTensor>(
133+
y, out_grad, ddx, ddy, axis, grad_out_grad);
134+
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
135+
}
136+
};
137+
102138
template <typename T>
103139
class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> {
104140
public:
@@ -139,7 +175,8 @@ REGISTER_OPERATOR(
139175
ops::ElementwiseGradOpInplaceInferer,
140176
ops::ElementwiseGradNoBufVarsInferer,
141177
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
142-
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
178+
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>,
179+
ops::ElementwiseAddCompositeDoubleGradOpMaker);
143180

144181
REGISTER_OPERATOR(
145182
elementwise_add_grad_grad,

paddle/fluid/operators/elementwise/elementwise_mul_op.cc

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,56 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
118118
}
119119
};
120120

121+
class ElementwiseMulCompositeDoubleGradOpMaker
122+
: public prim::CompositeGradOpMakerBase {
123+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
124+
125+
public:
126+
void Apply() override {
127+
// get input
128+
paddle::Tensor x = this->GetSingleForwardInput("X");
129+
paddle::Tensor y = this->GetSingleForwardInput("Y");
130+
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
131+
paddle::optional<paddle::Tensor> ddx =
132+
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
133+
paddle::optional<paddle::Tensor> ddy =
134+
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
135+
136+
// get attr
137+
int axis = static_cast<int>(this->Attr<int>("axis"));
138+
PADDLE_ENFORCE_EQ(
139+
axis,
140+
-1,
141+
phi::errors::InvalidArgument("We only support axis = -1 in composite "
142+
"add_doubel_grad but we got: ",
143+
axis));
144+
145+
// get output
146+
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
147+
paddle::Tensor y_grad_t = this->GetSingleInputGrad("Y");
148+
paddle::Tensor grad_out_grad_t =
149+
this->GetSingleInputGrad(framework::GradVarName("Out"));
150+
151+
// get output ptr
152+
paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
153+
paddle::Tensor* y_grad = this->GetOutputPtr(&y_grad_t);
154+
paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
155+
// get output orginal name
156+
std::string x_grad_name = this->GetOutputName(x_grad_t);
157+
std::string y_grad_name = this->GetOutputName(y_grad_t);
158+
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
159+
160+
VLOG(6) << "Runing multiply_double_grad composite func";
161+
prim::multiply_double_grad<prim::DescTensor>(
162+
x, y, out_grad, ddx, ddy, axis, x_grad, y_grad, grad_out_grad);
163+
164+
// recover output name
165+
this->RecoverOutputName(x_grad_t, x_grad_name);
166+
this->RecoverOutputName(y_grad_t, y_grad_name);
167+
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
168+
}
169+
};
170+
121171
template <typename T>
122172
class ElementwiseMulTripleGradMaker : public framework::SingleGradOpMaker<T> {
123173
public:
@@ -162,7 +212,8 @@ REGISTER_OPERATOR(
162212
elementwise_mul_grad,
163213
ops::ElementwiseOpGrad,
164214
ops::ElementwiseMulDoubleGradMaker<paddle::framework::OpDesc>,
165-
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
215+
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>,
216+
ops::ElementwiseMulCompositeDoubleGradOpMaker);
166217

167218
REGISTER_OPERATOR(
168219
elementwise_mul_grad_grad,

paddle/fluid/operators/elementwise/elementwise_sub_op.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,42 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
102102
}
103103
};
104104

105+
class ElementwiseSubCompositeDoubleGradOpMaker
106+
: public prim::CompositeGradOpMakerBase {
107+
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
108+
109+
public:
110+
void Apply() override {
111+
// get input
112+
paddle::Tensor y = this->GetSingleForwardInput("Y");
113+
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
114+
paddle::optional<paddle::Tensor> ddx =
115+
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
116+
paddle::optional<paddle::Tensor> ddy =
117+
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
118+
// get output
119+
paddle::Tensor grad_out_grad_t =
120+
this->GetSingleInputGrad(framework::GradVarName("Out"));
121+
122+
// get attr
123+
int axis = static_cast<int>(this->Attr<int>("axis"));
124+
PADDLE_ENFORCE_EQ(
125+
axis,
126+
-1,
127+
phi::errors::InvalidArgument("We only support axis = -1 in composite "
128+
"subtract_doubel_grad but we got: ",
129+
axis));
130+
131+
paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
132+
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
133+
134+
VLOG(6) << "Runing subtract_double_grad composite func";
135+
prim::subtract_double_grad<prim::DescTensor>(
136+
y, out_grad, ddx, ddy, axis, grad_out_grad);
137+
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
138+
}
139+
};
140+
105141
} // namespace operators
106142
} // namespace paddle
107143

@@ -124,7 +160,9 @@ REGISTER_OPERATOR(
124160
ops::ElementwiseGradOpInplaceInferer,
125161
ops::ElementwiseGradNoBufVarsInferer,
126162
ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
127-
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
163+
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>,
164+
ops::ElementwiseSubCompositeDoubleGradOpMaker);
165+
128166
REGISTER_OPERATOR(elementwise_sub_grad_grad,
129167
ops::ElementwiseOpDoubleGradWithoutDXDY,
130168
ops::ElementwiseDoubleGradOpInplaceInferer,

paddle/fluid/prim/api/api.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
- bitwise_not
1313
- bitwise_or
1414
- bitwise_xor
15-
- unsqueeze
1615
- exp
1716
- scale
1817
- matmul

0 commit comments

Comments
 (0)