Skip to content

[XPU]add xpu dist op, fix topk, regist bfloat16 #71623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,14 @@ XPUOpMap& get_kl3_ops() {
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_and", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_or", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
{"broadcast",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"c_allgather",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
Expand Down Expand Up @@ -690,6 +697,19 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"full_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"full_batch_size_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"fused_multi_transformer_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fused_rotary_position_embedding",
Expand Down Expand Up @@ -1049,11 +1069,11 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})},
{"put_along_axis",
XPUKernelSet({
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT16})},
{"quantize_linear_deprecated_infer",
XPUKernelSet({phi::DataType::FLOAT32})},
{"quantize_linear", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -1700,6 +1720,12 @@ XPUOpMap& get_kl3_ops() {
// Fused op
{"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})},
{"fused_bias_act",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
{"fused_bias_residual_layernorm",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT16})},
{"fused_gemm_epilogue",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
Expand All @@ -1726,6 +1752,11 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"mp_allreduce_sum",
XPUKernelSet({phi::DataType::BFLOAT16,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"blha_get_max_len",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
};
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(fused_bias_act,
ALL_LAYOUT,
phi::fusion::FusedBiasActKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
4 changes: 4 additions & 0 deletions paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ void FusedLayerNormKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
if (residual) {
if (std::is_same<T, phi::dtype::bfloat16>::value) {
PD_THROW("NOT supported quant bfloat16. ");
}
r = baidu::xpu::api::add_layer_norm_fusion(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
Expand Down Expand Up @@ -179,4 +182,5 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
ALL_LAYOUT,
phi::fusion::FusedLayerNormKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16) {}
65 changes: 65 additions & 0 deletions paddle/phi/kernels/xpu/broadcast_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/all_to_all_kernel.h"
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
#endif

namespace phi {

template <typename T, typename Context>
void BroadcastKernel(const Context& dev_ctx,
const DenseTensor& x,
int root,
DenseTensor* out) {
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_GT(x.numel(),
0,
common::errors::InvalidArgument(
"Tensor need be broadcast must not empty."));

dev_ctx.template Alloc<T>(out);
auto comm_context =
static_cast<distributed::BKCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
errors::Unavailable("BKCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_context->Broadcast(out, x, root, comm_context->GetStream());
#else
PADDLE_THROW(common::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with XPU."));
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(broadcast,
XPU,
ALL_LAYOUT,
phi::BroadcastKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/full_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
int,
int64_t,
bool,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ PD_REGISTER_KERNEL(mp_allreduce_sum,
phi::MpAllReduceSumKernel,
float,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/xpu/put_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ PD_REGISTER_KERNEL(put_along_axis,
XPU,
ALL_LAYOUT,
phi::PutAlongAxisKernel,
float,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
float) {}
phi::dtype::bfloat16) {}
51 changes: 33 additions & 18 deletions paddle/phi/kernels/xpu/top_k_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,47 @@ void TopkKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;

const auto& in_dims = x.dims();
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
dev_ctx.template Alloc<int64_t>(indices);
phi::funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));
return;
}

// axis < 0, calculate the real axis
if (axis < 0) {
axis += in_dims.size();
}

int k = k_scalar.to<int>();
PADDLE_ENFORCE_GE(
x.numel(),
k,
errors::InvalidArgument(
"x has only %d element, can not find %d top values.", x.numel(), k));

if (k_scalar.FromTensor()) {
auto out_dims_ = out->dims();
// according to axis to set K value in the dim
out_dims_[axis] = k;
out->Resize(out_dims_);
indices->Resize(out_dims_);
}

const T* in_data = x.data<T>();
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
T* output_data = dev_ctx.template Alloc<T>(out);

const auto& out_dims = out->dims();

PADDLE_ENFORCE_EQ(
sorted,
true,
errors::External(
"XPU API does not support unsorted topk operation currently."
" Operator will be supported in future update."));
if (in_dims.size() == 0) {
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");

phi::funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));

return;
}
// PADDLE_ENFORCE_EQ(
// sorted,
// true,
// errors::External(
// "XPU API does not support unsorted topk operation currently."
// " Operator will be supported in future update."));
if (axis < 0) axis += in_dims.size();

size_t k = k_scalar.to<int>();
if (axis + 1 == in_dims.size()) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int32_t* indices_int_data =
Expand Down
Loading