Skip to content

Commit 36ffc8b

Browse files
authored
Cherry pick seq2seq api from #19820 (#20555)
* Add seq2seq api related code (#19820) * fix expand bug (#20340) * fix expand bug test=develop * fix style test=develop * fix style test=develop * fix style test=develop * fix style test=develop * Fix the assign data check (#20564) * Fix the assign data check. test=develop * Fix test_assign_op.py. test=develop * Update API.spec
1 parent c8de728 commit 36ffc8b

28 files changed

+2545
-91
lines changed

paddle/fluid/API.spec

+46-12
Large diffs are not rendered by default.

paddle/fluid/operators/assign_op.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,12 @@ REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
154154
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer);
155155
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
156156
ops::AssignKernel, int, ops::AssignKernel,
157-
int64_t, ops::AssignKernel);
157+
int64_t, ops::AssignKernel, bool,
158+
ops::AssignKernel);
158159

159160
#ifdef PADDLE_WITH_CUDA
160161
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
161162
ops::AssignKernel, int, ops::AssignKernel,
162-
int64_t, ops::AssignKernel);
163+
int64_t, ops::AssignKernel, bool,
164+
ops::AssignKernel);
163165
#endif

paddle/fluid/operators/expand_op.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,11 @@ REGISTER_OP_CPU_KERNEL(
228228
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
229229
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
230230
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>,
231+
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int64_t>,
231232
ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>);
232233
REGISTER_OP_CPU_KERNEL(
233234
expand_grad,
234235
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
235-
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>);
236+
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
237+
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
238+
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/expand_op.cu

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ REGISTER_OP_CUDA_KERNEL(
1818
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
1919
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
2020
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
21+
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int64_t>,
2122
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
2223
REGISTER_OP_CUDA_KERNEL(
2324
expand_grad,
2425
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
25-
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>);
26+
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
27+
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
28+
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/fill_constant_batch_size_like_op.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
3838
.SetDefault(framework::proto::VarType::FP32);
3939
AddAttr<float>("value", "default 0. The value to be filled")
4040
.SetDefault(0.0f);
41+
AddAttr<bool>("force_cpu",
42+
"(bool, default false) Force fill output variable to cpu "
43+
"memory. Otherwise, fill output variable to the running "
44+
"device")
45+
.SetDefault(false);
4146
AddComment(R"DOC(
4247
This function creates a tensor of specified *shape*, *dtype* and batch size,
4348
and initializes this with a constant supplied in *value*. The batch size is
@@ -65,4 +70,6 @@ REGISTER_OP_CPU_KERNEL(
6570
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
6671
int>,
6772
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
68-
int64_t>);
73+
int64_t>,
74+
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
75+
bool>);

paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ REGISTER_OP_CUDA_KERNEL(
2525
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
2626
int>,
2727
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
28-
int64_t>);
28+
int64_t>,
29+
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
30+
bool>);

paddle/fluid/operators/fill_constant_batch_size_like_op.h

+14-5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ template <typename DeviceContext, typename T>
2323
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto data_type =
27+
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
28+
auto value = ctx.Attr<float>("value");
29+
auto force_cpu = ctx.Attr<bool>("force_cpu");
30+
2631
auto* out = ctx.Output<framework::Tensor>("Out");
2732
auto* in = ctx.Input<framework::LoDTensor>("Input");
2833
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
@@ -32,12 +37,16 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
3237
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
3338
out->mutable_data<T>(odims, ctx.GetPlace());
3439
}
35-
out->mutable_data<T>(ctx.GetPlace());
36-
auto value = ctx.Attr<float>("value");
3740

38-
math::SetConstant<DeviceContext, T> setter;
39-
setter(ctx.template device_context<DeviceContext>(), out,
40-
static_cast<T>(value));
41+
if (force_cpu) {
42+
out->mutable_data(platform::CPUPlace(), data_type);
43+
} else {
44+
out->mutable_data(ctx.GetPlace(), data_type);
45+
}
46+
47+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
48+
auto& dev_ctx = *pool.Get(ctx.GetPlace());
49+
math::set_constant(dev_ctx, out, value);
4150
}
4251
};
4352

