Skip to content

Commit ff92b6b

Browse files
authored
Merge pull request #12531 from tensor-tang/refine/op/gru
Refine gru cpu forward
2 parents d4d8f83 + c588c64 commit ff92b6b

File tree

6 files changed

+373
-87
lines changed

6 files changed

+373
-87
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/gru_op.h"
1616
#include <string>
17+
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
19+
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
20+
21+
DECLARE_int32(paddle_num_threads);
1722

1823
namespace paddle {
1924
namespace operators {
@@ -211,16 +216,167 @@ class GRUGradOp : public framework::OperatorWithKernel {
211216
}
212217
};
213218

219+
template <typename T>
220+
class GRUCPUKernel : public framework::OpKernel<T> {
221+
public:
222+
void BatchCompute(const framework::ExecutionContext& context) const {
223+
using DeviceContext = paddle::platform::CPUDeviceContext;
224+
auto* input = context.Input<LoDTensor>("Input");
225+
auto* h0 = context.Input<Tensor>("H0");
226+
auto* weight = context.Input<Tensor>("Weight");
227+
const T* weight_data = weight->data<T>();
228+
auto* bias = context.Input<Tensor>("Bias");
229+
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
230+
batch_gate->mutable_data<T>(context.GetPlace());
231+
auto* batch_reset_hidden_prev =
232+
context.Output<LoDTensor>("BatchResetHiddenPrev");
233+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
234+
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
235+
batch_hidden->mutable_data<T>(context.GetPlace());
236+
auto* hidden = context.Output<LoDTensor>("Hidden");
237+
hidden->mutable_data<T>(context.GetPlace());
238+
239+
auto hidden_dims = hidden->dims();
240+
241+
bool is_reverse = context.Attr<bool>("is_reverse");
242+
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
243+
auto& dev_ctx = context.template device_context<DeviceContext>();
244+
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
245+
246+
if (bias) {
247+
math::RowwiseAdd<DeviceContext, T> add_bias;
248+
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
249+
}
250+
251+
int frame_size = hidden_dims[1];
252+
math::GRUMetaValue<T> gru_value;
253+
gru_value.gate_weight = const_cast<T*>(weight_data);
254+
gru_value.state_weight =
255+
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
256+
Tensor ordered_h0;
257+
258+
framework::Vector<size_t> order(batch_gate->lod()[2]);
259+
260+
if (h0) {
261+
// Since the batch computing for GRU reorders the input sequences
262+
// according to their length. The initialized cell state also needs
263+
// to reorder.
264+
ReorderInitState<DeviceContext, T>(
265+
context.template device_context<DeviceContext>(), *h0, order,
266+
&ordered_h0, true);
267+
gru_value.prev_out_value = ordered_h0.data<T>();
268+
} else {
269+
gru_value.prev_out_value = nullptr;
270+
}
271+
auto batch_starts = batch_gate->lod()[0];
272+
size_t seq_len = batch_starts.size() - 1;
273+
auto active_node = math::detail::GetActivationType(
274+
context.Attr<std::string>("activation"));
275+
auto active_gate = math::detail::GetActivationType(
276+
context.Attr<std::string>("gate_activation"));
277+
278+
#ifdef PADDLE_WITH_MKLML
279+
// use MKL packed to speedup GEMM
280+
if (FLAGS_paddle_num_threads >= 4) {
281+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
282+
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
283+
frame_size * 2 /*width of weight*/,
284+
frame_size /*height of height*/);
285+
PADDLE_ENFORCE(packed_gate);
286+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
287+
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
288+
packed_gate);
289+
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
290+
frame_size /*width of weight*/,
291+
frame_size /*height of height*/);
292+
PADDLE_ENFORCE(packed_state);
293+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
294+
frame_size, T(1.0), gru_value.state_weight, frame_size,
295+
packed_state);
296+
for (size_t n = 0; n < seq_len; n++) {
297+
int bstart = static_cast<int>(batch_starts[n]);
298+
int bend = static_cast<int>(batch_starts[n + 1]);
299+
int cur_batch_size = bend - bstart;
300+
301+
Tensor gate_t = batch_gate->Slice(bstart, bend);
302+
Tensor reset_hidden_prev_t =
303+
batch_reset_hidden_prev->Slice(bstart, bend);
304+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
305+
gru_value.output_value = hidden_t.data<T>();
306+
gru_value.gate_value = gate_t.data<T>();
307+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
308+
309+
if (gru_value.prev_out_value) {
310+
blas.GEMM_COMPUTE(
311+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
312+
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
313+
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
314+
}
315+
316+
math::detail::forward_reset_output(
317+
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
318+
cur_batch_size, active_gate);
319+
320+
if (gru_value.prev_out_value) {
321+
blas.GEMM_COMPUTE(
322+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
323+
gru_value.reset_output_value, frame_size, packed_state,
324+
frame_size, T(1), gru_value.gate_value + frame_size * 2,
325+
frame_size * 3);
326+
}
327+
328+
math::detail::forward_final_output(
329+
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
330+
cur_batch_size, active_node);
331+
332+
gru_value.prev_out_value = gru_value.output_value;
333+
}
334+
335+
blas.GEMM_FREE(packed_gate);
336+
blas.GEMM_FREE(packed_state);
337+
} else {
338+
#endif
339+
for (size_t n = 0; n < seq_len; n++) {
340+
int bstart = static_cast<int>(batch_starts[n]);
341+
int bend = static_cast<int>(batch_starts[n + 1]);
342+
int cur_batch_size = bend - bstart;
343+
344+
Tensor gate_t = batch_gate->Slice(bstart, bend);
345+
Tensor reset_hidden_prev_t =
346+
batch_reset_hidden_prev->Slice(bstart, bend);
347+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
348+
gru_value.output_value = hidden_t.data<T>();
349+
gru_value.gate_value = gate_t.data<T>();
350+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
351+
352+
math::GRUUnitFunctor<DeviceContext, T>::compute(
353+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
354+
active_gate);
355+
356+
gru_value.prev_out_value = gru_value.output_value;
357+
}
358+
#ifdef PADDLE_WITH_MKLML
359+
}
360+
#endif
361+
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
362+
batch_hidden->set_lod(batch_gate->lod());
363+
to_seq(dev_ctx, *batch_hidden, hidden);
364+
}
365+
366+
void Compute(const framework::ExecutionContext& context) const override {
367+
BatchCompute(context);
368+
}
369+
};
370+
214371
} // namespace operators
215372
} // namespace paddle
216373

