@@ -14,6 +14,11 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/gru_op.h"
16
16
#include < string>
17
+ #include " paddle/fluid/operators/math/blas.h"
18
+ #include " paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
19
+ #include " paddle/fluid/operators/math/detail/gru_kernel.h"
20
+
21
+ DECLARE_int32 (paddle_num_threads);
17
22
18
23
namespace paddle {
19
24
namespace operators {
@@ -211,16 +216,167 @@ class GRUGradOp : public framework::OperatorWithKernel {
211
216
}
212
217
};
213
218
219
+ template <typename T>
220
+ class GRUCPUKernel : public framework ::OpKernel<T> {
221
+ public:
222
+ void BatchCompute (const framework::ExecutionContext& context) const {
223
+ using DeviceContext = paddle::platform::CPUDeviceContext;
224
+ auto * input = context.Input <LoDTensor>(" Input" );
225
+ auto * h0 = context.Input <Tensor>(" H0" );
226
+ auto * weight = context.Input <Tensor>(" Weight" );
227
+ const T* weight_data = weight->data <T>();
228
+ auto * bias = context.Input <Tensor>(" Bias" );
229
+ auto * batch_gate = context.Output <LoDTensor>(" BatchGate" );
230
+ batch_gate->mutable_data <T>(context.GetPlace ());
231
+ auto * batch_reset_hidden_prev =
232
+ context.Output <LoDTensor>(" BatchResetHiddenPrev" );
233
+ batch_reset_hidden_prev->mutable_data <T>(context.GetPlace ());
234
+ auto * batch_hidden = context.Output <LoDTensor>(" BatchHidden" );
235
+ batch_hidden->mutable_data <T>(context.GetPlace ());
236
+ auto * hidden = context.Output <LoDTensor>(" Hidden" );
237
+ hidden->mutable_data <T>(context.GetPlace ());
238
+
239
+ auto hidden_dims = hidden->dims ();
240
+
241
+ bool is_reverse = context.Attr <bool >(" is_reverse" );
242
+ math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
243
+ auto & dev_ctx = context.template device_context <DeviceContext>();
244
+ to_batch (dev_ctx, *input, batch_gate, true , is_reverse);
245
+
246
+ if (bias) {
247
+ math::RowwiseAdd<DeviceContext, T> add_bias;
248
+ add_bias (dev_ctx, *batch_gate, *bias, batch_gate);
249
+ }
250
+
251
+ int frame_size = hidden_dims[1 ];
252
+ math::GRUMetaValue<T> gru_value;
253
+ gru_value.gate_weight = const_cast <T*>(weight_data);
254
+ gru_value.state_weight =
255
+ const_cast <T*>(weight_data + 2 * frame_size * frame_size);
256
+ Tensor ordered_h0;
257
+
258
+ framework::Vector<size_t > order (batch_gate->lod ()[2 ]);
259
+
260
+ if (h0) {
261
+ // Since the batch computing for GRU reorders the input sequences
262
+ // according to their length. The initialized cell state also needs
263
+ // to reorder.
264
+ ReorderInitState<DeviceContext, T>(
265
+ context.template device_context <DeviceContext>(), *h0, order,
266
+ &ordered_h0, true );
267
+ gru_value.prev_out_value = ordered_h0.data <T>();
268
+ } else {
269
+ gru_value.prev_out_value = nullptr ;
270
+ }
271
+ auto batch_starts = batch_gate->lod ()[0 ];
272
+ size_t seq_len = batch_starts.size () - 1 ;
273
+ auto active_node = math::detail::GetActivationType (
274
+ context.Attr <std::string>(" activation" ));
275
+ auto active_gate = math::detail::GetActivationType (
276
+ context.Attr <std::string>(" gate_activation" ));
277
+
278
+ #ifdef PADDLE_WITH_MKLML
279
+ // use MKL packed to speedup GEMM
280
+ if (FLAGS_paddle_num_threads >= 4 ) {
281
+ auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
282
+ T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
283
+ frame_size * 2 /* width of weight*/ ,
284
+ frame_size /* height of height*/ );
285
+ PADDLE_ENFORCE (packed_gate);
286
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
287
+ frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
288
+ packed_gate);
289
+ T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
290
+ frame_size /* width of weight*/ ,
291
+ frame_size /* height of height*/ );
292
+ PADDLE_ENFORCE (packed_state);
293
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
294
+ frame_size, T (1.0 ), gru_value.state_weight , frame_size,
295
+ packed_state);
296
+ for (size_t n = 0 ; n < seq_len; n++) {
297
+ int bstart = static_cast <int >(batch_starts[n]);
298
+ int bend = static_cast <int >(batch_starts[n + 1 ]);
299
+ int cur_batch_size = bend - bstart;
300
+
301
+ Tensor gate_t = batch_gate->Slice (bstart, bend);
302
+ Tensor reset_hidden_prev_t =
303
+ batch_reset_hidden_prev->Slice (bstart, bend);
304
+ Tensor hidden_t = batch_hidden->Slice (bstart, bend);
305
+ gru_value.output_value = hidden_t .data <T>();
306
+ gru_value.gate_value = gate_t .data <T>();
307
+ gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
308
+
309
+ if (gru_value.prev_out_value ) {
310
+ blas.GEMM_COMPUTE (
311
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2 ,
312
+ frame_size, gru_value.prev_out_value , frame_size, packed_gate,
313
+ frame_size * 2 , T (1 ), gru_value.gate_value , frame_size * 3 );
314
+ }
315
+
316
+ math::detail::forward_reset_output (
317
+ math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
318
+ cur_batch_size, active_gate);
319
+
320
+ if (gru_value.prev_out_value ) {
321
+ blas.GEMM_COMPUTE (
322
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
323
+ gru_value.reset_output_value , frame_size, packed_state,
324
+ frame_size, T (1 ), gru_value.gate_value + frame_size * 2 ,
325
+ frame_size * 3 );
326
+ }
327
+
328
+ math::detail::forward_final_output (
329
+ math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
330
+ cur_batch_size, active_node);
331
+
332
+ gru_value.prev_out_value = gru_value.output_value ;
333
+ }
334
+
335
+ blas.GEMM_FREE (packed_gate);
336
+ blas.GEMM_FREE (packed_state);
337
+ } else {
338
+ #endif
339
+ for (size_t n = 0 ; n < seq_len; n++) {
340
+ int bstart = static_cast <int >(batch_starts[n]);
341
+ int bend = static_cast <int >(batch_starts[n + 1 ]);
342
+ int cur_batch_size = bend - bstart;
343
+
344
+ Tensor gate_t = batch_gate->Slice (bstart, bend);
345
+ Tensor reset_hidden_prev_t =
346
+ batch_reset_hidden_prev->Slice (bstart, bend);
347
+ Tensor hidden_t = batch_hidden->Slice (bstart, bend);
348
+ gru_value.output_value = hidden_t .data <T>();
349
+ gru_value.gate_value = gate_t .data <T>();
350
+ gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
351
+
352
+ math::GRUUnitFunctor<DeviceContext, T>::compute (
353
+ dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
354
+ active_gate);
355
+
356
+ gru_value.prev_out_value = gru_value.output_value ;
357
+ }
358
+ #ifdef PADDLE_WITH_MKLML
359
+ }
360
+ #endif
361
+ math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
362
+ batch_hidden->set_lod (batch_gate->lod ());
363
+ to_seq (dev_ctx, *batch_hidden, hidden);
364
+ }
365
+
366
+ void Compute (const framework::ExecutionContext& context) const override {
367
+ BatchCompute (context);
368
+ }
369
+ };
370
+
214
371
} // namespace operators
215
372
} // namespace paddle
216
373
217
374
namespace ops = paddle::operators;
218
375
REGISTER_OPERATOR (gru, ops::GRUOp, ops::GRUOpMaker,
219
376
paddle::framework::DefaultGradOpDescMaker<true >);
220
377
REGISTER_OPERATOR (gru_grad, ops::GRUGradOp);
221
- REGISTER_OP_CPU_KERNEL (
222
- gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float >,
223
- ops::GRUKernel<paddle::platform::CPUDeviceContext, double >);
378
+ REGISTER_OP_CPU_KERNEL (gru, ops::GRUCPUKernel<float >,
379
+ ops::GRUCPUKernel<double >);
224
380
REGISTER_OP_CPU_KERNEL (
225
381
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float >,
226
382
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double >);
0 commit comments