Skip to content

Commit 6a10e60

Browse files
authored
[CustomOP Optional] CustomOP supports optional vector<Tensor> input (#51973)
1 parent 5754aae commit 6a10e60

File tree

5 files changed

+343
-100
lines changed

5 files changed

+343
-100
lines changed

paddle/fluid/framework/custom_operator.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,20 @@ static void RunKernelFunc(
174174
custom_t.set_impl(std::make_shared<phi::DenseTensor>(*x));
175175
custom_vec_in.emplace_back(custom_t);
176176
}
177-
} else { // optional inputs, `custom_vec_in` is empty
177+
} else { // optional inputs.
178178
PADDLE_ENFORCE(
179179
detail::IsOptionalVar(in_name),
180180
phi::errors::NotFound("Your custom operator's KernelFunc cannot "
181181
"find input parameter `%s`",
182182
in_name));
183183
VLOG(3) << "Custom Operator: KernelFunc's vector input " << in_name
184184
<< " is optional dtype with None input";
185+
// NOTE(HongyuJia): In dygraph mode, we can not distinguish Tensor and
186+
// vector<Tensor> when user inputs None, so dygraph mode appends one
187+
// un-initialized Tensor to CustomOpKernelContext. To be compatible with
188+
// dygraph mode, `custom_vec_in` also emplace_back one un-initialized
189+
// tensor here.
190+
custom_vec_in.emplace_back(paddle::Tensor());
185191
}
186192
kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in));
187193
} else { // inputs Tensor

paddle/fluid/pybind/pybind.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,9 @@ PYBIND11_MODULE(libpaddle, m) {
10601060
if (PyList_Check(obj) || PyTuple_Check(obj)) {
10611061
self.EmplaceBackInputs(
10621062
std::move(CastPyArg2VectorOfTensor(obj, 1)));
1063-
} else if (obj == Py_None) { // check optional Tensor
1063+
} else if (obj == Py_None) {
1064+
// Check optional Tensor, use one un-initialized tensor to
1065+
// indicate both Tensor and vector<Tensor> inputs
10641066
self.EmplaceBackInput(std::move(paddle::Tensor()));
10651067
} else {
10661068
self.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, 1)));

paddle/phi/api/ext/op_meta_info.h

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,26 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
241241
}
242242
};
243243

244+
template <typename... Tail>
245+
struct ComputeCallHelper<const paddle::optional<std::vector<paddle::Tensor>>&,
246+
Tail...> {
247+
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
248+
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
249+
auto& range = ctx->InputRangeAt(in_idx);
250+
auto arg = ctx->InputsBetween(range.first, range.second);
251+
if (arg.empty() || !arg[0].is_initialized()) {
252+
ComputeCallHelper<Tail...>::
253+
template Compute<in_idx + 1, attr_idx, out_idx>(
254+
ctx, pargs..., paddle::none);
255+
} else {
256+
ComputeCallHelper<
257+
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
258+
pargs...,
259+
arg);
260+
}
261+
}
262+
};
263+
244264
PD_SPECIALIZE_ComputeCallHelper(bool);
245265
PD_SPECIALIZE_ComputeCallHelper(int);
246266
PD_SPECIALIZE_ComputeCallHelper(float);
@@ -486,6 +506,33 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
486506
}
487507
};
488508

509+
template <typename... Tail>
510+
struct InferShapeCallHelper<
511+
const paddle::optional<std::vector<std::vector<int64_t>>>&,
512+
Tail...> {
513+
template <int in_idx,
514+
int vec_in_idx,
515+
int attr_idx,
516+
typename... PreviousArgs>
517+
static Return InferShape(
518+
const std::vector<std::vector<int64_t>>& input_shapes,
519+
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
520+
const std::vector<paddle::any>& attrs,
521+
const PreviousArgs&... pargs) {
522+
const std::vector<std::vector<int64_t>>& arg =
523+
vec_input_shapes[vec_in_idx];
524+
if (arg.empty()) {
525+
return InferShapeCallHelper<Tail...>::
526+
template InferShape<in_idx, vec_in_idx + 1, attr_idx>(
527+
input_shapes, vec_input_shapes, attrs, pargs..., paddle::none);
528+
} else {
529+
return InferShapeCallHelper<Tail...>::
530+
template InferShape<in_idx, vec_in_idx + 1, attr_idx>(
531+
input_shapes, vec_input_shapes, attrs, pargs..., arg);
532+
}
533+
}
534+
};
535+
489536
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
490537
// interface, and will be deprecated in the future
491538
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(std::vector<int64_t>);
@@ -593,8 +640,7 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
593640
PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&);
594641

