diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 559a34cba02384..7e60b498d501ee 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -191,7 +191,14 @@ XPUOpMap& get_kl3_ops() { {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, {"bitwise_and", XPUKernelSet({phi::DataType::BOOL})}, {"bitwise_or", XPUKernelSet({phi::DataType::BOOL})}, - {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, + {"broadcast", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL, + phi::DataType::UINT8, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"c_allgather", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, @@ -690,6 +697,19 @@ XPUOpMap& get_kl3_ops() { phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, + {"full_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, + {"full_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"fused_multi_transformer_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_rotary_position_embedding", @@ -1049,11 +1069,11 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})}, {"put_along_axis", - XPUKernelSet({ - phi::DataType::FLOAT32, - phi::DataType::FLOAT16, - phi::DataType::BFLOAT16, - })}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BFLOAT16, + phi::DataType::FLOAT16})}, {"quantize_linear_deprecated_infer", XPUKernelSet({phi::DataType::FLOAT32})}, {"quantize_linear", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -1700,6 +1720,12 @@ XPUOpMap& get_kl3_ops() { // Fused op {"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})}, + {"fused_bias_act", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, + {"fused_bias_residual_layernorm", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::BFLOAT16, + phi::DataType::FLOAT16})}, {"fused_gemm_epilogue", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -1726,6 +1752,11 @@ XPUOpMap& get_kl3_ops() { phi::DataType::FLOAT64, phi::DataType::INT32, phi::DataType::INT64})}, + {"mp_allreduce_sum", + XPUKernelSet({phi::DataType::BFLOAT16, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::INT32})}, {"blha_get_max_len", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, }; diff --git a/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc index 1bc5c47fdbc863..9417b55e6d8a67 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc @@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(fused_bias_act, ALL_LAYOUT, phi::fusion::FusedBiasActKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc index 2c506c7f17b5c3..c80286eb7691a6 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc @@ -134,6 +134,9 @@ void FusedLayerNormKernel(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); } if (residual) { + if (std::is_same::value) { + PD_THROW("NOT supported quant bfloat16. "); + } r = baidu::xpu::api::add_layer_norm_fusion( xpu_ctx->x_context(), reinterpret_cast(x.data()), @@ -179,4 +182,5 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm, ALL_LAYOUT, phi::fusion::FusedLayerNormKernel, float, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/broadcast_kernel.cc b/paddle/phi/kernels/xpu/broadcast_kernel.cc new file mode 100644 index 00000000000000..8fc4aad4d1ae4f --- /dev/null +++ b/paddle/phi/kernels/xpu/broadcast_kernel.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/all_to_all_kernel.h" +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/phi/core/distributed/bkcl_comm_context.h" +#endif + +namespace phi { + +template +void BroadcastKernel(const Context& dev_ctx, + const DenseTensor& x, + int root, + DenseTensor* out) { +#if defined(PADDLE_WITH_XPU_BKCL) + PADDLE_ENFORCE_GT(x.numel(), + 0, + common::errors::InvalidArgument( + "Tensor need be broadcast must not empty.")); + + dev_ctx.template Alloc(out); + auto comm_context = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_context, + nullptr, + errors::Unavailable("BKCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_context->Broadcast(out, x, root, comm_context->GetStream()); +#else + PADDLE_THROW(common::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with XPU.")); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(broadcast, + XPU, + ALL_LAYOUT, + phi::BroadcastKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int16_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index df2b6d1a315f31..8d2152c829c1fe 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -169,6 +169,7 @@ PD_REGISTER_KERNEL(full_with_tensor, int, int64_t, bool, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::CPU); } diff --git a/paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc b/paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc index ba8886cc1834df..bb0e80c30c6ba8 100644 --- a/paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc @@ -31,4 +31,5 @@ PD_REGISTER_KERNEL(mp_allreduce_sum, phi::MpAllReduceSumKernel, float, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/put_along_axis_kernel.cc b/paddle/phi/kernels/xpu/put_along_axis_kernel.cc index 76029234df2127..a7b59cb0e28bd0 100644 --- a/paddle/phi/kernels/xpu/put_along_axis_kernel.cc +++ b/paddle/phi/kernels/xpu/put_along_axis_kernel.cc @@ -132,6 +132,8 @@ PD_REGISTER_KERNEL(put_along_axis, XPU, ALL_LAYOUT, phi::PutAlongAxisKernel, + float, + int64_t, + int, phi::dtype::float16, - phi::dtype::bfloat16, - float) {} + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index ff7f00e53f6895..0eb900dd455ba5 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -32,32 +32,47 @@ void TopkKernel(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; const auto& in_dims = x.dims(); + if (in_dims.size() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + dev_ctx.template Alloc(indices); + phi::funcs::set_constant(dev_ctx, indices, static_cast(0)); + return; + } + + // axis < 0, calculate the real axis + if (axis < 0) { + axis += in_dims.size(); + } + + int k = k_scalar.to(); + PADDLE_ENFORCE_GE( + x.numel(), + k, + errors::InvalidArgument( + "x has only %d element, can not find %d top values.", x.numel(), k)); + + if (k_scalar.FromTensor()) { + auto out_dims_ = out->dims(); + // according to axis to set K value in the dim + out_dims_[axis] = k; + out->Resize(out_dims_); + indices->Resize(out_dims_); + } + const T* in_data = x.data(); int64_t* indices_data = dev_ctx.template Alloc(indices); T* output_data = dev_ctx.template Alloc(out); const auto& out_dims = out->dims(); - PADDLE_ENFORCE_EQ( - sorted, - true, - errors::External( - "XPU API does not support unsorted topk operation currently." - " Operator will be supported in future update.")); - if (in_dims.size() == 0) { - int r = xpu::copy(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out->data()), - x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - - phi::funcs::set_constant(dev_ctx, indices, static_cast(0)); - - return; - } + // PADDLE_ENFORCE_EQ( + // sorted, + // true, + // errors::External( + // "XPU API does not support unsorted topk operation currently." + // " Operator will be supported in future update.")); if (axis < 0) axis += in_dims.size(); - size_t k = k_scalar.to(); if (axis + 1 == in_dims.size()) { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int32_t* indices_int_data = diff --git a/test/xpu/test_put_along_axis_op_int_xpu.py b/test/xpu/test_put_along_axis_op_int_xpu.py new file mode 100644 index 00000000000000..f88020329836fa --- /dev/null +++ b/test/xpu/test_put_along_axis_op_int_xpu.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import numpy as np +from get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) +from op_test import convert_float_to_uint16 +from op_test_xpu import XPUOpTest + +import paddle + +paddle.enable_static() + + +class XPUTestPutAlongAxisInt(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'put_along_axis' + + class TestXPUPutAlongAxisOpAssign(XPUOpTest): + def setUp(self): + + self.init_config() + self.init_data() + self.x = np.random.random(self.x_shape).astype( + self.dtype if self.dtype != np.uint16 else np.float32 + ) + self.value = np.random.random(self.index.shape).astype( + self.dtype if self.dtype != np.uint16 else np.float32 + ) + broadcast_shape_list = list(self.x_shape) + + self.broadcast_shape = tuple(broadcast_shape_list) + self.index_broadcast = np.broadcast_to( + self.index, self.broadcast_shape + ) + self.value_broadcast = np.broadcast_to( + self.value, self.broadcast_shape + ) + self.target = copy.deepcopy(self.x) + mean_record = {} + for i in range(self.index_broadcast.shape[0]): + for j in range(self.index_broadcast.shape[1]): + for k in range(self.index_broadcast.shape[2]): + loc_ = [i, j, k] + loc_[self.axis] = self.index_broadcast[i, j, k] + if self.reduce == "assign": + self.target[loc_[0], loc_[1], loc_[2]] = ( + self.value_broadcast[i, j, k] + ) + elif self.reduce == "add": + self.target[ + loc_[0], loc_[1], loc_[2] + ] += self.value_broadcast[i, j, k] + elif self.reduce == "mul" or self.reduce == "multiply": + self.target[ + loc_[0], loc_[1], loc_[2] + ] *= self.value_broadcast[i, j, k] + elif self.reduce == "mean": + self.target[ + loc_[0], loc_[1], loc_[2] + ] += self.value_broadcast[i, j, k] + loc = tuple(loc_) + if loc in mean_record.keys(): + mean_record[loc] += 1 + else: + mean_record[loc] = 1 + elif self.reduce == "amax": + self.target[loc_[0], loc_[1], loc_[2]] = max( + self.target[loc_[0], loc_[1], loc_[2]], + self.value_broadcast[i, j, k], + ) + elif self.reduce == "amin": + self.target[loc_[0], loc_[1], loc_[2]] = min( + self.target[loc_[0], loc_[1], loc_[2]], + self.value_broadcast[i, j, k], + ) + elif self.reduce == "max": + self.target[loc_[0], loc_[1], loc_[2]] = max( + self.target[loc_[0], loc_[1], loc_[2]], + self.value_broadcast[i, j, k], + ) + if self.reduce == "mean": + for loc in mean_record: + self.target[loc] /= mean_record[loc] + 1 + + self.inputs = { + 'Input': ( + self.x + if self.dtype != np.uint16 + else convert_float_to_uint16(self.x) + ), + 'Index': self.index_broadcast, + 'Value': ( + self.value_broadcast + if self.dtype != np.uint16 + else convert_float_to_uint16(self.value_broadcast) + ), + } + self.attrs = { + 'Axis': self.axis, + 'Reduce': self.reduce, + 'Include_self': True, + } + self.outputs = { + 'Result': ( + self.target + if self.dtype != np.uint16 + else convert_float_to_uint16(self.target) + ) + } + + def init_config(self): + self.op_type = "put_along_axis" + self.place = paddle.XPUPlace(0) + self.dtype = self.in_type + + def init_data(self): + self.x_shape = (10, 10, 10) + self.reduce = "assign" + self.index_type = np.int32 + self.index = np.array([[[0]]]).astype(self.index_type) + self.axis = 1 + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + self.check_grad_with_place( + self.place, ['Input', 'Value'], 'Result' + ) + + class TestAddCase1(TestXPUPutAlongAxisOpAssign): + def init_data(self): + self.in_type = self.dtype + self.reduce = "add" + self.x_shape = (10, 10, 10) + self.index_type = np.int64 + self.index = np.array([[[0]]]).astype(self.index_type) + self.axis = 1 + + class TestAddCase2(TestXPUPutAlongAxisOpAssign): + def init_data(self): + self.in_type = self.dtype + self.reduce = "add" + self.x_shape = (12, 14, 16) + self.index_type = np.int64 + self.index = np.random.randint(0, 12, size=(12, 14, 1)) + self.axis = 2 + + class TestMulCase1(TestXPUPutAlongAxisOpAssign): + def init_data(self): + self.in_type = self.dtype + self.reduce = "mul" + self.x_shape = (16, 6, 12) + self.index_type = np.int32 + self.index = np.random.randint(0, 6, size=(16, 1, 12)) + self.axis = 1 + + class TestMulCase2(TestXPUPutAlongAxisOpAssign): + def init_data(self): + self.in_type = self.dtype + self.reduce = "mul" + self.x_shape = (8, 6, 12) + self.index_type = np.int32 + self.index = np.random.randint(0, 6, size=(8, 1, 12)) + self.axis = 2 + + +support_types = get_xpu_op_support_types('put_along_axis') +for stype in support_types: + if stype == 'int32' or stype == 'int64': + create_test_class(globals(), XPUTestPutAlongAxisInt, stype) + +if __name__ == "__main__": + unittest.main() diff --git a/test/xpu/test_put_along_axis_op_xpu.py b/test/xpu/test_put_along_axis_op_xpu.py index 7da4b88dbad3cf..3cef0432bd0cf6 100644 --- a/test/xpu/test_put_along_axis_op_xpu.py +++ b/test/xpu/test_put_along_axis_op_xpu.py @@ -245,6 +245,8 @@ def init_data(self): support_types = get_xpu_op_support_types('put_along_axis') for stype in support_types: + if stype == 'int32' or stype == 'int64': + continue create_test_class(globals(), XPUTestPutAlongAxis, stype) if __name__ == "__main__":