@@ -1198,14 +1198,133 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
1198
1198
static constexpr ActBwdOpFwdDeps FwdDeps () { return kDepX ; }
1199
1199
};
1200
1200
1201
+ /*
1202
+ * in arguments: x, out, ddx
1203
+ * out arguments: ddout, dout, dx
1204
+ */
1205
+ template <ActBwdOpFwdDeps kDepValue >
1206
+ inline void ExtractActivationDoubleGradTensor (
1207
+ const framework::ExecutionContext& ctx, const framework::Tensor** X,
1208
+ const framework::Tensor** Out, const framework::Tensor** ddX,
1209
+ framework::Tensor** dX, framework::Tensor** dOut,
1210
+ framework::Tensor** ddOut) {
1211
+ auto out_var = ctx.InputVar (" Out" );
1212
+ auto ddx_var = ctx.InputVar (" DDX" );
1213
+ auto ddo_var = ctx.OutputVar (" DDOut" );
1214
+ auto do_var = ctx.OutputVar (" DOut" );
1215
+ PADDLE_ENFORCE (out_var != nullptr ,
1216
+ " Cannot get input Variable Out, variable name = %s" ,
1217
+ ctx.op ().Input (" Out" ));
1218
+ PADDLE_ENFORCE (ddx_var != nullptr ,
1219
+ " Cannot get input Variable %s, variable name = %s" , " DDX" ,
1220
+ ctx.op ().Input (" DDX" ));
1221
+ if (CanBeUsedBySelectedRows.count (ctx.op ().Type ())) {
1222
+ *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*out_var);
1223
+ *ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*ddx_var);
1224
+ if (ddo_var) {
1225
+ *ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (
1226
+ ddo_var);
1227
+ }
1228
+ if (do_var) {
1229
+ *dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (
1230
+ do_var);
1231
+ }
1232
+ } else {
1233
+ *Out = ctx.Input <framework::Tensor>(" Out" );
1234
+ *ddX = ctx.Input <framework::Tensor>(" DDX" );
1235
+ if (ddo_var) {
1236
+ *ddOut = ctx.Output <framework::Tensor>(" DDOut" );
1237
+ }
1238
+ if (do_var) {
1239
+ *dOut = ctx.Output <framework::Tensor>(" DOut" );
1240
+ }
1241
+ }
1242
+ PADDLE_ENFORCE (*ddX != nullptr ,
1243
+ " Cannot get output tensor %s, variable name = %s" , " DDX" ,
1244
+ ctx.op ().Output (" DDX" ));
1245
+
1246
+ if (static_cast <int >(kDepValue ) & static_cast <int >(kDepX )) {
1247
+ auto x_var = ctx.InputVar (" X" );
1248
+ PADDLE_ENFORCE (x_var != nullptr ,
1249
+ " Cannot get input tensor X, variable name = %s" ,
1250
+ ctx.op ().Input (" X" ));
1251
+ auto dx_var = ctx.OutputVar (" DX" );
1252
+ if (CanBeUsedBySelectedRows.count (ctx.op ().Type ())) {
1253
+ *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*x_var);
1254
+ if (dx_var) {
1255
+ *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (
1256
+ dx_var);
1257
+ }
1258
+ } else {
1259
+ *X = ctx.Input <framework::Tensor>(" X" );
1260
+ if (dx_var) {
1261
+ *dX = ctx.Output <framework::Tensor>(" DX" );
1262
+ }
1263
+ }
1264
+ } else {
1265
+ VLOG (10 ) << " Inplace activation of Op : " << ctx.op ().Type ();
1266
+ *X = *ddX;
1267
+ }
1268
+ }
1269
+
1270
+ template <typename DeviceContext, typename Functor>
1271
+ class ActivationDoubleGradKernel
1272
+ : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
1273
+ public:
1274
+ using T = typename Functor::ELEMENT_TYPE;
1275
+ void Compute (const framework::ExecutionContext& ctx) const override {
1276
+ const framework::Tensor *X, *Out, *ddX;
1277
+ X = Out = ddX = nullptr ;
1278
+ framework::Tensor *ddOut, *dOut, *dX;
1279
+ ddOut = dOut = dX = nullptr ;
1280
+
1281
+ ExtractActivationDoubleGradTensor<Functor::FwdDeps ()>(ctx, &X, &Out, &ddX,
1282
+ &dX, &dOut, &ddOut);
1283
+
1284
+ if (ddOut) ddOut->mutable_data <T>(ctx.GetPlace ());
1285
+ if (dOut) dOut->mutable_data <T>(ctx.GetPlace ());
1286
+ if (dX) dX->mutable_data <T>(Out->dims (), ctx.GetPlace ());
1287
+
1288
+ auto & place = ctx.template device_context <DeviceContext>();
1289
+
1290
+ Functor functor;
1291
+ auto attrs = functor.GetAttrs ();
1292
+ for (auto & attr : attrs) {
1293
+ *attr.second = ctx.Attr <float >(attr.first );
1294
+ }
1295
+ functor (place, X, Out, ddX, ddOut, dOut, dX);
1296
+ }
1297
+ };
1298
+
1299
+ template <typename T>
1300
+ struct ReluGradGradFunctor : public BaseActivationFunctor <T> {
1301
+ template <typename Device>
1302
+ void operator ()(const Device& dev, const framework::Tensor* X,
1303
+ const framework::Tensor* Out, const framework::Tensor* ddX,
1304
+ framework::Tensor* ddOut, framework::Tensor* dOut,
1305
+ framework::Tensor* dX) const {
1306
+ auto * d = dev.eigen_device ();
1307
+ auto ddx = framework::EigenVector<T>::Flatten (detail::Ref (ddX));
1308
+ auto out = framework::EigenVector<T>::Flatten (detail::Ref (Out));
1309
+ if (ddOut) {
1310
+ auto ddout = framework::EigenVector<T>::Flatten (detail::Ref (ddOut));
1311
+ ddout.device (*d) = ddx * (out > static_cast <T>(0 )).template cast <T>();
1312
+ }
1313
+ if (dOut) {
1314
+ auto dout = framework::EigenVector<T>::Flatten (detail::Ref (dOut));
1315
+ dout.device (*d) = dout.constant (static_cast <T>(0 ));
1316
+ }
1317
+ }
1318
+ static constexpr ActBwdOpFwdDeps FwdDeps () { return kDepOut ; }
1319
+ };
1320
+
1201
1321
} // namespace operators
1202
1322
} // namespace paddle
1203
1323
1204
1324
#define FOR_EACH_ACTIVATION_OP (__macro ) \
1205
1325
__macro (sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
1206
1326
__macro (logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
1207
1327
__macro (exp, Exp, ExpFunctor, ExpGradFunctor); \
1208
- __macro (relu, Relu, ReluFunctor, ReluGradFunctor); \
1209
1328
__macro (gelu, Gelu, GeluFunctor, GeluGradFunctor); \
1210
1329
__macro (tanh, Tanh, TanhFunctor, TanhGradFunctor); \
1211
1330
__macro (atan, Atan, AtanFunctor, AtanGradFunctor); \
0 commit comments