From bede385bc26a741fb48c49804a8eb95266890027 Mon Sep 17 00:00:00 2001 From: jiangfan06 Date: Thu, 24 Aug 2023 19:03:15 +0800 Subject: [PATCH 1/4] add element_mul_add_fuse_pass & elementwise_madd_xpu kernel --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../ir/xpu/element_mul_add_fuse_pass.cc | 330 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/fused_ops.yaml | 9 + paddle/phi/backends/xpu/xpu2_op_list.cc | 4 + paddle/phi/infermeta/fusion.cc | 9 + paddle/phi/infermeta/fusion.h | 5 + .../fusion/xpu/elementwise_madd_xpu_kernel.cc | 61 ++++ paddle/phi/kernels/xpu/concat_kernel.cc | 4 +- .../kernels/xpu/plugin/include/xpu/plugin.h | 3 + .../src/kernel/kunlun2cpp/fast_mul_add.xpu | 77 ++++ .../xpu/plugin/src/wrapper/fast_mul_add.cpp | 76 ++++ .../test_xpu_element_mul_add_fuse_pass.py | 72 ++++ 13 files changed, 652 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc create mode 100644 paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu create mode 100644 paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp create mode 100644 test/ir/inference/test_xpu_element_mul_add_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b1dafb0d3934db..dee1cc4f9ea1b9 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -290,6 +290,8 @@ if(WITH_XPU) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(element_mul_add_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc new file mode 100644 index 00000000000000..996fd17bbb6530 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc @@ -0,0 +1,330 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +/* +fuse elementwise_mul + elementwise_add op to elementwise_madd op +For example: +graph: + x y + \ / + \ / + elementwise_mul w + \ / + \ / + elementwise_add + | + | + output +------------------------------------------------------ +After the pass is applied: + x y w + \ | / + \ | / + elementwise_madd + | + | + output +*/ +struct ElementMulAddFusePattern : public PatternBase { + ElementMulAddFusePattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(elementwise_mul); + PATTERN_DECL_NODE(elementwise_add); + // declare variable node's name + PATTERN_DECL_NODE(mul_x); + PATTERN_DECL_NODE(mul_y); + PATTERN_DECL_NODE(mul_out); + PATTERN_DECL_NODE(add_w); + PATTERN_DECL_NODE(add_out); +}; + +ElementMulAddFusePattern::ElementMulAddFusePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto elementwise_mul = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + auto elementwise_add = + pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add"); + auto mul_x = pattern->NewNode(mul_x_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "X"); + auto mul_y = pattern->NewNode(mul_y_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "Y"); + auto mul_out = pattern->NewNode(mul_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add", "X"); + elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out}); + auto add_w = pattern->NewNode(add_w_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto add_out = pattern->NewNode(add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + elementwise_add->LinksFrom({mul_out, add_w}).LinksTo({add_out}); +} + +/* +special case for elementwise_madd op: +graph: + x y + \ / + \ / + elementwise_mul x + \ / + \ / + elementwise_add + | + | + output +------------------------------------------------------ +After the pass is applied: + x y + \ / + \ / + elementwise_madd + | + | + output +*/ +struct ElementMulAddFuseXYPattern : public PatternBase { + ElementMulAddFuseXYPattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(elementwise_mul); + PATTERN_DECL_NODE(elementwise_add); + // declare variable node's name + PATTERN_DECL_NODE(mul_x); + PATTERN_DECL_NODE(mul_y); + PATTERN_DECL_NODE(mul_out); + PATTERN_DECL_NODE(add_out); +}; + +ElementMulAddFuseXYPattern::ElementMulAddFuseXYPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto elementwise_mul = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + auto elementwise_add = + pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add"); + auto mul_x = pattern->NewNode(mul_x_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("elementwise_add", "Y"); + auto mul_y = pattern->NewNode(mul_y_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "Y"); + auto mul_out = pattern->NewNode(mul_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add", "X"); + elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out}); + auto add_out = pattern->NewNode(add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + elementwise_add->LinksFrom({mul_out, mul_x}).LinksTo({add_out}); +} +} // namespace patterns + +class ElementMulAddFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FuseElementMulAdd(ir::Graph* graph) const; + void FuseElementMulAddWithOnlyXY(ir::Graph* graph) const; + + const std::string name_scope_{"element_mul_add_fuse_pass"}; +}; + +void ElementMulAddFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FuseElementMulAdd(graph); + FuseElementMulAddWithOnlyXY(graph); +} + +void ElementMulAddFusePass::FuseElementMulAdd(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ElementMulAddFusePattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ElementMulAddFusePass"; + // declare operator node's name + GET_IR_NODE(elementwise_mul); + GET_IR_NODE(elementwise_add); + // declare variable node's name + GET_IR_NODE(mul_x); + GET_IR_NODE(mul_y); + GET_IR_NODE(mul_out); + GET_IR_NODE(add_w); + GET_IR_NODE(add_out); + + bool flag = true; + auto var_type = mul_x->Var()->GetDataType(); + if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) { + flag = false; + } + + auto x_shape = mul_x->Var()->GetShape(); + auto y_shape = mul_y->Var()->GetShape(); + auto w_shape = add_w->Var()->GetShape(); + if (x_shape.size() == y_shape.size() && x_shape.size() == w_shape.size()) { + for (int i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i] || x_shape[i] != w_shape[i]) { + flag = false; + } + } + } else { + flag = false; + } + + if (flag) { + auto* block = elementwise_mul->Op()->Block(); + + // delete useless node + std::unordered_set delete_nodes; + + // Generate elementwise_madd op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("elementwise_madd"); + fused_op_desc.SetInput("x", {mul_x->Name()}); + fused_op_desc.SetInput("y", {mul_y->Name()}); + fused_op_desc.SetInput("w", {add_w->Name()}); + fused_op_desc.SetOutput("out", {add_out->Name()}); + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(mul_x, fused_op); + IR_NODE_LINK_TO(mul_y, fused_op); + IR_NODE_LINK_TO(add_w, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + delete_nodes.insert({elementwise_mul, elementwise_add, mul_out}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void ElementMulAddFusePass::FuseElementMulAddWithOnlyXY( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ElementMulAddFuseXYPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ElementMulAddFusePass"; + // declare operator node's name + GET_IR_NODE(elementwise_mul); + GET_IR_NODE(elementwise_add); + // declare variable node's name + GET_IR_NODE(mul_x); + GET_IR_NODE(mul_y); + GET_IR_NODE(mul_out); + GET_IR_NODE(add_out); + + bool flag = true; + auto var_type = mul_x->Var()->GetDataType(); + if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) { + flag = false; + } + + auto x_shape = mul_x->Var()->GetShape(); + auto y_shape = mul_y->Var()->GetShape(); + if (x_shape.size() == y_shape.size()) { + for (int i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i]) { + flag = false; + } + } + } else { + flag = false; + } + + if (flag) { + auto* block = elementwise_mul->Op()->Block(); + + // delete useless node + std::unordered_set delete_nodes; + + // Generate elementwise_madd op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("elementwise_madd"); + fused_op_desc.SetInput("x", {mul_x->Name()}); + fused_op_desc.SetInput("y", {mul_y->Name()}); + fused_op_desc.SetInput("w", {mul_x->Name()}); + fused_op_desc.SetOutput("out", {add_out->Name()}); + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(mul_x, fused_op); + IR_NODE_LINK_TO(mul_y, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + delete_nodes.insert({elementwise_mul, elementwise_add, mul_out}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(element_mul_add_fuse_pass, + paddle::framework::ir::ElementMulAddFusePass); + +REGISTER_PASS_CAPABILITY(element_mul_add_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("elementwise_add", 0) + .LE("elementwise_add", 1) + .GE("elementwise_mul", 0) + .LE("elementwise_mul", 1)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 09d1197d35b556..0d1887badf8496 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -552,6 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fast_layernorm_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "fast_where_xpu_fuse_pass", + "element_mul_add_fuse_pass", "link_xpu_op_max_pass", "delete_isolated_node_pass", // "auto_mixed_precision_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 648384422ca8ab..7850eb1afb3d58 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -53,6 +53,15 @@ data_type : x optional : bias, branch, branch_max ,x_max +- op : elementwise_madd + args : (Tensor x, Tensor y, Tensor w) + output : Tensor(out) + infer_meta : + func : ElementwiseMaddXPUInferMeta + kernel : + func : elementwise_madd + data_type : x + - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, Tensor mask, int64_t padding_idx) output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9154d1aa092469..9776bf9be8da99 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -161,6 +161,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::INT8, phi::DataType::INT64, phi::DataType::INT32})}, {"conv2d_grad", @@ -238,6 +240,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_madd", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 3143c5cde2e1e5..13d7496cc7450a 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -821,4 +821,13 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void ElementwiseMaddXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 25c27bdd406b96..26c9bf36f0fb32 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -201,4 +201,9 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); +void ElementwiseMaddXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc new file mode 100644 index 00000000000000..2202876e5ab940 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void ElementwiseMaddXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& w, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + const auto* x_data = x.data(); + const auto* y_data = y.data(); + const auto* w_data = w.data(); + + auto* out_data = ctx.template Alloc(out); + +#ifdef PADDLE_WITH_XPU_PLUGIN + int r = xpu::plugin::fast_mul_add(ctx.x_context(), + reinterpret_cast(w_data), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_mul_add"); +#else + int r = xpu::addcmul(ctx.x_context(), + reinterpret_cast(w_data), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + 1.0f, + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "addcmul"); +#endif +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(elementwise_madd, + XPU, + ALL_LAYOUT, + phi::fusion::ElementwiseMaddXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc index f1fac997061c59..5afdf2612981eb 100644 --- a/paddle/phi/kernels/xpu/concat_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -119,4 +119,6 @@ PD_REGISTER_KERNEL(concat, double, phi::dtype::float16, int64_t, - int) {} + int, + int8_t, + bool) {} diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index 1357ca43001c8b..4a43c9e742b4b5 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -114,6 +114,9 @@ DLL_EXPORT int fast_embedding(Context* ctx, int64_t ym, int64_t padding_idx, TID start_index = 0); +template +DLL_EXPORT int fast_mul_add( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len); } // namespace plugin } // namespace api diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu new file mode 100644 index 00000000000000..2a24321d80184e --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu @@ -0,0 +1,77 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +template +static inline __device__ void primitive_addcmul(T* x, const T* y, int len) { + float32x16_t vx0; + float32x16_t vy0; + float32x16_t vx1; + float32x16_t vy1; + for (int i = 0; i < len; i += 32) { + vload2_lm(x + i, vx0, vx1); + vload2_lm(y + i, vy0, vy1); + vx0 = vvmac_float32x16(vx0, vy0, vx0); + vx1 = vvmac_float32x16(vx1, vy1, vx1); + vstore2_lm(x + i, vx0, vx1); + } + mfence_lm(); +} + +template +__global__ void fast_mul_add(const T* x, const T* y, T* z, int64_t len) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + const int buf_len = 512 / sizeof(T); + __simd__ float local_x_after_cast[buf_len]; + __simd__ float local_y_after_cast[buf_len]; + T* local_x = (T*)(local_x_after_cast); + T* local_y = (T*)(local_y_after_cast); + + int loop = 0; + for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) { + int read_len = min(static_cast(buf_len), len - i); + GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T)); + GM2LM(y + i, local_y, read_len * sizeof(T)); + primitive_addcmul(local_x, local_y, read_len); + LM2GM_ASYNC(local_x, z + i, read_len * sizeof(T)); + mfence_lm(); +#ifndef __XPU3__ + loop++; + if ((loop & 0xF) == 0) { + sync_all(); + } +#endif + } +} + +#define _XPU_DEF__FAST_MUL_ADD_(DTYPE) \ + template __global__ void fast_mul_add( \ + const DTYPE* x, const DTYPE* y, DTYPE* z, int64_t len); +_XPU_DEF__FAST_MUL_ADD_(float); +_XPU_DEF__FAST_MUL_ADD_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp new file mode 100644 index 00000000000000..2a1da86fe5f970 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/refactor/util/vector_util.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void fast_mul_add(const T* x, + const T* y, + T* z, + int64_t len); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int xpu2_wrapper( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { + if (x == w) { + xpu2::plugin::fast_mul_add + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, z, len); + } else { + return addcmul(ctx, w, x, y, z, 1.0f, len); + } + return SUCCESS; +} + +template +int fast_mul_add( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_mul_add", T); + WRAPPER_DUMP_PARAM4(ctx, w, x, y, z); + WRAPPER_DUMP_PARAM2(ctx, len, ctx->_l3_mgr.get_size()); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_4PTRS(ctx, T, len, w, x, y, z); + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper(ctx, w, x, y, z, len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int fast_mul_add( + Context*, const float*, const float*, const float*, float*, int64_t); +template int fast_mul_add(Context*, + const float16*, + const float16*, + const float16*, + float16*, + int64_t); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/test/ir/inference/test_xpu_element_mul_add_fuse_pass.py b/test/ir/inference/test_xpu_element_mul_add_fuse_pass.py new file mode 100644 index 00000000000000..a49417c083d0b5 --- /dev/null +++ b/test/ir/inference/test_xpu_element_mul_add_fuse_pass.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestGatherAddTransposePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["elementwise_madd"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=3, max_size=4 + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + mul_op = OpConfig( + "elementwise_mul", + inputs={"X": ["mul_x"], "Y": ["mul_y"]}, + outputs={"Out": ["mul_out"]}, + ) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["mul_out"], "Y": ["add_w"]}, + outputs={"Out": ["add_out"]}, + ) + + ops = [mul_op, add_op] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "mul_x": TensorConfig(data_gen=partial(generate_data, x_shape)), + "mul_y": TensorConfig(data_gen=partial(generate_data, x_shape)), + "add_w": TensorConfig(data_gen=partial(generate_data, x_shape)), + }, + weights={}, + outputs=["add_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, max_examples=25, passes=["element_mul_add_fuse_pass"] + ) + + +if __name__ == "__main__": + unittest.main() From 7eae4ae83551ca9bdef23871267a8dd72db3ed75 Mon Sep 17 00:00:00 2001 From: jiangfan06 Date: Fri, 25 Aug 2023 11:26:23 +0800 Subject: [PATCH 2/4] fix --- .../fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc index 996fd17bbb6530..e78c922e63ad89 100644 --- a/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc @@ -213,8 +213,9 @@ void ElementMulAddFusePass::FuseElementMulAdd(ir::Graph* graph) const { auto y_shape = mul_y->Var()->GetShape(); auto w_shape = add_w->Var()->GetShape(); if (x_shape.size() == y_shape.size() && x_shape.size() == w_shape.size()) { - for (int i = 0; i < x_shape.size(); ++i) { - if (x_shape[i] != y_shape[i] || x_shape[i] != w_shape[i]) { + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i] || x_shape[i] != w_shape[i] || + x_shape[i] == -1) { flag = false; } } @@ -278,8 +279,8 @@ void ElementMulAddFusePass::FuseElementMulAddWithOnlyXY( auto x_shape = mul_x->Var()->GetShape(); auto y_shape = mul_y->Var()->GetShape(); if (x_shape.size() == y_shape.size()) { - for (int i = 0; i < x_shape.size(); ++i) { - if (x_shape[i] != y_shape[i]) { + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i] || x_shape[i] == -1) { flag = false; } } From ecdc4f33718b38e4b69866d2cb128256ef7ac59f Mon Sep 17 00:00:00 2001 From: jiangfan06 Date: Fri, 25 Aug 2023 16:23:53 +0800 Subject: [PATCH 3/4] update after review --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- ...ss.cc => elementwise_mul_add_fuse_pass.cc} | 65 ++++++++++--------- .../inference/api/paddle_pass_builder.cc | 2 +- paddle/phi/api/yaml/fused_ops.yaml | 18 ++--- paddle/phi/backends/xpu/xpu2_op_list.cc | 4 +- paddle/phi/infermeta/fusion.cc | 8 +-- paddle/phi/infermeta/fusion.h | 8 +-- ...dd_xpu_kernel.cc => addcmul_xpu_kernel.cc} | 18 ++--- .../kernels/xpu/plugin/include/xpu/plugin.h | 2 +- .../{fast_mul_add.xpu => fast_addcmul.xpu} | 10 +-- .../{fast_mul_add.cpp => fast_addcmul.cpp} | 10 +-- ...test_xpu_elementwise_mul_add_fuse_pass.py} | 6 +- 12 files changed, 78 insertions(+), 75 deletions(-) rename paddle/fluid/framework/ir/xpu/{element_mul_add_fuse_pass.cc => elementwise_mul_add_fuse_pass.cc} (84%) rename paddle/phi/kernels/fusion/xpu/{elementwise_madd_xpu_kernel.cc => addcmul_xpu_kernel.cc} (81%) rename paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/{fast_mul_add.xpu => fast_addcmul.xpu} (90%) rename paddle/phi/kernels/xpu/plugin/src/wrapper/{fast_mul_add.cpp => fast_addcmul.cpp} (92%) rename test/ir/inference/{test_xpu_element_mul_add_fuse_pass.py => test_xpu_elementwise_mul_add_fuse_pass.py} (93%) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index dee1cc4f9ea1b9..f856eca2aa526d 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -290,7 +290,7 @@ if(WITH_XPU) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(element_mul_add_fuse_pass inference DIR xpu DEPS + pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() diff --git a/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc similarity index 84% rename from paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc rename to paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc index e78c922e63ad89..400371f48b60d3 100644 --- a/paddle/fluid/framework/ir/xpu/element_mul_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc @@ -40,7 +40,7 @@ namespace ir { namespace patterns { /* -fuse elementwise_mul + elementwise_add op to elementwise_madd op +fuse elementwise_mul + elementwise_add op to addcmul_xpu op For example: graph: x y @@ -58,13 +58,13 @@ After the pass is applied: x y w \ | / \ | / - elementwise_madd + addcmul_xpu | | output */ -struct ElementMulAddFusePattern : public PatternBase { - ElementMulAddFusePattern(PDPattern* pattern, const std::string& name_scope); +struct ElementwiseMulAddFusePass : public PatternBase { + ElementwiseMulAddFusePass(PDPattern* pattern, const std::string& name_scope); // declare operator node's name PATTERN_DECL_NODE(elementwise_mul); PATTERN_DECL_NODE(elementwise_add); @@ -76,7 +76,7 @@ struct ElementMulAddFusePattern : public PatternBase { PATTERN_DECL_NODE(add_out); }; -ElementMulAddFusePattern::ElementMulAddFusePattern( +ElementwiseMulAddFusePass::ElementwiseMulAddFusePass( PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { auto elementwise_mul = @@ -104,7 +104,7 @@ ElementMulAddFusePattern::ElementMulAddFusePattern( } /* -special case for elementwise_madd op: +special case for addcmul_xpu op: graph: x y \ / @@ -121,13 +121,14 @@ After the pass is applied: x y \ / \ / - elementwise_madd + addcmul_xpu | | output */ -struct ElementMulAddFuseXYPattern : public PatternBase { - ElementMulAddFuseXYPattern(PDPattern* pattern, const std::string& name_scope); +struct ElementwiseMulAddFuseXYPattern : public PatternBase { + ElementwiseMulAddFuseXYPattern(PDPattern* pattern, + const std::string& name_scope); // declare operator node's name PATTERN_DECL_NODE(elementwise_mul); PATTERN_DECL_NODE(elementwise_add); @@ -138,7 +139,7 @@ struct ElementMulAddFuseXYPattern : public PatternBase { PATTERN_DECL_NODE(add_out); }; -ElementMulAddFuseXYPattern::ElementMulAddFuseXYPattern( +ElementwiseMulAddFuseXYPattern::ElementwiseMulAddFuseXYPattern( PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { auto elementwise_mul = @@ -164,35 +165,35 @@ ElementMulAddFuseXYPattern::ElementMulAddFuseXYPattern( } } // namespace patterns -class ElementMulAddFusePass : public FusePassBase { +class ElementwiseMulAddFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; private: - void FuseElementMulAdd(ir::Graph* graph) const; - void FuseElementMulAddWithOnlyXY(ir::Graph* graph) const; + void FuseElementwiseMulAdd(ir::Graph* graph) const; + void FuseElementwiseMulAddWithOnlyXY(ir::Graph* graph) const; - const std::string name_scope_{"element_mul_add_fuse_pass"}; + const std::string name_scope_{"elementwise_mul_add_fuse_pass"}; }; -void ElementMulAddFusePass::ApplyImpl(ir::Graph* graph) const { +void ElementwiseMulAddFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - FuseElementMulAdd(graph); - FuseElementMulAddWithOnlyXY(graph); + FuseElementwiseMulAdd(graph); + FuseElementwiseMulAddWithOnlyXY(graph); } -void ElementMulAddFusePass::FuseElementMulAdd(ir::Graph* graph) const { +void ElementwiseMulAddFusePass::FuseElementwiseMulAdd(ir::Graph* graph) const { GraphPatternDetector gpd; - patterns::ElementMulAddFusePattern pattern(gpd.mutable_pattern(), - name_scope_); + patterns::ElementwiseMulAddFusePass pattern(gpd.mutable_pattern(), + name_scope_); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - VLOG(4) << "handle ElementMulAddFusePass"; + VLOG(4) << "handle ElementwiseMulAddFusePass"; // declare operator node's name GET_IR_NODE(elementwise_mul); GET_IR_NODE(elementwise_add); @@ -229,9 +230,9 @@ void ElementMulAddFusePass::FuseElementMulAdd(ir::Graph* graph) const { // delete useless node std::unordered_set delete_nodes; - // Generate elementwise_madd op + // Generate addcmul_xpu op framework::OpDesc fused_op_desc(block); - fused_op_desc.SetType("elementwise_madd"); + fused_op_desc.SetType("addcmul_xpu"); fused_op_desc.SetInput("x", {mul_x->Name()}); fused_op_desc.SetInput("y", {mul_y->Name()}); fused_op_desc.SetInput("w", {add_w->Name()}); @@ -251,16 +252,16 @@ void ElementMulAddFusePass::FuseElementMulAdd(ir::Graph* graph) const { AddStatis(found_subgraph_count); } -void ElementMulAddFusePass::FuseElementMulAddWithOnlyXY( +void ElementwiseMulAddFusePass::FuseElementwiseMulAddWithOnlyXY( ir::Graph* graph) const { GraphPatternDetector gpd; - patterns::ElementMulAddFuseXYPattern pattern(gpd.mutable_pattern(), - name_scope_); + patterns::ElementwiseMulAddFuseXYPattern pattern(gpd.mutable_pattern(), + name_scope_); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - VLOG(4) << "handle ElementMulAddFusePass"; + VLOG(4) << "handle ElementwiseMulAddFusePass"; // declare operator node's name GET_IR_NODE(elementwise_mul); GET_IR_NODE(elementwise_add); @@ -294,9 +295,9 @@ void ElementMulAddFusePass::FuseElementMulAddWithOnlyXY( // delete useless node std::unordered_set delete_nodes; - // Generate elementwise_madd op + // Generate addcmul_xpu op framework::OpDesc fused_op_desc(block); - fused_op_desc.SetType("elementwise_madd"); + fused_op_desc.SetType("addcmul_xpu"); fused_op_desc.SetInput("x", {mul_x->Name()}); fused_op_desc.SetInput("y", {mul_y->Name()}); fused_op_desc.SetInput("w", {mul_x->Name()}); @@ -319,10 +320,10 @@ void ElementMulAddFusePass::FuseElementMulAddWithOnlyXY( } // namespace framework } // namespace paddle -REGISTER_PASS(element_mul_add_fuse_pass, - paddle::framework::ir::ElementMulAddFusePass); +REGISTER_PASS(elementwise_mul_add_fuse_pass, + paddle::framework::ir::ElementwiseMulAddFusePass); -REGISTER_PASS_CAPABILITY(element_mul_add_fuse_pass) +REGISTER_PASS_CAPABILITY(elementwise_mul_add_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .GE("elementwise_add", 0) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0d1887badf8496..3f1e57f7f8adc8 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -552,7 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fast_layernorm_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "fast_where_xpu_fuse_pass", - "element_mul_add_fuse_pass", + "elementwise_mul_add_fuse_pass", "link_xpu_op_max_pass", "delete_isolated_node_pass", // "auto_mixed_precision_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 7850eb1afb3d58..ab7b96f05086e1 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -23,6 +23,15 @@ func : add_layernorm_xpu data_type : x +- op : addcmul_xpu + args : (Tensor x, Tensor y, Tensor w) + output : Tensor(out) + infer_meta : + func : AddCMulXPUInferMeta + kernel : + func : addcmul_xpu + data_type : x + - op : conv1d_xpu args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, str padding_algorithm, int dilations, int strides, int groups, int act_type, float act_param) output : Tensor(out), Tensor(out_max) @@ -53,15 +62,6 @@ data_type : x optional : bias, branch, branch_max ,x_max -- op : elementwise_madd - args : (Tensor x, Tensor y, Tensor w) - output : Tensor(out) - infer_meta : - func : ElementwiseMaddXPUInferMeta - kernel : - func : elementwise_madd - data_type : x - - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, Tensor mask, int64_t padding_idx) output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9776bf9be8da99..d52769723e34f8 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -36,6 +36,8 @@ XPUOpMap& get_kl2_ops() { {"adam_dense_param_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adagrad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"addcmul_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"arg_max", XPUKernelSet({phi::DataType::INT32, phi::DataType::FLOAT32, @@ -240,8 +242,6 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"elementwise_madd", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 13d7496cc7450a..907522d6c3fe6d 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -821,10 +821,10 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } -void ElementwiseMaddXPUInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& w, - MetaTensor* out) { +void AddCMulXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(x.dtype()); out->set_layout(x.layout()); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 26c9bf36f0fb32..ca29797318288b 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -201,9 +201,9 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); -void ElementwiseMaddXPUInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& w, - MetaTensor* out); +void AddCMulXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc similarity index 81% rename from paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc rename to paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc index 2202876e5ab940..57c71bcd4bd7da 100644 --- a/paddle/phi/kernels/fusion/xpu/elementwise_madd_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc @@ -19,11 +19,11 @@ namespace phi { namespace fusion { template -void ElementwiseMaddXPUKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& w, - DenseTensor* out) { +void AddCMulXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& w, + DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; const auto* x_data = x.data(); const auto* y_data = y.data(); @@ -32,13 +32,13 @@ void ElementwiseMaddXPUKernel(const Context& ctx, auto* out_data = ctx.template Alloc(out); #ifdef PADDLE_WITH_XPU_PLUGIN - int r = xpu::plugin::fast_mul_add(ctx.x_context(), + int r = xpu::plugin::fast_addcmul(ctx.x_context(), reinterpret_cast(w_data), reinterpret_cast(x_data), reinterpret_cast(y_data), reinterpret_cast(out_data), x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_mul_add"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_addcmul"); #else int r = xpu::addcmul(ctx.x_context(), reinterpret_cast(w_data), @@ -53,9 +53,9 @@ void ElementwiseMaddXPUKernel(const Context& ctx, } // namespace fusion } // namespace phi -PD_REGISTER_KERNEL(elementwise_madd, +PD_REGISTER_KERNEL(addcmul_xpu, XPU, ALL_LAYOUT, - phi::fusion::ElementwiseMaddXPUKernel, + phi::fusion::AddCMulXPUKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index 4a43c9e742b4b5..2038fef8023938 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -115,7 +115,7 @@ DLL_EXPORT int fast_embedding(Context* ctx, int64_t padding_idx, TID start_index = 0); template -DLL_EXPORT int fast_mul_add( +DLL_EXPORT int fast_addcmul( Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len); } // namespace plugin diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu similarity index 90% rename from paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu rename to paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu index 2a24321d80184e..dd2f2f6488a5b0 100644 --- a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_mul_add.xpu +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu @@ -39,7 +39,7 @@ static inline __device__ void primitive_addcmul(T* x, const T* y, int len) { } template -__global__ void fast_mul_add(const T* x, const T* y, T* z, int64_t len) { +__global__ void fast_addcmul(const T* x, const T* y, T* z, int64_t len) { int cid = core_id(); const int ncores = core_num(); int tid = cid * cluster_num() + cluster_id(); @@ -67,11 +67,11 @@ __global__ void fast_mul_add(const T* x, const T* y, T* z, int64_t len) { } } -#define _XPU_DEF__FAST_MUL_ADD_(DTYPE) \ - template __global__ void fast_mul_add( \ +#define _XPU_DEF__FAST_ADDCMUL_(DTYPE) \ + template __global__ void fast_addcmul( \ const DTYPE* x, const DTYPE* y, DTYPE* z, int64_t len); -_XPU_DEF__FAST_MUL_ADD_(float); -_XPU_DEF__FAST_MUL_ADD_(float16); +_XPU_DEF__FAST_ADDCMUL_(float); +_XPU_DEF__FAST_ADDCMUL_(float16); } // namespace plugin } // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp similarity index 92% rename from paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp rename to paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp index 2a1da86fe5f970..a333cbd7a43a23 100644 --- a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_mul_add.cpp +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp @@ -22,7 +22,7 @@ namespace xpu2 { namespace plugin { template -__attribute__((global)) void fast_mul_add(const T* x, +__attribute__((global)) void fast_addcmul(const T* x, const T* y, T* z, int64_t len); @@ -38,7 +38,7 @@ template static int xpu2_wrapper( Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { if (x == w) { - xpu2::plugin::fast_mul_add + xpu2::plugin::fast_addcmul <<ncluster(), 64, ctx->xpu_stream>>>(x, y, z, len); } else { return addcmul(ctx, w, x, y, z, 1.0f, len); @@ -47,7 +47,7 @@ static int xpu2_wrapper( } template -int fast_mul_add( +int fast_addcmul( Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_mul_add", T); @@ -61,9 +61,9 @@ int fast_mul_add( WRAPPER_UNIMPLEMENTED(ctx); } -template int fast_mul_add( +template int fast_addcmul( Context*, const float*, const float*, const float*, float*, int64_t); -template int fast_mul_add(Context*, +template int fast_addcmul(Context*, const float16*, const float16*, const float16*, diff --git a/test/ir/inference/test_xpu_element_mul_add_fuse_pass.py b/test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py similarity index 93% rename from test/ir/inference/test_xpu_element_mul_add_fuse_pass.py rename to test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py index a49417c083d0b5..48603acf90de9f 100644 --- a/test/ir/inference/test_xpu_element_mul_add_fuse_pass.py +++ b/test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py @@ -24,7 +24,7 @@ class TestGatherAddTransposePass(PassAutoScanTest): def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_xpu=True) - yield config, ["elementwise_madd"], (1e-3, 1e-3) + yield config, ["addcmul_xpu"], (1e-3, 1e-3) def sample_program_config(self, draw): x_shape = draw( @@ -64,7 +64,9 @@ def generate_data(shape): def test(self): self.run_and_statis( - quant=False, max_examples=25, passes=["element_mul_add_fuse_pass"] + quant=False, + max_examples=25, + passes=["elementwise_mul_add_fuse_pass"], ) From aeae13afd681f868347d9db8f62d7b5336726950 Mon Sep 17 00:00:00 2001 From: jiangfan06 Date: Tue, 29 Aug 2023 18:07:19 +0800 Subject: [PATCH 4/4] fix --- paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc index 400371f48b60d3..6057f200ac0c18 100644 --- a/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc @@ -92,7 +92,8 @@ ElementwiseMulAddFusePass::ElementwiseMulAddFusePass( auto mul_out = pattern->NewNode(mul_out_repr()) ->AsOutput() ->assert_is_op_output("elementwise_mul", "Out") - ->assert_is_op_input("elementwise_add", "X"); + ->assert_is_op_input("elementwise_add", "X") + ->assert_has_n_outputs(1); elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out}); auto add_w = pattern->NewNode(add_w_repr()) ->AsInput()