Skip to content

Commit 6b0ff15

Browse files
authored
[CustomOP Inplace] Update custom operator inplace document, ignore Dtype and Shape function (PaddlePaddle#5765)
1 parent 053536a commit 6b0ff15

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

docs/guides/custom_op/new_cpp_op_cn.md

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,9 @@ PD_BUILD_GRAD_OP(custom_inplace_relu)
12321232

12331233
3. 一方面,做 inplace 映射的输出 Tensor,不再作为函数的返回值,如果此时函数没有需要返回的 Tensor,函数的输出类型应为 `void` ;另一方面,其他没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 `std::vector<paddle::Tensor>`。例如 `ReluCpuInplaceForward` 函数中不再显式输出 Tensor,因此函数返回类型为 `void`
12341234

1235-
4. 框架会对算子的输入、输出映射做基本的正确性检查(`SetInplaceMap`中指定的输入 Tensor 命名与 `Inputs` 中定义的名称一致;输出 Tensor 命名与 `Outputs` 中定义的名称一致),因此 `SetInplaceMap` 必须在 `Inputs``Outputs` 之后指定。
1235+
4. 框架会自动为 inplace 的输入输出做 Shape 和 Dtype 映射。因此 `InferShape``InferDtype` 函数只需要返回未被 inplace 映射的输出类型。如果没有需要返回的值,可以不设置这两个函数。
1236+
1237+
5. 框架会对算子的输入、输出映射做基本的正确性检查(`SetInplaceMap`中指定的输入 Tensor 命名与 `Inputs` 中定义的名称一致;输出 Tensor 命名与 `Outputs` 中定义的名称一致),因此 `SetInplaceMap` 必须在 `Inputs``Outputs` 之后指定。
12361238

12371239
下面以一个自定义的 inplace `custom_add` 加法实现为例,来对上述的注意事项进行介绍:
12381240

@@ -1271,17 +1273,7 @@ void AddForward(paddle::Tensor& x, // 输入的 inplace Tensor 类型
12711273
// 输出 Tensor out 指定了 inplace 映射,因此不需要显式的返回
12721274
}
12731275

1274-
// InferDtype 函数的输入类型不需要做特别修改
1275-
std::vector<paddle::DataType> AddInferDtype(const paddle::DataType& x_dtype,
1276-
const paddle::DataType& y_dtype) {
1277-
return {x_dtype};
1278-
}
1279-
1280-
// InferShape 函数的输入类型不需要做特别修改
1281-
std::vector<std::vector<int64_t>> AddInferShape(
1282-
const std::vector<int64_t>& x_shape, const std::vector<int64_t>& y_shape) {
1283-
return {x_shape};
1284-
}
1276+
// 输入的 Tensor 已通过 inplace 指定,不需要设置 InferShapeFn 和 InferDtypeFn
12851277

12861278
// 没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 std::vector<paddle::Tensor>
12871279
std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x,
@@ -1306,9 +1298,7 @@ PD_BUILD_OP(custom_add)
13061298
.Inputs({"X", "Y"})
13071299
.Outputs({"Out"})
13081300
.SetInplaceMap({{"X", "Out"}}) // 使用 `SetInplaceMap` 指明输入和输出间 inplace 的映射关系
1309-
.SetKernelFn(PD_KERNEL(AddForward))
1310-
.SetInferShapeFn(PD_INFER_SHAPE(AddInferShape))
1311-
.SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype));
1301+
.SetKernelFn(PD_KERNEL(AddForward));
13121302

13131303
PD_BUILD_GRAD_OP(custom_add)
13141304
.Inputs({"X", "Y", paddle::Grad("Out")})

0 commit comments

Comments
 (0)