Skip to content

Commit c1c2633

Browse files
authored
Support backward of backward for Relu and add a new gradient checker by comparing theoretical and numerical Jacobian. (PaddlePaddle#16862)
* Support backward of backward and a new gradient checker * Rename decorators.py to decorator_helper.py, since Python on Windows CI has decorators package. 1. Add ReluDoubleGradMaker when register relu_grad. 2. Add a new gradient checker by comparing theoretical and numerical Jacobian. Check double gradients by double_grad_check.
1 parent 63d9fe3 commit c1c2633

13 files changed

+643
-23
lines changed

paddle/fluid/operators/activation_op.cc

+67
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,57 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
597597
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
598598
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
599599

600+
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
601+
public:
602+
using framework::OperatorWithKernel::OperatorWithKernel;
603+
604+
void InferShape(framework::InferShapeContext* ctx) const override {
605+
if (ctx->HasOutput("DOut")) {
606+
ctx->ShareDim("Out", "DOut");
607+
ctx->ShareLoD("Out", "DOut");
608+
}
609+
if (ctx->HasOutput("DDOut")) {
610+
ctx->ShareDim("Out", "DDOut");
611+
ctx->ShareLoD("Out", "DDOut");
612+
}
613+
}
614+
615+
protected:
616+
framework::OpKernelType GetExpectedKernelType(
617+
const framework::ExecutionContext& ctx) const override {
618+
return GetKernelType(ctx, *this, "Out");
619+
}
620+
};
621+
622+
//
623+
// ReluGrad: dx = dy if y >= 0 else 0
624+
// ReluGradGrad: ddy = ddx if y >= 0 else 0
625+
//
626+
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
627+
public:
628+
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
629+
630+
protected:
631+
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
632+
auto* op = new ::paddle::framework::OpDesc();
633+
op->SetType("relu_grad_grad");
634+
// input1: Out
635+
op->SetInput("Out", Input("Out"));
636+
// X@GRAD@GRAD: ddx
637+
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
638+
op->SetAttrMap(Attrs());
639+
// Out@GRAD@GRAD: ddy
640+
op->SetOutput("DOut", InputGrad("Out"));
641+
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
642+
return std::unique_ptr<::paddle::framework::OpDesc>(op);
643+
}
644+
};
645+
600646
} // namespace operators
601647
} // namespace paddle
602648

603649
namespace ops = paddle::operators;
650+
namespace plat = paddle::platform;
604651

605652
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
606653
REGISTER_OPERATOR( \
@@ -632,3 +679,23 @@ namespace ops = paddle::operators;
632679

633680
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
634681
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
682+
683+
REGISTER_OPERATOR(
684+
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
685+
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
686+
paddle::framework::SingleOpInplaceInToOut);
687+
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
688+
paddle::framework::SingleOpInplaceInToOut,
689+
ops::ReluDoubleGradMaker);
690+
REGISTER_OPERATOR(relu_grad_grad, ops::ActivationOpDoubleGrad);
691+
692+
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
693+
694+
REGISTER_OP_CPU_KERNEL(
695+
relu_grad_grad,
696+
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
697+
ops::ReluGradGradFunctor<float>>,
698+
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
699+
ops::ReluGradGradFunctor<double>>,
700+
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
701+
ops::ReluGradGradFunctor<plat::float16>>);

paddle/fluid/operators/activation_op.cu

+11
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ namespace plat = paddle::platform;
3232
ops::grad_functor<plat::float16>>);
3333

3434
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
35+
36+
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
37+
38+
REGISTER_OP_CUDA_KERNEL(
39+
relu_grad_grad,
40+
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
41+
ops::ReluGradGradFunctor<float>>,
42+
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
43+
ops::ReluGradGradFunctor<double>>,
44+
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
45+
ops::ReluGradGradFunctor<plat::float16>>);

paddle/fluid/operators/activation_op.h

+120-1
Original file line numberDiff line numberDiff line change
@@ -1198,14 +1198,133 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
11981198
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
11991199
};
12001200

