@@ -1232,7 +1232,9 @@ PD_BUILD_GRAD_OP(custom_inplace_relu)
1232
1232
1233
1233
3 . 一方面,做 inplace 映射的输出 Tensor,不再作为函数的返回值,如果此时函数没有需要返回的 Tensor,函数的输出类型应为 ` void ` ;另一方面,其他没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 ` std::vector<paddle::Tensor> ` 。例如 ` ReluCpuInplaceForward ` 函数中不再显式输出 Tensor,因此函数返回类型为 ` void ` ;
1234
1234
1235
- 4 . 框架会对算子的输入、输出映射做基本的正确性检查(` SetInplaceMap ` 中指定的输入 Tensor 命名与 ` Inputs ` 中定义的名称一致;输出 Tensor 命名与 ` Outputs ` 中定义的名称一致),因此 ` SetInplaceMap ` 必须在 ` Inputs ` 和 ` Outputs ` 之后指定。
1235
+ 4 . 框架会自动为 inplace 的输入输出做 Shape 和 Dtype 映射。因此 ` InferShape ` 和 ` InferDtype ` 函数只需要返回未被 inplace 映射的输出类型。如果没有需要返回的值,可以不设置这两个函数。
1236
+
1237
+ 5 . 框架会对算子的输入、输出映射做基本的正确性检查(` SetInplaceMap ` 中指定的输入 Tensor 命名与 ` Inputs ` 中定义的名称一致;输出 Tensor 命名与 ` Outputs ` 中定义的名称一致),因此 ` SetInplaceMap ` 必须在 ` Inputs ` 和 ` Outputs ` 之后指定。
1236
1238
1237
1239
下面以一个自定义的 inplace ` custom_add ` 加法实现为例,来对上述的注意事项进行介绍:
1238
1240
@@ -1271,17 +1273,7 @@ void AddForward(paddle::Tensor& x, // 输入的 inplace Tensor 类型
1271
1273
// 输出 Tensor out 指定了 inplace 映射,因此不需要显式的返回
1272
1274
}
1273
1275
1274
- // InferDtype 函数的输入类型不需要做特别修改
1275
- std::vector< paddle::DataType > AddInferDtype(const paddle::DataType& x_dtype,
1276
- const paddle::DataType& y_dtype) {
1277
- return {x_dtype};
1278
- }
1279
-
1280
- // InferShape 函数的输入类型不需要做特别修改
1281
- std::vector< std::vector<int64_t > > AddInferShape(
1282
- const std::vector<int64_t>& x_shape, const std::vector<int64_t>& y_shape) {
1283
- return {x_shape};
1284
- }
1276
+ // 输入的 Tensor 已通过 inplace 指定,不需要设置 InferShapeFn 和 InferDtypeFn
1285
1277
1286
1278
// 没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 std::vector< paddle::Tensor >
1287
1279
std::vector< paddle::Tensor > AddBackward(const paddle::Tensor& x,
@@ -1306,9 +1298,7 @@ PD_BUILD_OP(custom_add)
1306
1298
.Inputs({"X", "Y"})
1307
1299
.Outputs({"Out"})
1308
1300
.SetInplaceMap({{"X", "Out"}}) // 使用 ` SetInplaceMap ` 指明输入和输出间 inplace 的映射关系
1309
- .SetKernelFn(PD_KERNEL(AddForward))
1310
- .SetInferShapeFn(PD_INFER_SHAPE(AddInferShape))
1311
- .SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype));
1301
+ .SetKernelFn(PD_KERNEL(AddForward));
1312
1302
1313
1303
PD_BUILD_GRAD_OP(custom_add)
1314
1304
.Inputs({"X", "Y", paddle::Grad("Out")})
0 commit comments