595642
template <typename... Tail>
596-
struct InferDtypeCallHelper<const paddle::optional<paddle::DataType>&,
597-
Tail...> {
643+
struct InferDtypeCallHelper<const paddle::optional<DataType>&, Tail...> {
598644
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
599645
static Return InferDtype(
600646
const std::vector<DataType>& input_dtypes,
@@ -613,6 +659,27 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
613659
}
614660
};
615661

662+
template <typename... Tail>
663+
struct InferDtypeCallHelper<const paddle::optional<std::vector<DataType>>&,
664+
Tail...> {
665+
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
666+
static Return InferDtype(
667+
const std::vector<DataType>& input_dtypes,
668+
const std::vector<std::vector<DataType>>& vec_input_dtypes,
669+
const PreviousArgs&... pargs) {
670+
const std::vector<DataType>& arg = vec_input_dtypes[vec_in_idx];
671+
if (arg.empty()) {
672+
return InferDtypeCallHelper<Tail...>::
673+
template InferDtype<in_idx, vec_in_idx + 1>(
674+
input_dtypes, vec_input_dtypes, pargs..., paddle::none);
675+
} else {
676+
return InferDtypeCallHelper<Tail...>::
677+
template InferDtype<in_idx, vec_in_idx + 1>(
678+
input_dtypes, vec_input_dtypes, pargs..., arg);
679+
}
680+
}
681+
};
682+
616683
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
617684
// interface, and will be deprecated in the future
618685
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType);

python/paddle/fluid/tests/custom_op/custom_optional.cc

Lines changed: 105 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,19 @@
1919
#include "paddle/extension.h"
2020

