Skip to content

Commit 978113d

Browse files
authored
XPU Update xft and weight_quant (PaddlePaddle#72053)
1 parent 9f42e45 commit 978113d

File tree

10 files changed

+129
-40
lines changed

10 files changed

+129
-40
lines changed

cmake/external/xpu.cmake

+7-7
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ if(NOT DEFINED XPU_XHPC_BASE_DATE)
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "3.0.2.5") # For XRE5
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)
37-
set(XPU_XFT_BASE_VERSION "20230602")
37+
set(XPU_XFT_BASE_VERSION "20250402/xpu3")
3838
endif()
3939

4040
if(NOT DEFINED XPU_XRE_BASE_VERSION)
4141
if(WITH_XPU_XRE5)
42-
set(XPU_XRE_BASE_VERSION "5.0.21.18")
42+
set(XPU_XRE_BASE_VERSION "5.0.21.19")
4343
else()
4444
set(XPU_XRE_BASE_VERSION "4.32.0.1")
4545
endif()
@@ -61,7 +61,7 @@ set(XPU_XCCL_BASE_URL
6161

6262
if(NOT XPU_XFT_BASE_URL)
6363
set(XPU_XFT_BASE_URL
64-
"https://klx-sdk-release-public.su.bcebos.com/xft/dev/${XPU_XFT_BASE_VERSION}"
64+
"https://klx-sdk-release-public.su.bcebos.com/xft_internal/dev/${XPU_XFT_BASE_VERSION}"
6565
)
6666
endif()
6767

@@ -112,7 +112,7 @@ else()
112112
set(XPU_XHPC_DIR_NAME "xhpc-ubuntu1604_x86_64")
113113
endif()
114114
set(XPU_XCCL_DIR_NAME "xccl_Linux_x86_64")
115-
set(XPU_XFT_DIR_NAME "xft_ubuntu1604_x86_64")
115+
set(XPU_XFT_DIR_NAME "xft_internal_ubuntu2004")
116116
endif()
117117

118118
set(XPU_XRE_URL
@@ -187,9 +187,9 @@ if(DEFINED ENV{XPU_LIB_ROOT})
187187
endif()
188188

189189
# XCCL
190-
if(DEFINED ENV{XCCL_DIR_NAME})
191-
set(XPU_XCCL_URL "${XPU_LIB_ROOT}/$ENV{XCCL_DIR_NAME}")
192-
set(XCCL_DIR_NAME "$ENV{XCCL_DIR_NAME}")
190+
if(DEFINED ENV{XPU_XCCL_DIR_NAME})
191+
set(XPU_XCCL_URL "${XPU_LIB_ROOT}/$ENV{XPU_XCCL_DIR_NAME}")
192+
set(XPU_XCCL_DIR_NAME "$ENV{XPU_XCCL_DIR_NAME}")
193193
endif()
194194

195195
# XHPC

paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ PermuteINT8WeightOnlyPattern::PermuteINT8WeightOnlyPattern(
4444
PDPattern* pattern, const std::string& name_scope)
4545
: PatternBase(pattern, name_scope, name_scope) {
4646
auto* input = pattern->NewNode(input_repr())
47-
->assert_is_op_input("weight_only_linear_xpu", "x")
47+
->assert_is_op_input("weight_only_linear", "x")
4848
->AsInput();
4949
auto* weight = pattern->NewNode(weight_repr())
50-
->assert_is_op_input("weight_only_linear_xpu", "weight")
50+
->assert_is_op_input("weight_only_linear", "weight")
5151
->AsInput();
5252
auto* weight_scale =
5353
pattern->NewNode(weight_scale_repr())
54-
->assert_is_op_input("weight_only_linear_xpu", "weight_scale")
54+
->assert_is_op_input("weight_only_linear", "weight_scale")
5555
->AsInput();
5656
auto* out = pattern->NewNode(out_repr())
57-
->assert_is_op_output("weight_only_linear_xpu", "out")
57+
->assert_is_op_output("weight_only_linear", "out")
5858
->AsOutput();
5959
auto* weight_only_linear = pattern->NewNode(weight_only_linear_repr())
60-
->assert_is_op("weight_only_linear_xpu");
60+
->assert_is_op("weight_only_linear");
6161

6262
std::vector<PDNode*> input_vars{input, weight, weight_scale};
6363
std::vector<PDNode*> output_vars{out};
@@ -236,4 +236,4 @@ REGISTER_PASS(weight_only_linear_xpu_pass,
236236
REGISTER_PASS_CAPABILITY(weight_only_linear_xpu_pass)
237237
.AddCombination(
238238
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
239-
"weight_only_linear_xpu", 0));
239+
"weight_only_linear", 0));

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

+18-3
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
107107

108108
// only gpu need update. xpu not need, because xpu memcpy op kernel is
109109
// synchronous.
110-
if (phi::is_gpu_place(place) || phi::is_custom_place(place)) {
110+
if (phi::is_gpu_place(place) || phi::is_custom_place(place) ||
111+
phi::is_xpu_place(place)) {
111112
VLOG(6) << "Parse DeviceContext for " << op_name
112113
<< ", execution stream = " << execution_stream;
113114
if (execution_stream != kDefaultStream) {
@@ -136,7 +137,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
136137
}
137138

138139
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
139-
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
140+
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU_BKCL)
140141
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
141142
// with use_cal_stream==false by returning a device context getting from the
142143
// global NCCLCommContext instance. Because when use_calc_stream==false, in
@@ -205,7 +206,21 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
205206
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
206207
op_name.compare(
207208
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
208-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
209+
#if defined(PADDLE_WITH_XPU_BKCL)
210+
if (phi::is_xpu_place(place) && execution_stream == kDefaultStream) {
211+
VLOG(3) << "set stream for " << op_name << "in XPU device";
212+
if (origin_dev_ctx != nullptr) {
213+
// set stream
214+
auto default_stream =
215+
static_cast<DEVICE_CONTEXT*>(origin_dev_ctx)->stream();
216+
static_cast<DEVICE_CONTEXT*>(dev_ctx)->SetStream(default_stream);
217+
// todo set allocator
218+
} else {
219+
VLOG(3) << "CUSTOM DEVICE op " << op_name << " ring_id "
220+
<< ring_id << " origin_dev_ctx is nullptr";
221+
}
222+
}
223+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
209224
if (phi::is_custom_place(place) &&
210225
execution_stream == kDefaultStream) {
211226
VLOG(3) << "set stream for " << op_name << "in Custom device";

paddle/phi/backends/xpu/xpu2_op_list.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ XPUOpMap& get_kl2_ops() {
12201220
phi::DataType::FLOAT32})},
12211221
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
12221222
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
1223-
{"weight_only_linear_xpu",
1223+
{"weight_only_linear",
12241224
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
12251225
{"where_index",
12261226
XPUKernelSet({phi::DataType::INT32,

paddle/phi/backends/xpu/xpu3_op_list.cc

+2
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,8 @@ XPUOpMap& get_kl3_ops() {
16901690
phi::DataType::BOOL,
16911691
phi::DataType::FLOAT32,
16921692
phi::DataType::INT64})},
1693+
{"weight_quantize",
1694+
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
16931695
{"where_grad",
16941696
XPUKernelSet({phi::DataType::INT32,
16951697
phi::DataType::INT64,

paddle/phi/kernels/fusion/xpu/spatial_transformer_resblock_xpu_kernel.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ void SpatialTransformerResblockXPUKernel(
6262
bool include_silu,
6363
DenseTensor* out,
6464
DenseTensor* out_max) {
65-
#ifdef PADDLE_WITH_XPU_XFT
65+
// not suppotr in current xft
66+
#if defined(PADDLE_WITH_XPU_XFT_NOT_SUPPORT)
6667
using XPUType = typename XPUTypeTrait<T>::Type;
6768

6869
auto* in1 = reinterpret_cast<const XPUType*>(x.data<T>());

paddle/phi/kernels/xpu/top_p_sampling_kernel.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ void TopPSamplingKernel(const Context& dev_ctx,
5858
auto x_dims = x.dims();
5959
int bs = x_dims[0];
6060
int vocab_size = x_dims[1];
61-
int p_num = ps.numel();
61+
// int p_num = ps.numel();
6262

63-
PADDLE_ENFORCE_EQ(
64-
p_num,
65-
bs,
66-
common::errors::PreconditionNotMet(
67-
"Expected bs == p_num, but got bs=%d, p_num=%d.", bs, p_num));
63+
// PADDLE_ENFORCE_EQ(
64+
// p_num,
65+
// bs,
66+
// common::errors::PreconditionNotMet(
67+
// "Expected bs == p_num, but got bs=%d, p_num=%d.", bs, p_num));
6868

6969
std::vector<int64_t> infer_seed(bs, random_seed);
7070
if (topp_seed.get_ptr() != nullptr) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#if defined(PADDLE_WITH_XPU_XFT)
15+
#include <xft/xdnn_plugin.h>
16+
#endif
17+
#include "paddle/common/enforce.h"
18+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
19+
#include "paddle/phi/core/dense_tensor.h"
20+
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/cast_kernel.h"
22+
#include "paddle/phi/kernels/transpose_kernel.h"
23+
24+
namespace phi {
25+
26+
template <typename T, typename Context>
27+
void WeightQuantizeKernel(const Context& dev_ctx,
28+
const DenseTensor& x,
29+
const std::string& algo,
30+
const int32_t arch,
31+
const int32_t group_size,
32+
DenseTensor* out,
33+
DenseTensor* scale) {
34+
#if defined(PADDLE_WITH_XPU_XFT)
35+
using XPUType = typename XPUTypeTrait<T>::Type;
36+
auto xpu_ctx = static_cast<const phi::XPUContext*>(&dev_ctx);
37+
int k = x.dims()[0];
38+
int n = x.dims()[1];
39+
scale->Resize({static_cast<int64_t>(n)});
40+
41+
dev_ctx.template Alloc<float>(scale);
42+
43+
if (algo == "weight_only_int8") {
44+
out->Resize({static_cast<int64_t>(k), static_cast<int64_t>(n)});
45+
dev_ctx.template Alloc<int8_t>(out);
46+
47+
int ret = baidu::xpu::xftkernel::xft_quant2d_per_channel<XPUType, float>(
48+
xpu_ctx->x_context(),
49+
reinterpret_cast<const XPUType*>(x.template data<T>()),
50+
nullptr,
51+
out->data<int8_t>(),
52+
scale->data<float>(),
53+
k,
54+
n);
55+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "quant2d");
56+
} else {
57+
PADDLE_THROW(common::errors::Unimplemented(
58+
"Weight quantize only supports weight_only_int8 on XPU now."));
59+
}
60+
#else
61+
PADDLE_THROW(common::errors::Unimplemented(
62+
"weight_quantize is not supported since it's not "
63+
"compiled with XPU_XFT"));
64+
#endif
65+
}
66+
} // namespace phi
67+
68+
PD_REGISTER_KERNEL(weight_quantize,
69+
XPU,
70+
ALL_LAYOUT,
71+
phi::WeightQuantizeKernel,
72+
phi::dtype::float16,
73+
phi::dtype::bfloat16) {}

paddle/phi/ops/yaml/fused_ops.yaml

-10
Original file line numberDiff line numberDiff line change
@@ -828,16 +828,6 @@
828828
optional : mask
829829
support_dygraph_mode : true
830830

831-
- op : weight_only_linear_xpu
832-
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80, int group_size = -1)
833-
output : Tensor(out)
834-
infer_meta :
835-
func : WeightOnlyLinearInferMeta
836-
kernel :
837-
func : weight_only_linear_xpu
838-
data_type : x
839-
optional : bias
840-
841831
- op : yolo_box_xpu
842832
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
843833
output : Tensor(out), Tensor(out_max)

0 commit comments

Comments
 (0)