Skip to content

Commit c159715

Browse files
committed
add xpu dist op, fix topk, regist bfloat16
1 parent c1747b7 commit c159715

File tree

8 files changed

+151
-26
lines changed

8 files changed

+151
-26
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,14 @@ XPUOpMap& get_kl3_ops() {
191191
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
192192
{"bitwise_and", XPUKernelSet({phi::DataType::BOOL})},
193193
{"bitwise_or", XPUKernelSet({phi::DataType::BOOL})},
194-
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
194+
{"broadcast",
195+
XPUKernelSet({phi::DataType::INT32,
196+
phi::DataType::INT64,
197+
phi::DataType::BOOL,
198+
phi::DataType::UINT8,
199+
phi::DataType::FLOAT32,
200+
phi::DataType::FLOAT16,
201+
phi::DataType::BFLOAT16})},
195202
{"c_allgather",
196203
XPUKernelSet({phi::DataType::FLOAT16,
197204
phi::DataType::FLOAT32,
@@ -690,6 +697,19 @@ XPUOpMap& get_kl3_ops() {
690697
phi::DataType::FLOAT32,
691698
phi::DataType::FLOAT16,
692699
phi::DataType::BFLOAT16})},
700+
{"full_like",
701+
XPUKernelSet({phi::DataType::INT64,
702+
phi::DataType::INT32,
703+
phi::DataType::FLOAT32,
704+
phi::DataType::FLOAT64,
705+
phi::DataType::FLOAT16,
706+
phi::DataType::BFLOAT16})},
707+
{"full_batch_size_like",
708+
XPUKernelSet({phi::DataType::INT64,
709+
phi::DataType::INT32,
710+
phi::DataType::FLOAT32,
711+
phi::DataType::FLOAT16,
712+
phi::DataType::BFLOAT16})},
693713
{"fused_multi_transformer_xpu",
694714
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
695715
{"fused_rotary_position_embedding",
@@ -1049,11 +1069,12 @@ XPUOpMap& get_kl3_ops() {
10491069
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
10501070
{"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})},
10511071
{"put_along_axis",
1052-
XPUKernelSet({
1053-
phi::DataType::FLOAT32,
1054-
phi::DataType::FLOAT16,
1055-
phi::DataType::BFLOAT16,
1056-
})},
1072+
XPUKernelSet({phi::DataType::FLOAT32,
1073+
phi::DataType::INT32,
1074+
phi::DataType::INT16,
1075+
phi::DataType::INT64,
1076+
phi::DataType::BFLOAT16,
1077+
phi::DataType::FLOAT16})},
10571078
{"quantize_linear_deprecated_infer",
10581079
XPUKernelSet({phi::DataType::FLOAT32})},
10591080
{"quantize_linear", XPUKernelSet({phi::DataType::FLOAT32})},
@@ -1700,6 +1721,12 @@ XPUOpMap& get_kl3_ops() {
17001721
// Fused op
17011722
{"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})},
17021723
{"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})},
1724+
{"fused_bias_act",
1725+
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
1726+
{"fused_bias_residual_layernorm",
1727+
XPUKernelSet({phi::DataType::FLOAT32,
1728+
phi::DataType::BFLOAT16,
1729+
phi::DataType::FLOAT16})},
17031730
{"fused_gemm_epilogue",
17041731
XPUKernelSet({phi::DataType::FLOAT32,
17051732
phi::DataType::FLOAT16,
@@ -1726,6 +1753,11 @@ XPUOpMap& get_kl3_ops() {
17261753
phi::DataType::FLOAT64,
17271754
phi::DataType::INT32,
17281755
phi::DataType::INT64})},
1756+
{"mp_allreduce_sum",
1757+
XPUKernelSet({phi::DataType::BFLOAT16,
1758+
phi::DataType::FLOAT16,
1759+
phi::DataType::FLOAT32,
1760+
phi::DataType::INT32})},
17291761
{"blha_get_max_len",
17301762
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
17311763
};

paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(fused_bias_act,
139139
ALL_LAYOUT,
140140
phi::fusion::FusedBiasActKernel,
141141
float,
142-
phi::dtype::float16) {}
142+
phi::dtype::float16,
143+
phi::dtype::bfloat16) {}

paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ void FusedLayerNormKernel(const Context& dev_ctx,
134134
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
135135
}
136136
if (residual) {
137+
if (std::is_same<T, phi::dtype::bfloat16>::value) {
138+
PD_THROW("NOT supported quant bfloat16. ");
139+
}
137140
r = baidu::xpu::api::add_layer_norm_fusion(
138141
xpu_ctx->x_context(),
139142
reinterpret_cast<const XPUType*>(x.data<T>()),
@@ -179,4 +182,5 @@ PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
179182
ALL_LAYOUT,
180183
phi::fusion::FusedLayerNormKernel,
181184
float,
185+
phi::dtype::bfloat16,
182186
phi::dtype::float16) {}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/all_context.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/all_to_all_kernel.h"
18+
#if defined(PADDLE_WITH_XPU_BKCL)
19+
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
20+
#endif
21+
22+
namespace phi {
23+
24+
template <typename T, typename Context>
25+
void BroadcastKernel(const Context& dev_ctx,
26+
const DenseTensor& x,
27+
int root,
28+
DenseTensor* out) {
29+
#if defined(PADDLE_WITH_XPU_BKCL)
30+
PADDLE_ENFORCE_GT(x.numel(),
31+
0,
32+
common::errors::InvalidArgument(
33+
"Tensor need be broadcast must not empty."));
34+
35+
dev_ctx.template Alloc<T>(out);
36+
auto comm_context =
37+
static_cast<distributed::BKCLCommContext*>(dev_ctx.GetCommContext());
38+
PADDLE_ENFORCE_NE(
39+
comm_context,
40+
nullptr,
41+
errors::Unavailable("BKCLCommContext is nullptr, collective op should "
42+
"has ring_id attr."));
43+
comm_context->Broadcast(out, x, root, comm_context->GetStream());
44+
#else
45+
PADDLE_THROW(common::errors::PreconditionNotMet(
46+
"PaddlePaddle should be compiled with XPU."));
47+
#endif
48+
}
49+
50+
} // namespace phi
51+
52+
PD_REGISTER_KERNEL(broadcast,
53+
XPU,
54+
ALL_LAYOUT,
55+
phi::BroadcastKernel,
56+
float,
57+
double,
58+
int,
59+
bool,
60+
int8_t,
61+
uint8_t,
62+
int16_t,
63+
int64_t,
64+
phi::dtype::float16,
65+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/full_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ PD_REGISTER_KERNEL(full_with_tensor,
169169
int,
170170
int64_t,
171171
bool,
172-
phi::dtype::float16) {
172+
phi::dtype::float16,
173+
phi::dtype::bfloat16) {
173174
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
174175
}

paddle/phi/kernels/xpu/mp_allreduce_sum_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ PD_REGISTER_KERNEL(mp_allreduce_sum,
3131
phi::MpAllReduceSumKernel,
3232
float,
3333
int,
34-
phi::dtype::float16) {}
34+
phi::dtype::float16,
35+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/put_along_axis_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ PD_REGISTER_KERNEL(put_along_axis,
132132
XPU,
133133
ALL_LAYOUT,
134134
phi::PutAlongAxisKernel,
135+
float,
136+
int64_t,
137+
int,
135138
phi::dtype::float16,
136-
phi::dtype::bfloat16,
137-
float) {}
139+
phi::dtype::bfloat16) {}

paddle/phi/kernels/xpu/top_k_kernel.cc

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,51 @@ void TopkKernel(const Context& dev_ctx,
3232
using XPUType = typename XPUTypeTrait<T>::Type;
3333

3434
const auto& in_dims = x.dims();
35-
const T* in_data = x.data<T>();
36-
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
37-
T* output_data = dev_ctx.template Alloc<T>(out);
38-
39-
const auto& out_dims = out->dims();
40-
41-
PADDLE_ENFORCE_EQ(
42-
sorted,
43-
true,
44-
errors::External(
45-
"XPU API does not support unsorted topk operation currently."
46-
" Operator will be supported in future update."));
4735
if (in_dims.size() == 0) {
4836
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
4937
reinterpret_cast<const XPUType*>(x.data<T>()),
5038
reinterpret_cast<XPUType*>(out->data<T>()),
5139
x.numel());
5240
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
53-
41+
dev_ctx.template Alloc<int64_t>(indices);
5442
phi::funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));
55-
5643
return;
5744
}
45+
46+
// axis < 0, calculate the real axis
47+
if (axis < 0) {
48+
axis += in_dims.size();
49+
}
50+
51+
int k = k_scalar.to<int>();
52+
PADDLE_ENFORCE_GE(
53+
x.numel(),
54+
k,
55+
errors::InvalidArgument(
56+
"x has only %d element, can not find %d top values.", x.numel(), k));
57+
58+
if (k_scalar.FromTensor()) {
59+
auto out_dims_ = out->dims();
60+
// according to axis to set K value in the dim
61+
out_dims_[axis] = k;
62+
out->Resize(out_dims_);
63+
indices->Resize(out_dims_);
64+
}
65+
66+
const T* in_data = x.data<T>();
67+
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
68+
T* output_data = dev_ctx.template Alloc<T>(out);
69+
70+
const auto& out_dims = out->dims();
71+
72+
// PADDLE_ENFORCE_EQ(
73+
// sorted,
74+
// true,
75+
// errors::External(
76+
// "XPU API does not support unsorted topk operation currently."
77+
// " Operator will be supported in future update."));
5878
if (axis < 0) axis += in_dims.size();
5979

60-
size_t k = k_scalar.to<int>();
6180
if (axis + 1 == in_dims.size()) {
6281
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
6382
int32_t* indices_int_data =

0 commit comments

Comments
 (0)