paddle/fluid/operators/fill_constant_op.cu.cc

+1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
1919
ops::FillConstantKernel<double>,
2020
ops::FillConstantKernel<int64_t>,
2121
ops::FillConstantKernel<int>,
22+
ops::FillConstantKernel<bool>,
2223
ops::FillConstantKernel<paddle::platform::float16>);

paddle/fluid/operators/gather_nd_op.cc

+8-3
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,13 @@ class GatherNdOp : public framework::OperatorWithKernel {
6060
protected:
6161
framework::OpKernelType GetExpectedKernelType(
6262
const framework::ExecutionContext& ctx) const override {
63-
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
64-
ctx.device_context());
63+
auto* x = ctx.Input<Tensor>("X");
64+
const auto& x_type = x->type();
65+
return framework::OpKernelType(
66+
x_type,
67+
x_type == framework::proto::VarType::BOOL
68+
? x->place() // to be consistent with compare and logical ops
69+
: ctx.device_context().GetPlace());
6570
}
6671
};
6772

@@ -173,7 +178,7 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
173178
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
174179
ops::GatherNdOpKernel<double>,
175180
ops::GatherNdOpKernel<int64_t>,
176-
ops::GatherNdOpKernel<int>,
181+
ops::GatherNdOpKernel<int>, ops::GatherNdOpKernel<bool>,
177182
ops::GatherNdOpKernel<uint8_t>);
178183

179184
REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,

paddle/fluid/operators/gather_nd_op.cu

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
9595
ops::GatherNdOpCUDAKernel<CUDA, double>,
9696
ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
9797
ops::GatherNdOpCUDAKernel<CUDA, int>,
98+
ops::GatherNdOpCUDAKernel<CUDA, bool>,
9899
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
99100