2121
template <typename data_t>
22-
void add_forward_kernel(const data_t* x_data,
23-
const data_t* y_data,
24-
data_t* out_data,
25-
int64_t numel) {
22+
void add_one_pointer(const data_t* x_data, data_t* out_data, int64_t numel) {
2623
for (size_t i = 0; i < numel; ++i) {
27-
out_data[i] = x_data[i] + y_data[i];
24+
out_data[i] += x_data[i];
2825
}
2926
}
3027

3128
template <typename data_t>
32-
void add_backward_kernel(data_t* x_grad_data,
33-
const data_t* out_grad_data,
34-
int64_t numel) {
29+
void add_two_pointers(const data_t* x_data,
30+
const data_t* y_data,
31+
data_t* out_data,
32+
int64_t numel) {
3533
for (size_t i = 0; i < numel; ++i) {
36-
x_grad_data[i] += out_grad_data[i];
34+
out_data[i] = x_data[i] + y_data[i];
3735
}
3836
}
3937

@@ -53,12 +51,12 @@ std::vector<paddle::Tensor> AddForward(
5351
PD_DISPATCH_FLOATING_TYPES(
5452
x.type(), "AddForward", ([&] {
5553
if (y) {
56-
add_forward_kernel<data_t>(x.data<data_t>(),
57-
y->data<data_t>(),
58-
out.data<data_t>(),
59-
x.size());
54+
add_two_pointers<data_t>(x.data<data_t>(),
55+
y->data<data_t>(),
56+
out.data<data_t>(),
57+
x.size());
6058
} else {
61-
add_forward_kernel<data_t>(
59+
add_two_pointers<data_t>(
6260
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
6361
}
6462
}));
@@ -69,7 +67,6 @@ std::vector<paddle::DataType> AddInferDtype(
6967
const paddle::DataType& x_dtype,
7068
const paddle::optional<paddle::DataType>& y_dtype) {
7169
if (y_dtype) {
72-
std::cout << "DEBUG AddInferDtype" << *y_dtype << std::endl;
7370
return {*y_dtype};
7471
}
7572
return {x_dtype};
@@ -98,18 +95,14 @@ std::vector<paddle::Tensor> AddBackward(
9895
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
9996

10097
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
101-
paddle::Tensor y_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
10298

10399
PD_DISPATCH_FLOATING_TYPES(
104100
out_grad.type(), "AddBackward", ([&] {
105-
add_backward_kernel<data_t>(
106-
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
107-
if (y) {
108-
add_backward_kernel<data_t>(
109-
y_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
110-
} else {
111-
add_backward_kernel<data_t>(
112-
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
101+
add_one_pointer<data_t>(
102+
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
103+
if (!y) {
104+
add_one_pointer<data_t>(
105+
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
113106
}
114107
}));
115108

@@ -127,3 +120,91 @@ PD_BUILD_GRAD_OP(custom_add)
127120
.Inputs({"X", paddle::Optional("Y"), paddle::Grad("Out")})
128121
.Outputs({paddle::Grad("X")})
129122
.SetKernelFn(PD_KERNEL(AddBackward));
123+
124+
/*
125+
if (y) {
126+
out = x + y[0] + y[1] + ...;
127+
} else {
128+
out = x + x;
129+
}
130+
*/
131+
std::vector<paddle::Tensor> AddVectorForward(
132+
const paddle::Tensor& x,
133+
const paddle::optional<std::vector<paddle::Tensor>>& y) { // NOLINT
134+
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
135+
paddle::Tensor out = paddle::zeros(x.shape(), x.dtype(), x.place());
136+
137+
PD_DISPATCH_FLOATING_TYPES(
138+
x.type(), "AddVectorForward", ([&] {
139+
if (y) {
140+
add_one_pointer<data_t>(
141+
x.data<data_t>(), out.data<data_t>(), out.size());
142+
for (size_t i = 0; i < y->size(); ++i) {
143+
add_one_pointer<data_t>(
144+
y->at(i).data<data_t>(), out.data<data_t>(), out.size());
145+
}
146+
} else {
147+
add_two_pointers<data_t>(
148+
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
149+
}
150+
}));
151+
return {out};
152+
}
153+
154+
std::vector<paddle::DataType> AddVectorInferDtype(
155+
const paddle::DataType& x_dtype,
156+
const paddle::optional<std::vector<paddle::DataType>>& y_dtype) {
157+
if (y_dtype) {
158+
return {y_dtype->at(0)};
159+
}
160+
return {x_dtype};
161+
}
162+
163+
std::vector<std::vector<int64_t>> AddVectorInferShape(
164+
const std::vector<int64_t>& x_shape,
165+
const paddle::optional<std::vector<std::vector<int64_t>>>& y_shape) {
166+
if (y_shape) {
167+
return {y_shape->at(0)};
168+
}
169+
return {x_shape};
170+
}
171+
172+
/*
173+
if (y) {
174+
x_grad = out_grad;
175+
} else {
176+
x_grad = out_grad + out_grad;
177+
}
178+
*/
179+
std::vector<paddle::Tensor> AddVectorBackward(
180+
const paddle::Tensor& x,
181+
const paddle::optional<std::vector<paddle::Tensor>>& y,
182+
const paddle::Tensor& out_grad) { // NOLINT
183+
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
184+
185+
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
186+
187+
PD_DISPATCH_FLOATING_TYPES(
188+
out_grad.type(), "AddVectorBackward", ([&] {
189+
add_one_pointer<data_t>(
190+
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
191+
if (!y) {
192+
add_one_pointer<data_t>(
193+
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
194+
}
195+
}));
196+
197+
return {x_grad};
198+
}
199+
200+
PD_BUILD_OP(custom_add_vec)
201+
.Inputs({"X", paddle::Optional(paddle::Vec("Y"))})
202+
.Outputs({"Out"})
203+
.SetKernelFn(PD_KERNEL(AddVectorForward))
204+
.SetInferShapeFn(PD_INFER_SHAPE(AddVectorInferShape))
205+
.SetInferDtypeFn(PD_INFER_DTYPE(AddVectorInferDtype));
206+
207+
PD_BUILD_GRAD_OP(custom_add_vec)
208+
.Inputs({"X", paddle::Optional(paddle::Vec("Y")), paddle::Grad("Out")})
209+
.Outputs({paddle::Grad("X")})
210+
.SetKernelFn(PD_KERNEL(AddVectorBackward));

0 commit comments

Comments
 (0)