diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b1dafb0d3934db..f856eca2aa526d 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -290,6 +290,8 @@ if(WITH_XPU) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc new file mode 100644 index 00000000000000..6057f200ac0c18 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc @@ -0,0 +1,333 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +/* +fuse elementwise_mul + elementwise_add op to addcmul_xpu op +For example: +graph: + x y + \ / + \ / + elementwise_mul w + \ / + \ / + elementwise_add + | + | + output +------------------------------------------------------ +After the pass is applied: + x y w + \ | / + \ | / + addcmul_xpu + | + | + output +*/ +struct ElementwiseMulAddFusePass : public PatternBase { + ElementwiseMulAddFusePass(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(elementwise_mul); + PATTERN_DECL_NODE(elementwise_add); + // declare variable node's name + PATTERN_DECL_NODE(mul_x); + PATTERN_DECL_NODE(mul_y); + PATTERN_DECL_NODE(mul_out); + PATTERN_DECL_NODE(add_w); + PATTERN_DECL_NODE(add_out); +}; + +ElementwiseMulAddFusePass::ElementwiseMulAddFusePass( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto elementwise_mul = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + auto elementwise_add = + pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add"); + auto mul_x = pattern->NewNode(mul_x_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "X"); + auto mul_y = pattern->NewNode(mul_y_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "Y"); + auto mul_out = pattern->NewNode(mul_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add", "X") + ->assert_has_n_outputs(1); + elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out}); + auto add_w = pattern->NewNode(add_w_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto add_out = pattern->NewNode(add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + elementwise_add->LinksFrom({mul_out, add_w}).LinksTo({add_out}); +} + +/* +special case for addcmul_xpu op: +graph: + x y + \ / + \ / + elementwise_mul x + \ / + \ / + elementwise_add + | + | + output +------------------------------------------------------ +After the pass is applied: + x y + \ / + \ / + addcmul_xpu + | + | + output +*/ +struct ElementwiseMulAddFuseXYPattern : public PatternBase { + ElementwiseMulAddFuseXYPattern(PDPattern* pattern, + const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(elementwise_mul); + PATTERN_DECL_NODE(elementwise_add); + // declare variable node's name + PATTERN_DECL_NODE(mul_x); + PATTERN_DECL_NODE(mul_y); + PATTERN_DECL_NODE(mul_out); + PATTERN_DECL_NODE(add_out); +}; + +ElementwiseMulAddFuseXYPattern::ElementwiseMulAddFuseXYPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto elementwise_mul = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + auto elementwise_add = + pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add"); + auto mul_x = pattern->NewNode(mul_x_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("elementwise_add", "Y"); + auto mul_y = pattern->NewNode(mul_y_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_mul", "Y"); + auto mul_out = pattern->NewNode(mul_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add", "X"); + elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out}); + auto add_out = pattern->NewNode(add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + elementwise_add->LinksFrom({mul_out, mul_x}).LinksTo({add_out}); +} +} // namespace patterns + +class ElementwiseMulAddFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FuseElementwiseMulAdd(ir::Graph* graph) const; + void FuseElementwiseMulAddWithOnlyXY(ir::Graph* graph) const; + + const std::string name_scope_{"elementwise_mul_add_fuse_pass"}; +}; + +void ElementwiseMulAddFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FuseElementwiseMulAdd(graph); + FuseElementwiseMulAddWithOnlyXY(graph); +} + +void ElementwiseMulAddFusePass::FuseElementwiseMulAdd(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ElementwiseMulAddFusePass pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ElementwiseMulAddFusePass"; + // declare operator node's name + GET_IR_NODE(elementwise_mul); + GET_IR_NODE(elementwise_add); + // declare variable node's name + GET_IR_NODE(mul_x); + GET_IR_NODE(mul_y); + GET_IR_NODE(mul_out); + GET_IR_NODE(add_w); + GET_IR_NODE(add_out); + + bool flag = true; + auto var_type = mul_x->Var()->GetDataType(); + if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) { + flag = false; + } + + auto x_shape = mul_x->Var()->GetShape(); + auto y_shape = mul_y->Var()->GetShape(); + auto w_shape = add_w->Var()->GetShape(); + if (x_shape.size() == y_shape.size() && x_shape.size() == w_shape.size()) { + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i] || x_shape[i] != w_shape[i] || + x_shape[i] == -1) { + flag = false; + } + } + } else { + flag = false; + } + + if (flag) { + auto* block = elementwise_mul->Op()->Block(); + + // delete useless node + std::unordered_set delete_nodes; + + // Generate addcmul_xpu op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("addcmul_xpu"); + fused_op_desc.SetInput("x", {mul_x->Name()}); + fused_op_desc.SetInput("y", {mul_y->Name()}); + fused_op_desc.SetInput("w", {add_w->Name()}); + fused_op_desc.SetOutput("out", {add_out->Name()}); + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(mul_x, fused_op); + IR_NODE_LINK_TO(mul_y, fused_op); + IR_NODE_LINK_TO(add_w, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + delete_nodes.insert({elementwise_mul, elementwise_add, mul_out}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void ElementwiseMulAddFusePass::FuseElementwiseMulAddWithOnlyXY( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ElementwiseMulAddFuseXYPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ElementwiseMulAddFusePass"; + // declare operator node's name + GET_IR_NODE(elementwise_mul); + GET_IR_NODE(elementwise_add); + // declare variable node's name + GET_IR_NODE(mul_x); + GET_IR_NODE(mul_y); + GET_IR_NODE(mul_out); + GET_IR_NODE(add_out); + + bool flag = true; + auto var_type = mul_x->Var()->GetDataType(); + if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) { + flag = false; + } + + auto x_shape = mul_x->Var()->GetShape(); + auto y_shape = mul_y->Var()->GetShape(); + if (x_shape.size() == y_shape.size()) { + for (size_t i = 0; i < x_shape.size(); ++i) { + if (x_shape[i] != y_shape[i] || x_shape[i] == -1) { + flag = false; + } + } + } else { + flag = false; + } + + if (flag) { + auto* block = elementwise_mul->Op()->Block(); + + // delete useless node + std::unordered_set delete_nodes; + + // Generate addcmul_xpu op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("addcmul_xpu"); + fused_op_desc.SetInput("x", {mul_x->Name()}); + fused_op_desc.SetInput("y", {mul_y->Name()}); + fused_op_desc.SetInput("w", {mul_x->Name()}); + fused_op_desc.SetOutput("out", {add_out->Name()}); + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(mul_x, fused_op); + IR_NODE_LINK_TO(mul_y, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + delete_nodes.insert({elementwise_mul, elementwise_add, mul_out}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(elementwise_mul_add_fuse_pass, + paddle::framework::ir::ElementwiseMulAddFusePass); + +REGISTER_PASS_CAPABILITY(elementwise_mul_add_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("elementwise_add", 0) + .LE("elementwise_add", 1) + .GE("elementwise_mul", 0) + .LE("elementwise_mul", 1)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 31d044f8c0b489..0c5423fe4d9152 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -552,6 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fast_layernorm_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "fast_where_xpu_fuse_pass", + "elementwise_mul_add_fuse_pass", "link_xpu_op_max_pass", "delete_isolated_node_pass", // "auto_mixed_precision_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 09ccd2fe7d87d5..09e40595bedde7 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -23,6 +23,15 @@ func : add_layernorm_xpu data_type : x +- op : addcmul_xpu + args : (Tensor x, Tensor y, Tensor w) + output : Tensor(out) + infer_meta : + func : AddCMulXPUInferMeta + kernel : + func : addcmul_xpu + data_type : x + - op : conv1d_xpu args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, str padding_algorithm, int dilations, int strides, int groups, int act_type, float act_param) output : Tensor(out), Tensor(out_max) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 9154d1aa092469..d52769723e34f8 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -36,6 +36,8 @@ XPUOpMap& get_kl2_ops() { {"adam_dense_param_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adagrad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"addcmul_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"arg_max", XPUKernelSet({phi::DataType::INT32, phi::DataType::FLOAT32, @@ -161,6 +163,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::INT8, phi::DataType::INT64, phi::DataType::INT32})}, {"conv2d_grad", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 993fb5d5887b82..1baf780f67610c 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -821,6 +821,15 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void AddCMulXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + void FusedScaleBiasReluConvBnstatsInferMeta( const MetaTensor& x, const MetaTensor& w, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 3d7ba19c4ec3f5..ee41d55ca5524a 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -201,6 +201,11 @@ void FastLayernormXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); +void AddCMulXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& w, + MetaTensor* out); + void FusedScaleBiasReluConvBnstatsInferMeta( const MetaTensor& x, const MetaTensor& w, diff --git a/paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc new file mode 100644 index 00000000000000..57c71bcd4bd7da --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void AddCMulXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& w, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + const auto* x_data = x.data(); + const auto* y_data = y.data(); + const auto* w_data = w.data(); + + auto* out_data = ctx.template Alloc(out); + +#ifdef PADDLE_WITH_XPU_PLUGIN + int r = xpu::plugin::fast_addcmul(ctx.x_context(), + reinterpret_cast(w_data), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_addcmul"); +#else + int r = xpu::addcmul(ctx.x_context(), + reinterpret_cast(w_data), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + 1.0f, + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "addcmul"); +#endif +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(addcmul_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::AddCMulXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc index f1fac997061c59..5afdf2612981eb 100644 --- a/paddle/phi/kernels/xpu/concat_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -119,4 +119,6 @@ PD_REGISTER_KERNEL(concat, double, phi::dtype::float16, int64_t, - int) {} + int, + int8_t, + bool) {} diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index 1357ca43001c8b..2038fef8023938 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -114,6 +114,9 @@ DLL_EXPORT int fast_embedding(Context* ctx, int64_t ym, int64_t padding_idx, TID start_index = 0); +template +DLL_EXPORT int fast_addcmul( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len); } // namespace plugin } // namespace api diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu new file mode 100644 index 00000000000000..dd2f2f6488a5b0 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu @@ -0,0 +1,77 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +template +static inline __device__ void primitive_addcmul(T* x, const T* y, int len) { + float32x16_t vx0; + float32x16_t vy0; + float32x16_t vx1; + float32x16_t vy1; + for (int i = 0; i < len; i += 32) { + vload2_lm(x + i, vx0, vx1); + vload2_lm(y + i, vy0, vy1); + vx0 = vvmac_float32x16(vx0, vy0, vx0); + vx1 = vvmac_float32x16(vx1, vy1, vx1); + vstore2_lm(x + i, vx0, vx1); + } + mfence_lm(); +} + +template +__global__ void fast_addcmul(const T* x, const T* y, T* z, int64_t len) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + const int buf_len = 512 / sizeof(T); + __simd__ float local_x_after_cast[buf_len]; + __simd__ float local_y_after_cast[buf_len]; + T* local_x = (T*)(local_x_after_cast); + T* local_y = (T*)(local_y_after_cast); + + int loop = 0; + for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) { + int read_len = min(static_cast(buf_len), len - i); + GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T)); + GM2LM(y + i, local_y, read_len * sizeof(T)); + primitive_addcmul(local_x, local_y, read_len); + LM2GM_ASYNC(local_x, z + i, read_len * sizeof(T)); + mfence_lm(); +#ifndef __XPU3__ + loop++; + if ((loop & 0xF) == 0) { + sync_all(); + } +#endif + } +} + +#define _XPU_DEF__FAST_ADDCMUL_(DTYPE) \ + template __global__ void fast_addcmul( \ + const DTYPE* x, const DTYPE* y, DTYPE* z, int64_t len); +_XPU_DEF__FAST_ADDCMUL_(float); +_XPU_DEF__FAST_ADDCMUL_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp new file mode 100644 index 00000000000000..a333cbd7a43a23 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/refactor/util/vector_util.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void fast_addcmul(const T* x, + const T* y, + T* z, + int64_t len); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int xpu2_wrapper( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { + if (x == w) { + xpu2::plugin::fast_addcmul + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, z, len); + } else { + return addcmul(ctx, w, x, y, z, 1.0f, len); + } + return SUCCESS; +} + +template +int fast_addcmul( + Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_mul_add", T); + WRAPPER_DUMP_PARAM4(ctx, w, x, y, z); + WRAPPER_DUMP_PARAM2(ctx, len, ctx->_l3_mgr.get_size()); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_4PTRS(ctx, T, len, w, x, y, z); + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper(ctx, w, x, y, z, len); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int fast_addcmul( + Context*, const float*, const float*, const float*, float*, int64_t); +template int fast_addcmul(Context*, + const float16*, + const float16*, + const float16*, + float16*, + int64_t); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py b/test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py new file mode 100644 index 00000000000000..48603acf90de9f --- /dev/null +++ b/test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestGatherAddTransposePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["addcmul_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=3, max_size=4 + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + mul_op = OpConfig( + "elementwise_mul", + inputs={"X": ["mul_x"], "Y": ["mul_y"]}, + outputs={"Out": ["mul_out"]}, + ) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["mul_out"], "Y": ["add_w"]}, + outputs={"Out": ["add_out"]}, + ) + + ops = [mul_op, add_op] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "mul_x": TensorConfig(data_gen=partial(generate_data, x_shape)), + "mul_y": TensorConfig(data_gen=partial(generate_data, x_shape)), + "add_w": TensorConfig(data_gen=partial(generate_data, x_shape)), + }, + weights={}, + outputs=["add_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["elementwise_mul_add_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()