1201+
/*
1202+
* in arguments: x, out, ddx
1203+
* out arguments: ddout, dout, dx
1204+
*/
1205+
template <ActBwdOpFwdDeps kDepValue>
1206+
inline void ExtractActivationDoubleGradTensor(
1207+
const framework::ExecutionContext& ctx, const framework::Tensor** X,
1208+
const framework::Tensor** Out, const framework::Tensor** ddX,
1209+
framework::Tensor** dX, framework::Tensor** dOut,
1210+
framework::Tensor** ddOut) {
1211+
auto out_var = ctx.InputVar("Out");
1212+
auto ddx_var = ctx.InputVar("DDX");
1213+
auto ddo_var = ctx.OutputVar("DDOut");
1214+
auto do_var = ctx.OutputVar("DOut");
1215+
PADDLE_ENFORCE(out_var != nullptr,
1216+
"Cannot get input Variable Out, variable name = %s",
1217+
ctx.op().Input("Out"));
1218+
PADDLE_ENFORCE(ddx_var != nullptr,
1219+
"Cannot get input Variable %s, variable name = %s", "DDX",
1220+
ctx.op().Input("DDX"));
1221+
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
1222+
*Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
1223+
*ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var);
1224+
if (ddo_var) {
1225+
*ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
1226+
ddo_var);
1227+
}
1228+
if (do_var) {
1229+
*dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
1230+
do_var);
1231+
}
1232+
} else {
1233+
*Out = ctx.Input<framework::Tensor>("Out");
1234+
*ddX = ctx.Input<framework::Tensor>("DDX");
1235+
if (ddo_var) {
1236+
*ddOut = ctx.Output<framework::Tensor>("DDOut");
1237+
}
1238+
if (do_var) {
1239+
*dOut = ctx.Output<framework::Tensor>("DOut");
1240+
}
1241+
}
1242+
PADDLE_ENFORCE(*ddX != nullptr,
1243+
"Cannot get output tensor %s, variable name = %s", "DDX",
1244+
ctx.op().Output("DDX"));
1245+
1246+
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
1247+
auto x_var = ctx.InputVar("X");
1248+
PADDLE_ENFORCE(x_var != nullptr,
1249+
"Cannot get input tensor X, variable name = %s",
1250+
ctx.op().Input("X"));
1251+
auto dx_var = ctx.OutputVar("DX");
1252+
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
1253+
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
1254+
if (dx_var) {
1255+
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
1256+
dx_var);
1257+
}
1258+
} else {
1259+
*X = ctx.Input<framework::Tensor>("X");
1260+
if (dx_var) {
1261+
*dX = ctx.Output<framework::Tensor>("DX");
1262+
}
1263+
}
1264+
} else {
1265+
VLOG(10) << " Inplace activation of Op : " << ctx.op().Type();
1266+
*X = *ddX;
1267+
}
1268+
}
1269+
1270+
template <typename DeviceContext, typename Functor>
1271+
class ActivationDoubleGradKernel
1272+
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
1273+
public:
1274+
using T = typename Functor::ELEMENT_TYPE;
1275+
void Compute(const framework::ExecutionContext& ctx) const override {
1276+
const framework::Tensor *X, *Out, *ddX;
1277+
X = Out = ddX = nullptr;
1278+
framework::Tensor *ddOut, *dOut, *dX;
1279+
ddOut = dOut = dX = nullptr;
1280+
1281+
ExtractActivationDoubleGradTensor<Functor::FwdDeps()>(ctx, &X, &Out, &ddX,
1282+
&dX, &dOut, &ddOut);
1283+
1284+
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
1285+
if (dOut) dOut->mutable_data<T>(ctx.GetPlace());
1286+
if (dX) dX->mutable_data<T>(Out->dims(), ctx.GetPlace());
1287+
1288+
auto& place = ctx.template device_context<DeviceContext>();
1289+
1290+
Functor functor;
1291+
auto attrs = functor.GetAttrs();
1292+
for (auto& attr : attrs) {
1293+
*attr.second = ctx.Attr<float>(attr.first);
1294+
}
1295+
functor(place, X, Out, ddX, ddOut, dOut, dX);
1296+
}
1297+
};
1298+
1299+
template <typename T>
1300+
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
1301+
template <typename Device>
1302+
void operator()(const Device& dev, const framework::Tensor* X,
1303+
const framework::Tensor* Out, const framework::Tensor* ddX,
1304+
framework::Tensor* ddOut, framework::Tensor* dOut,
1305+
framework::Tensor* dX) const {
1306+
auto* d = dev.eigen_device();
1307+
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
1308+
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
1309+
if (ddOut) {
1310+
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
1311+
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
1312+
}
1313+
if (dOut) {
1314+
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
1315+
dout.device(*d) = dout.constant(static_cast<T>(0));
1316+
}
1317+
}
1318+
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
1319+
};
1320+
12011321
} // namespace operators
12021322
} // namespace paddle
12031323

12041324
#define FOR_EACH_ACTIVATION_OP(__macro) \
12051325
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
12061326
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
12071327
__macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
1208-
__macro(relu, Relu, ReluFunctor, ReluGradFunctor); \
12091328
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
12101329
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
12111330
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \

python/paddle/fluid/backward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
611611
if inputs:
612612
for op in op_path:
613613
for name in op.desc.input_arg_names():
614-
if name not in input_names:
614+
if name not in input_names and block.vars[name].stop_gradient:
615615
no_grad_set.add(name)
616616

617617
return op_path

python/paddle/fluid/tests/unittests/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com
2929
list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
3030

3131
list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
32-
list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test
32+
list(REMOVE_ITEM TEST_OPS decorator_helper) # decorator_helper is a helper python file, not a test
3333
if(APPLE)
3434
if(NOT WITH_DISTRIBUTE)
3535
list(REMOVE_ITEM TEST_OPS test_desc_clone)

0 commit comments

Comments
 (0)