Skip to content

[XPU] Support int31 weight dynamic quantization for fc and conv2d (#59981) #67058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,19 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias(
false,
weight_scale,
true);
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
quant_post_type.find("conv2d")->second == 4) {
VLOG(5) << "Use int31 per-tensor weight";
PrepareWeight<float, float>(graph,
scope,
block,
conv_filter_replicated_node,
&filter_intx,
&filter_max,
&scale_max,
false,
weight_scale,
false);
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
quant_post_type.find("conv2d")->second == 0 ||
quant_post_type.find("conv2d") != quant_post_type.end() &&
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,19 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
!transpose_w,
weight_scale,
true);
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
quant_post_type.find("fc")->second == 4) {
VLOG(5) << "Use int31 per-tensor weight";
PrepareWeight<float, float>(graph,
scope,
block,
mul_w_replicated_node,
&filter_intx,
&filter_max,
&scale_max,
!transpose_w,
weight_scale,
false);
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
quant_post_type.find("fc")->second == 0 ||
quant_post_type.find("fc") != quant_post_type.end() &&
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/xpu/pass_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,18 @@ void PrepareWeight(Graph* graph,
}
}

template void PrepareWeight<float, float>(
Graph* graph,
Scope* scope,
BlockDesc* block,
Node* weight,
Node** dst_weight,
Node** dst_weight_max,
Node** dst_scale_max,
bool transpose,
const std::vector<float>& weight_scales,
bool per_channel_quant = false);

template void PrepareWeight<float, int16_t>(
Graph* graph,
Scope* scope,
Expand Down
53 changes: 47 additions & 6 deletions paddle/fluid/framework/ir/xpu/quant_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ static void QuantFP32ToIntX(const float* src_ptr,
LOG(FATAL) << "Not support.";
}

template <>
void QuantFP32ToIntX<float>(const float* src_ptr,
float* dst_ptr,
float max_val,
int numel) {
for (int i = 0; i < numel; i++) {
dst_ptr[i] = static_cast<float>(src_ptr[i]);
}
}

template <>
void QuantFP32ToIntX<int16_t>(const float* src_ptr,
int16_t* dst_ptr,
Expand Down Expand Up @@ -364,16 +374,16 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
phi::DenseTensor* scale_max,
bool transpose,
const std::vector<float>& weight_scales) {
PADDLE_ENFORCE_EQ(
weight_scales.empty(),
false,
platform::errors::InvalidArgument(
"ConvertWithoutQuant is not allowed weight scales is empty!"));
if (transpose) {
Transpose2D(weight);
}
bool per_tensor_quant = weight_scales.size() == 1;
if (std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value) {
PADDLE_ENFORCE_EQ(
weight_scales.empty(),
false,
platform::errors::InvalidArgument(
"ConvertWithoutQuant is not allowed weight scales is empty!"));
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
if (per_tensor_quant) {
Expand All @@ -400,8 +410,32 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
weight_scales.data(),
weight_scales.size() * sizeof(float));
}
} else if (std::is_same<T, float>::value) {
// Convert fp16 to fp32
phi::DenseTensor weight_fp32;
CastToFp32(weight, &weight_fp32);
// Find max
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
int size = weight_fp32.numel();
auto* weight_data = weight_fp32.data<float>();
float max_val = FindMaxAbs(weight_data, size);
std::vector<float> max_vec(max_ptr_size, max_val);
weight_max->set_type(phi::DataType::FLOAT32);
weight_max->Resize({max_ptr_size});
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
memcpy(cpu_ctx->Alloc<float>(weight_max),
max_vec.data(),
max_ptr_size * sizeof(float));

// Quant
weight->set_type(phi::DataType::FLOAT32);
weight->Resize(weight_fp32.dims());
QuantFP32ToIntX<float>(
weight_data, cpu_ctx->Alloc<float>(weight), max_val, size);
} else {
LOG(FATAL) << "Only support int8<->int8 and int16<->int16 convert.";
LOG(FATAL)
<< "Only support float<->int31, int8<->int8 and int16<->int16 convert.";
}
}

Expand All @@ -424,6 +458,13 @@ template void ConvertWithoutQuant<int8_t>(
bool transpose,
const std::vector<float>& weight_scales);

template void ConvertWithoutQuant<float>(
phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
phi::DenseTensor* scale_max,
bool transpose,
const std::vector<float>& weight_scales);

bool IsPerTensorQuant(const std::vector<float>& weight_max) {
bool per_tensor = true;
PADDLE_ENFORCE_GT(
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ void Conv2dXPUKernel(const Context& ctx,
DataTypeToString(filter.dtype()),
DataTypeToString(out_dtype)));
}
} else if (filter.dtype() == DataType::FLOAT32) {
CONV2D_XPU_KERNEL_IMPL(float, float, float, int32_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.",
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ void FcXPUKernel(const Context& ctx,
DataTypeToString(w.dtype()),
DataTypeToString(out_dtype)));
}
} else if (w.dtype() == DataType::FLOAT32) {
FC_XPU_KERNEL_IMPL(float, float, float, int32_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.",
Expand Down