@@ -41,6 +41,12 @@ static std::unordered_set<std::string> InplaceOpSet = {
41
41
" floor" , " reciprocal" , " relu6" , " soft_relu" , " hard_sigmoid" ,
42
42
};
43
43
44
+ /* The following operator can be used to process SelectedRows, because the
45
+ * output of those operator for zero is zero too.
46
+ */
47
+ static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
48
+ " abs" , " abs_grad" , " square" , " square_grad" , " sqrt" , " sqrt_grad" };
49
+
44
50
static bool IsInplace (std::string op) { return InplaceOpSet.count (op); }
45
51
46
52
template <typename DeviceContext, typename Functor>
@@ -50,16 +56,38 @@ class ActivationKernel
50
56
using T = typename Functor::ELEMENT_TYPE;
51
57
52
58
void Compute (const framework::ExecutionContext& context) const override {
53
- auto & X = detail::Ref (context.Input <framework::Tensor>(" X" ),
54
- " Cannot get input tensor X, variable name = %s" ,
55
- context.op ().Input (" X" ));
56
-
57
- auto & Out = detail::Ref (context.Output <framework::Tensor>(" Out" ),
58
- " Cannot get output tensor Out, variable name = %s" ,
59
- context.op ().Output (" Out" ));
60
- Out.mutable_data <T>(context.GetPlace ());
59
+ auto x_var = context.InputVar (" X" );
60
+ auto out_var = context.OutputVar (" Out" );
61
+ PADDLE_ENFORCE (x_var != nullptr ,
62
+ " Cannot get input Variable X, variable name = %s" ,
63
+ context.op ().Input (" X" ));
64
+ PADDLE_ENFORCE (out_var != nullptr ,
65
+ " Cannot get output Variable Out, variable name = %s" ,
66
+ context.op ().Output (" Out" ));
67
+
68
+ framework::Tensor X, *Out;
69
+
70
+ if (CanBeUsedBySelectedRows.count (context.op ().Type ())) {
71
+ X = detail::Ref (
72
+ paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*x_var),
73
+ " Cannot get input Tensor X, variable name = %s" ,
74
+ context.op ().Input (" X" ));
75
+ Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (
76
+ out_var);
77
+ } else {
78
+ X = detail::Ref (context.Input <framework::Tensor>(" X" ),
79
+ " Cannot get input Tensor X, variable name = %s" ,
80
+ context.op ().Input (" X" ));
81
+ Out = context.Output <framework::Tensor>(" Out" );
82
+ }
83
+
84
+ PADDLE_ENFORCE (Out != nullptr ,
85
+ " Cannot get output tensor Out, variable name = %s" ,
86
+ context.op ().Output (" Out" ));
87
+
88
+ Out->mutable_data <T>(context.GetPlace ());
61
89
auto x = framework::EigenVector<T>::Flatten (X);
62
- auto out = framework::EigenVector<T>::Flatten (Out);
90
+ auto out = framework::EigenVector<T>::Flatten (* Out);
63
91
auto * place =
64
92
context.template device_context <DeviceContext>().eigen_device ();
65
93
Functor functor;
@@ -78,14 +106,54 @@ class ActivationGradKernel
78
106
public:
79
107
using T = typename Functor::ELEMENT_TYPE;
80
108
void Compute (const framework::ExecutionContext& context) const override {
81
- auto * Out = context.Input <framework::Tensor>(" Out" );
82
- auto * dOut =
83
- context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
84
- auto * dX = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
109
+ auto out_var = context.InputVar (" Out" );
110
+ auto out_grad_var = context.InputVar (framework::GradVarName (" Out" ));
111
+ auto x_grad_var = context.OutputVar (framework::GradVarName (" X" ));
112
+ PADDLE_ENFORCE (out_var != nullptr ,
113
+ " Cannot get input Variable Out, variable name = %s" ,
114
+ context.op ().Input (" Out" ));
115
+ PADDLE_ENFORCE (out_grad_var != nullptr ,
116
+ " Cannot get input Variable %s, variable name = %s" ,
117
+ framework::GradVarName (" Out" ),
118
+ context.op ().Input (framework::GradVarName (" Out" )));
119
+ PADDLE_ENFORCE (x_grad_var != nullptr ,
120
+ " Cannot get output Variable %s, variable name = %s" ,
121
+ framework::GradVarName (" X" ),
122
+ context.op ().Output (framework::GradVarName (" X" )));
123
+
124
+ framework::Tensor Out, dOut, *dX;
125
+ if (CanBeUsedBySelectedRows.count (context.op ().Type ())) {
126
+ Out = detail::Ref (
127
+ paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*out_var),
128
+ " Cannot get input Tensor Out, variable name = %s" ,
129
+ context.op ().Input (" Out" ));
130
+ dOut =
131
+ detail::Ref (paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (
132
+ *out_grad_var),
133
+ " Cannot get input Tensor %s, variable name = %s" ,
134
+ framework::GradVarName (" Out" ),
135
+ context.op ().Input (framework::GradVarName (" Out" )));
136
+ dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (
137
+ x_grad_var);
138
+ } else {
139
+ Out = detail::Ref (context.Input <framework::Tensor>(" Out" ),
140
+ " Cannot get input Tensor Out, variable name = %s" ,
141
+ context.op ().Input (" Out" ));
142
+ dOut = detail::Ref (
143
+ context.Input <framework::Tensor>(framework::GradVarName (" Out" )),
144
+ " Cannot get input Tensor %s, variable name = %s" ,
145
+ framework::GradVarName (" Out" ),
146
+ context.op ().Input (framework::GradVarName (" Out" )));
147
+ dX = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
148
+ }
149
+ PADDLE_ENFORCE (dX != nullptr ,
150
+ " Cannot get output tensor %s, variable name = %s" ,
151
+ framework::GradVarName (" X" ),
152
+ context.op ().Output (framework::GradVarName (" X" )));
85
153
dX->mutable_data <T>(context.GetPlace ());
86
154
87
- auto dout = framework::EigenVector<T>::Flatten (* dOut);
88
- auto out = framework::EigenVector<T>::Flatten (* Out);
155
+ auto dout = framework::EigenVector<T>::Flatten (dOut);
156
+ auto out = framework::EigenVector<T>::Flatten (Out);
89
157
auto dx = framework::EigenVector<T>::Flatten (*dX);
90
158
auto * place =
91
159
context.template device_context <DeviceContext>().eigen_device ();
@@ -96,8 +164,19 @@ class ActivationGradKernel
96
164
}
97
165
bool inplace = functor.Inplace ();
98
166
if (!inplace) {
99
- auto * X = context.Input <framework::Tensor>(" X" );
100
- auto x = framework::EigenVector<T>::Flatten (*X);
167
+ auto x_var = context.InputVar (" X" );
168
+ PADDLE_ENFORCE (x_var != nullptr ,
169
+ " Cannot get input tensor X, variable name = %s" ,
170
+ context.op ().Input (" X" ));
171
+ framework::Tensor X;
172
+ if (CanBeUsedBySelectedRows.count (context.op ().Type ())) {
173
+ X = detail::Ref (
174
+ paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar (*x_var));
175
+ } else {
176
+ X = detail::Ref (context.Input <framework::Tensor>(" X" ));
177
+ }
178
+
179
+ auto x = framework::EigenVector<T>::Flatten (X);
101
180
functor (*place, x, out, dout, dx);
102
181
} else {
103
182
VLOG (10 ) << " Inplace activation " ;
0 commit comments