217374
namespace ops = paddle::operators;
218375
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
219376
paddle::framework::DefaultGradOpDescMaker<true>);
220377
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
221-
REGISTER_OP_CPU_KERNEL(
222-
gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float>,
223-
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>);
378+
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
379+
ops::GRUCPUKernel<double>);
224380
REGISTER_OP_CPU_KERNEL(
225381
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>,
226382
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/gru_op.cu.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,96 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/gru_op.h"
1616

17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename DeviceContext, typename T>
21+
class GRUKernel : public framework::OpKernel<T> {
22+
public:
23+
void BatchCompute(const framework::ExecutionContext& context) const {
24+
auto* input = context.Input<LoDTensor>("Input");
25+
auto* h0 = context.Input<Tensor>("H0");
26+
auto* weight = context.Input<Tensor>("Weight");
27+
const T* weight_data = weight->data<T>();
28+
auto* bias = context.Input<Tensor>("Bias");
29+
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
30+
batch_gate->mutable_data<T>(context.GetPlace());
31+
auto* batch_reset_hidden_prev =
32+
context.Output<LoDTensor>("BatchResetHiddenPrev");
33+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
34+
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
35+
batch_hidden->mutable_data<T>(context.GetPlace());
36+
auto* hidden = context.Output<LoDTensor>("Hidden");
37+
hidden->mutable_data<T>(context.GetPlace());
38+
39+
auto hidden_dims = hidden->dims();
40+
41+
bool is_reverse = context.Attr<bool>("is_reverse");
42+
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
43+
auto& dev_ctx = context.template device_context<DeviceContext>();
44+
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
45+
46+
if (bias) {
47+
math::RowwiseAdd<DeviceContext, T> add_bias;
48+
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
49+
}
50+
51+
int frame_size = hidden_dims[1];
52+
math::GRUMetaValue<T> gru_value;
53+
gru_value.gate_weight = const_cast<T*>(weight_data);
54+
gru_value.state_weight =
55+
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
56+
Tensor ordered_h0;
57+
58+
framework::Vector<size_t> order(batch_gate->lod()[2]);
59+
60+
if (h0) {
61+
// Since the batch computing for GRU reorders the input sequences
62+
// according to their length. The initialized cell state also needs
63+
// to reorder.
64+
ReorderInitState<DeviceContext, T>(
65+
context.template device_context<DeviceContext>(), *h0, order,
66+
&ordered_h0, true);
67+
gru_value.prev_out_value = ordered_h0.data<T>();
68+
} else {
69+
gru_value.prev_out_value = nullptr;
70+
}
71+
auto batch_starts = batch_gate->lod()[0];
72+
size_t num_batch = batch_starts.size() - 1;
73+
auto active_node = math::detail::GetActivationType(
74+
context.Attr<std::string>("activation"));
75+
auto active_gate = math::detail::GetActivationType(
76+
context.Attr<std::string>("gate_activation"));
77+
for (size_t n = 0; n < num_batch; n++) {
78+
int bstart = static_cast<int>(batch_starts[n]);
79+
int bend = static_cast<int>(batch_starts[n + 1]);
80+
int cur_batch_size = bend - bstart;
81+
82+
Tensor gate_t = batch_gate->Slice(bstart, bend);
83+
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
84+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
85+
gru_value.output_value = hidden_t.data<T>();
86+
gru_value.gate_value = gate_t.data<T>();
87+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
88+
math::GRUUnitFunctor<DeviceContext, T>::compute(
89+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
90+
active_gate);
91+
gru_value.prev_out_value = gru_value.output_value;
92+
}
93+
94+
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
95+
batch_hidden->set_lod(batch_gate->lod());
96+
to_seq(dev_ctx, *batch_hidden, hidden);
97+
}
98+
99+
void Compute(const framework::ExecutionContext& context) const override {
100+
BatchCompute(context);
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle
106+
17107
namespace ops = paddle::operators;
18108
REGISTER_OP_CUDA_KERNEL(
19109
gru, ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,

paddle/fluid/operators/gru_op.h

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -37,90 +37,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
3737
row_shuffle(ctx, src, index_lod, dst, indexed_src);
3838
}
3939

40-
template <typename DeviceContext, typename T>
41-
class GRUKernel : public framework::OpKernel<T> {
42-
public:
43-
void BatchCompute(const framework::ExecutionContext& context) const {
44-
auto* input = context.Input<LoDTensor>("Input");
45-
auto* h0 = context.Input<Tensor>("H0");
46-
auto* weight = context.Input<Tensor>("Weight");
47-
const T* weight_data = weight->data<T>();
48-
auto* bias = context.Input<Tensor>("Bias");
49-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
50-
batch_gate->mutable_data<T>(context.GetPlace());
51-
auto* batch_reset_hidden_prev =
52-
context.Output<LoDTensor>("BatchResetHiddenPrev");
53-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
54-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
55-
batch_hidden->mutable_data<T>(context.GetPlace());
56-
auto* hidden = context.Output<LoDTensor>("Hidden");
57-
hidden->mutable_data<T>(context.GetPlace());
58-
59-
auto hidden_dims = hidden->dims();
60-
61-
bool is_reverse = context.Attr<bool>("is_reverse");
62-
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
63-
auto& dev_ctx = context.template device_context<DeviceContext>();
64-
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
65-
66-
if (bias) {
67-
math::RowwiseAdd<DeviceContext, T> add_bias;
68-
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
69-
}
70-
71-
int frame_size = hidden_dims[1];
72-
math::GRUMetaValue<T> gru_value;
73-
gru_value.gate_weight = const_cast<T*>(weight_data);
74-
gru_value.state_weight =
75-
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
76-
Tensor ordered_h0;
77-
78-
framework::Vector<size_t> order(batch_gate->lod()[2]);
79-
80-
if (h0) {
81-
// Since the batch computing for GRU reorders the input sequences
82-
// according to their length. The initialized cell state also needs
83-
// to reorder.
84-
ReorderInitState<DeviceContext, T>(
85-
context.template device_context<DeviceContext>(), *h0, order,
86-
&ordered_h0, true);
87-
gru_value.prev_out_value = ordered_h0.data<T>();
88-
} else {
89-
gru_value.prev_out_value = nullptr;
90-
}
91-
auto batch_starts = batch_gate->lod()[0];
92-
size_t num_batch = batch_starts.size() - 1;
93-
auto active_node = math::detail::GetActivationType(
94-
context.Attr<std::string>("activation"));
95-
auto active_gate = math::detail::GetActivationType(
96-
context.Attr<std::string>("gate_activation"));
97-
for (size_t n = 0; n < num_batch; n++) {
98-
int bstart = static_cast<int>(batch_starts[n]);
99-
int bend = static_cast<int>(batch_starts[n + 1]);
100-
int cur_batch_size = bend - bstart;
101-
102-
Tensor gate_t = batch_gate->Slice(bstart, bend);
103-
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
104-
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
105-
gru_value.output_value = hidden_t.data<T>();
106-
gru_value.gate_value = gate_t.data<T>();
107-
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
108-
math::GRUUnitFunctor<DeviceContext, T>::compute(
109-
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
110-
active_gate);
111-
gru_value.prev_out_value = gru_value.output_value;
112-
}
113-
114-
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
115-
batch_hidden->set_lod(batch_gate->lod());
116-
to_seq(dev_ctx, *batch_hidden, hidden);
117-
}
118-
119-
void Compute(const framework::ExecutionContext& context) const override {
120-
BatchCompute(context);
121-
}
122-
};
123-
12440
template <typename DeviceContext, typename T>
12541
class GRUGradKernel : public framework::OpKernel<T> {
12642
public:

0 commit comments

Comments
 (0)