100101
REGISTER_OP_CUDA_KERNEL(gather_nd_grad,
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/gather_tree_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class GatherTreeOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Ids"),
26+
"Input(Ids) of GatherTreeOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("Parents"),
28+
"Input(Parents) of GatherTreeOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of GatherTreeOp should not be null.");
31+
32+
auto ids_dims = ctx->GetInputDim("Ids");
33+
auto parents_dims = ctx->GetInputDim("Parents");
34+
PADDLE_ENFORCE(ids_dims == parents_dims,
35+
"The shape of Input(Parents) must be same with the shape of "
36+
"Input(Ids).");
37+
ctx->SetOutputDim("Out", ids_dims);
38+
}
39+
40+
protected:
41+
framework::OpKernelType GetExpectedKernelType(
42+
const framework::ExecutionContext& ctx) const override {
43+
return framework::OpKernelType(ctx.Input<Tensor>("Ids")->type(),
44+
ctx.device_context());
45+
}
46+
};
47+
48+
class GatherTreeOpMaker : public framework::OpProtoAndCheckerMaker {
49+
public:
50+
void Make() override {
51+
AddInput("Ids",
52+
"The Tensor with shape [length, batch_size, beam_size] containing "
53+
"the selected ids of all time steps.");
54+
AddInput("Parents",
55+
"The Tensor has the same shape as Ids and contains the parents "
56+
"corresponding to selected ids when searching among beams.");
57+
AddOutput(
58+
"Out",
59+
"A Tensor with shape [length, batch_size, beam_size] containing the "
60+
"full sequences. The sequences is collected by backtracing from the "
61+
"last time step of Ids.");
62+
AddComment(R"DOC(
63+
GatherTree Operator.
64+
65+
Backtrace from the last time step and generate the full sequences by collecting beam search
66+
selected ids.
67+
68+
)DOC");
69+
}
70+
};
71+
72+
} // namespace operators
73+
} // namespace paddle
74+
75+
namespace ops = paddle::operators;
76+
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
77+
REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel<int32_t>,
78+
ops::GatherTreeOpKernel<int64_t>);
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/operators/gather_tree_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
#define CUDA_1D_KERNEL_LOOP(i, n) \
23+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
24+
i += blockDim.x * gridDim.x)
25+
26+
template <typename T>
27+
__global__ void GatherTree(const T *ids_data, const T *parents_data,
28+
T *out_data, const int64_t max_length,
29+
const int64_t batch_size, const int64_t beam_size) {
30+
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_size) {
31+
int batch = i / beam_size;
32+
int beam = i % beam_size;
33+
auto idx =
34+
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
35+
out_data[idx] = ids_data[idx];
36+
auto parent = parents_data[idx];
37+
for (int step = max_length - 2; step >= 0; step--) {
38+
idx = step * batch_size * beam_size + batch * beam_size;
39+
out_data[idx + beam] = ids_data[idx + parent];
40+
parent = parents_data[idx + parent];
41+
}
42+
}
43+
}
44+
45+
template <typename T>
46+
class GatherTreeOpCUDAKernel : public framework::OpKernel<T> {
47+
public:
48+
void Compute(const framework::ExecutionContext &ctx) const override {
49+
auto *ids = ctx.Input<Tensor>("Ids");
50+
auto *parents = ctx.Input<Tensor>("Parents");
51+
auto *out = ctx.Output<Tensor>("Out");
52+
53+
const auto *ids_data = ids->data<T>();
54+
const auto *parents_data = parents->data<T>();
55+
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
56+
57+
auto &ids_dims = ids->dims();
58+
int64_t max_length = ids_dims[0];
59+
int64_t batch_size = ids_dims[1];
60+
int64_t beam_size = ids_dims[2];
61+
62+
auto &dev_ctx = ctx.cuda_device_context();
63+
64+
const int block = 512;
65+
int max_threads =
66+
std::min(static_cast<int64_t>(dev_ctx.GetMaxPhysicalThreadCount()),
67+
batch_size * beam_size);
68+
const int grid = std::max(max_threads / block, 1);
69+
GatherTree<<<grid, block>>>(ids_data, parents_data, out_data, max_length,
70+
batch_size, beam_size);
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle
76+
77+
namespace ops = paddle::operators;
78+
79+
REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel<int32_t>,
80+
ops::GatherTreeOpCUDAKernel<int64_t>);
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
template <typename T>
25+
class GatherTreeOpKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext &ctx) const override {
28+
auto *ids = ctx.Input<Tensor>("Ids");
29+
auto *parents = ctx.Input<Tensor>("Parents");
30+
auto *out = ctx.Output<Tensor>("Out");
31+
32+
const auto *ids_data = ids->data<T>();
33+
const auto *parents_data = parents->data<T>();
34+
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
35+
36+
auto &ids_dims = ids->dims();
37+
auto max_length = ids_dims[0];
38+
auto batch_size = ids_dims[1];
39+
auto beam_size = ids_dims[2];
40+
41+
for (int batch = 0; batch < batch_size; batch++) {
42+
for (int beam = 0; beam < beam_size; beam++) {
43+
auto idx = (max_length - 1) * batch_size * beam_size +
44+
batch * beam_size + beam;
45+
out_data[idx] = ids_data[idx];
46+
auto parent = parents_data[idx];
47+
for (int step = max_length - 2; step >= 0; step--) {
48+
idx = step * batch_size * beam_size + batch * beam_size;
49+
out_data[idx + beam] = ids_data[idx + parent];
50+
parent = parents_data[idx + parent];
51+
}
52+
}
53+
}
54+
}
55+
};
56+
57+
} // namespace operators
58+
} // namespace paddle

paddle/fluid/operators/reduce_ops/reduce_all_op.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
1616

17-
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all);
17+
// kernel's device type is decided by input tensor place, to be consistent with
18+
// compare and logical ops
19+
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all, UseInputPlace);
1820
REGISTER_OP_CPU_KERNEL(reduce_all,
1921
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
2022
bool, ops::AllFunctor>);

paddle/fluid/operators/reduce_ops/reduce_any_op.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
1616

17-
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any);
17+
// kernel's device type is decided by input tensor place, to be consistent with
18+
// compare and logical ops
19+
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any, UseInputPlace);
1820
REGISTER_OP_CPU_KERNEL(reduce_any,
1921
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
2022
bool, ops::AnyFunctor>);

0 commit comments

Comments
 (0)